From 9b19efac4dc63e447d4f71080190628f9a49c853 Mon Sep 17 00:00:00 2001 From: benniekiss <63211101+benniekiss@users.noreply.github.com> Date: Wed, 28 Aug 2024 03:53:58 -0400 Subject: [PATCH] [whisper] pass attention_mask to generate_with_fallback() (#33145) pass attention_mask to generate_with_fallback --- src/transformers/models/whisper/generation_whisper.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index d9a684183cabd0..7da4e49de98c9a 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -686,6 +686,7 @@ def generate( do_condition_on_prev_tokens=do_condition_on_prev_tokens, is_shortform=is_shortform, batch_size=batch_size, + attention_mask=attention_mask, kwargs=kwargs, ) @@ -790,6 +791,7 @@ def generate_with_fallback( do_condition_on_prev_tokens, is_shortform, batch_size, + attention_mask, kwargs, ): kwargs = copy.copy(kwargs) @@ -837,6 +839,7 @@ def generate_with_fallback( prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, **generate_kwargs, )