From 1518508467d96b3866fc4ebcb7a5b3a2e0df2aa4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Wed, 22 May 2024 15:07:51 +0200 Subject: [PATCH] Avoid extra chunk in speech recognition (#29539) --- .../pipelines/automatic_speech_recognition.py | 3 +-- .../test_pipelines_automatic_speech_recognition.py | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 123dbcdb67afd7..2e8682d96a65e0 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -67,8 +67,7 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, if dtype is not None: processed = processed.to(dtype=dtype) _stride_left = 0 if chunk_start_idx == 0 else stride_left - # all right strides must be full, otherwise it is the last item - is_last = chunk_end_idx > inputs_len if stride_right > 0 else chunk_end_idx >= inputs_len + is_last = chunk_end_idx >= inputs_len _stride_right = 0 if is_last else stride_right chunk_len = chunk.shape[0] diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index a1ab2947830ba9..bc619769e113f6 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -1569,10 +1569,10 @@ def test_chunk_iterator_stride(self): "input_values" ] outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10)) - self.assertEqual(len(outs), 2) - self.assertEqual([o["stride"] for o in outs], [(100, 0, 10), (30, 20, 0)]) - self.assertEqual([o["input_values"].shape for o in outs], [(1, 100), (1, 30)]) - self.assertEqual([o["is_last"] for o in outs], [False, True]) + self.assertEqual(len(outs), 1) + self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)]) + self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)]) + self.assertEqual([o["is_last"] for o in outs], [True]) outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10)) self.assertEqual(len(outs), 2)