diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 09ce72c8b1b231..24e60eddba221b 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -629,11 +629,13 @@ def forward( if ( self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None ): # efficient SDPA and no padding - attention_mask = attention_mask * torch.tril( - torch.ones_like(attention_mask), diagonal=-self.sliding_window + min_dtype = torch.finfo(hidden_states.dtype).min + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window ) - if attention_mask.shape[1] <= 1: # when decoding - attention_mask = attention_mask[:, -self.sliding_window :] + attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) + if attention_mask.shape[-1] <= 1: # when decoding + attention_mask = attention_mask[:, :, :, -self.sliding_window :] residual = hidden_states