diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index a3de765137b84d..26a5388a5cd08b 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -891,9 +891,13 @@ def generate_with_fallback( # remove all padding tokens if seek_sequence[-1] == generation_config.pad_token_id: num_paddings = (seek_sequence == generation_config.pad_token_id).sum() - seek_sequence = seek_sequence[:-num_paddings] - if return_token_timestamps and not is_shortform: - seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings] + if is_shortform and generation_config.pad_token_id == generation_config.eos_token_id: + num_paddings -= 1 + + if num_paddings != 0: + seek_sequence = seek_sequence[:-num_paddings] + if return_token_timestamps and not is_shortform: + seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings] # check which sequences in batch need fallback & which should be skipped needs_fallback[i], should_skip[i] = self._need_fallback(