From 7a7f27697ad17d4ff03dbe203095be8b71759b55 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 3 Dec 2024 13:56:59 +0100 Subject: [PATCH] Fix `BertGeneration` (#35043) fix Co-authored-by: ydshieh --- .../models/bert_generation/modeling_bert_generation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index db4a378577562a..800ea2bef1d631 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -785,9 +785,7 @@ def forward( # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = None - if not use_cache: - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]