Skip to content

Commit

Permalink
allow unused input parameters passthrough when chunking in asr pipeli…
Browse files Browse the repository at this point in the history
…nes (#33889)

* allow unused parameter passthrough when chunking in asr pipelines

* format code

* format

* run fixup

* update tests

* update parameters to pipline in test

* updates parametrs in tests

* change spelling in gitignore

* revert .gitignore to main

* add git ignore of devcontainer folder

* assert asr output follows expected inference output type

* run fixup

* Remove .devcontainer from .gitignore

* remove compliance check
  • Loading branch information
VictorAtIfInsurance authored Nov 25, 2024
1 parent 4dc1a69 commit a0f4f31
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
for item in chunk_iter(
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
):
yield item
yield {**item, **extra}
else:
if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples:
processed = self.feature_extractor(
Expand Down
19 changes: 19 additions & 0 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,6 +1443,25 @@ def test_chunking_fast(self):
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "ZBT ZC")

@require_torch
def test_input_parameter_passthrough(self):
"""Test that chunked vs non chunked versions of ASR pipelines returns the same structure for the same inputs."""
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="hf-internal-testing/tiny-random-wav2vec2",
)

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
audio = ds[40]["audio"]["array"]

inputs = {"raw": audio, "sampling_rate": 16_000, "id": 1}

chunked_output = speech_recognizer(inputs.copy(), chunk_length_s=30)
non_chunked_output = speech_recognizer(inputs.copy())
assert (
chunked_output.keys() == non_chunked_output.keys()
), "The output structure should be the same for chunked vs non-chunked versions of asr pipelines."

@require_torch
def test_return_timestamps_ctc_fast(self):
speech_recognizer = pipeline(
Expand Down

0 comments on commit a0f4f31

Please sign in to comment.