diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 99e21da562a..65fa4197f80 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -4491,11 +4491,11 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str): self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) self.assertEqual( outputs_model_with_pkv.shape[1], - self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 2, + self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 1, ) self.assertEqual( outputs_model_without_pkv.shape[1], - self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 2, + self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 1, ) self.GENERATION_LENGTH = generation_length