From 8f8e6ca3bbdbcbcd4d903257cd25e5ad7d328300 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 29 Jul 2024 17:58:00 +0200 Subject: [PATCH] typo --- tests/onnxruntime/test_modeling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 99e21da562a..65fa4197f80 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -4491,11 +4491,11 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str): self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) self.assertEqual( outputs_model_with_pkv.shape[1], - self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 2, + self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 1, ) self.assertEqual( outputs_model_without_pkv.shape[1], - self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 2, + self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 1, ) self.GENERATION_LENGTH = generation_length