Skip to content

Commit

Permalink
Fix BertGeneration (#35043)
Browse files Browse the repository at this point in the history
fix

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Dec 3, 2024
1 parent 901f504 commit 7a7f276
Showing 1 changed file with 1 addition and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -785,9 +785,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = None
if not use_cache:
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down

0 comments on commit 7a7f276

Please sign in to comment.