diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index d9a684183cabd0..7da4e49de98c9a 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -686,6 +686,7 @@ def generate( do_condition_on_prev_tokens=do_condition_on_prev_tokens, is_shortform=is_shortform, batch_size=batch_size, + attention_mask=attention_mask, kwargs=kwargs, ) @@ -790,6 +791,7 @@ def generate_with_fallback( do_condition_on_prev_tokens, is_shortform, batch_size, + attention_mask, kwargs, ): kwargs = copy.copy(kwargs) @@ -837,6 +839,7 @@ def generate_with_fallback( prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, **generate_kwargs, )