Skip to content

Commit

Permalink
fix more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 18, 2024
1 parent de9a776 commit 0249b17
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
5 changes: 4 additions & 1 deletion optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,10 @@ def _prepare_speech_to_text_calibration_data(self, config: OVQuantizationConfigB
if decoder_w_p_model is not None:
decoder_w_p_model.request = decoder_w_p_model.request.request

datasets = [nncf.Dataset(encoder_calibration_data), nncf.Dataset(decoder_calibration_data),]
datasets = [
nncf.Dataset(encoder_calibration_data),
nncf.Dataset(decoder_calibration_data),
]
if decoder_w_p_model is not None:
datasets.append(nncf.Dataset(decoder_w_p_calibration_data))
return datasets
Expand Down
5 changes: 3 additions & 2 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,9 @@ def test_seq2seq_load_from_hub(self):
with TemporaryDirectory() as tmpdirname:
ov_exported_pipe.save_pretrained(tmpdirname)
folder_contents = os.listdir(tmpdirname)
self.assertTrue(OV_DECODER_WITH_PAST_NAME in folder_contents)
self.assertTrue(OV_DECODER_WITH_PAST_NAME.replace(".xml", ".bin") in folder_contents)
if not ov_exported_pipe.model.decoder.stateful:
self.assertTrue(OV_DECODER_WITH_PAST_NAME in folder_contents)
self.assertTrue(OV_DECODER_WITH_PAST_NAME.replace(".xml", ".bin") in folder_contents)
ov_exported_pipe = optimum_pipeline("text2text-generation", tmpdirname, accelerator="openvino")
self.assertIsInstance(ov_exported_pipe.model, OVBaseModel)

Expand Down
2 changes: 1 addition & 1 deletion tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ def test_calibration_data_uniqueness(self, model_name, apply_caching):

for inputs_dict in calibration_data:
for k, v in inputs_dict.items():
if k == "input_ids":
if k in ["input_ids", "beam_idx"]:
continue

x = (v.numpy() if isinstance(v, torch.Tensor) else v).copy()
Expand Down

0 comments on commit 0249b17

Please sign in to comment.