From a0f4f3174f4aee87dd88ffda95579f7450934fc8 Mon Sep 17 00:00:00 2001 From: VictorAtIfInsurance <143422373+VictorAtIfInsurance@users.noreply.github.com> Date: Mon, 25 Nov 2024 11:36:44 +0100 Subject: [PATCH] allow unused input parameters passthrough when chunking in asr pipelines (#33889) * allow unused parameter passthrough when chunking in asr pipelines * format code * format * run fixup * update tests * update parameters to pipline in test * updates parametrs in tests * change spelling in gitignore * revert .gitignore to main * add git ignore of devcontainer folder * assert asr output follows expected inference output type * run fixup * Remove .devcontainer from .gitignore * remove compliance check --- .../pipelines/automatic_speech_recognition.py | 2 +- ..._pipelines_automatic_speech_recognition.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index f4ffdf6445381c..09958b5fca195b 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -434,7 +434,7 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): for item in chunk_iter( inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype ): - yield item + yield {**item, **extra} else: if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples: processed = self.feature_extractor( diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index b21e8cd25f2408..e8cd8febca006e 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -1443,6 +1443,25 @@ def test_chunking_fast(self): self.assertEqual(output, [{"text": ANY(str)}]) self.assertEqual(output[0]["text"][:6], "ZBT ZC") + @require_torch + def test_input_parameter_passthrough(self): + """Test that chunked vs non chunked versions of ASR pipelines returns the same structure for the same inputs.""" + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="hf-internal-testing/tiny-random-wav2vec2", + ) + + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") + audio = ds[40]["audio"]["array"] + + inputs = {"raw": audio, "sampling_rate": 16_000, "id": 1} + + chunked_output = speech_recognizer(inputs.copy(), chunk_length_s=30) + non_chunked_output = speech_recognizer(inputs.copy()) + assert ( + chunked_output.keys() == non_chunked_output.keys() + ), "The output structure should be the same for chunked vs non-chunked versions of asr pipelines." + @require_torch def test_return_timestamps_ctc_fast(self): speech_recognizer = pipeline(