From 6cd2c7b2d390be362570d25364dbd8415ee4d292 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 3 Oct 2024 16:01:39 +0200 Subject: [PATCH 1/3] Fix Whisper shortform --- src/transformers/models/whisper/generation_whisper.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 32e54e0a121d7b..2f19f67f51bd13 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -891,9 +891,13 @@ def generate_with_fallback( # remove all padding tokens if seek_sequence[-1] == generation_config.pad_token_id: num_paddings = (seek_sequence == generation_config.pad_token_id).sum() - seek_sequence = seek_sequence[:-num_paddings] - if return_token_timestamps and not is_shortform: - seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings] + if is_shortform and generation_config.pad_token_id == generation_config.eos_token_id: + num_paddings -= 1 + + if num_paddings != 0: + seek_sequence = seek_sequence[:-num_paddings] + if return_token_timestamps and not is_shortform: + seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings] # check which sequences in batch need fallback & which should be skipped needs_fallback[i], should_skip[i] = self._need_fallback( From a03ace5d3b8e1459723877ad2ad1ec805da1bc16 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 3 Oct 2024 16:09:34 +0200 Subject: [PATCH 2/3] [run-slow] whisper From fde551425c9914dda66dc522d5d889a4ec15b80e Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 3 Oct 2024 16:16:05 +0200 Subject: [PATCH 3/3] [run-slow] whisper