diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index f1d9cb9d000..9187b851fc0 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -4602,14 +4602,14 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str): ) self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) - self.assertEqual( - outputs_model_with_pkv.shape[1], - self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 1, - ) - self.assertEqual( - outputs_model_without_pkv.shape[1], - self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 1, - ) + + if model_arch == "whisper" and check_if_transformers_greater("4.43"): + gen_length = self.GENERATION_LENGTH + 2 + else: + gen_length = self.GENERATION_LENGTH + 1 + + self.assertEqual(outputs_model_with_pkv.shape[1], gen_length) + self.assertEqual(outputs_model_without_pkv.shape[1], gen_length) self.GENERATION_LENGTH = generation_length if os.environ.get("TEST_LEVEL", 0) == "1":