diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 6111261830b8f0..5504caf1484139 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -247,9 +247,12 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): return_lse=output_attentions, ) if not output_attentions: - return attn_output, None + attn_weights = None else: - return attn_output[0], attn_output[1] + attn_output, attn_weights = attn_output + + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): @@ -280,6 +283,7 @@ def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): is_causal=is_causal, scale=config.scaling, ) + attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None @@ -362,7 +366,7 @@ def forward( if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") - attention_type = "eager" + attention_type = "flex_attention" else: attention_type = self.config._attn_implementation diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 8d86238632365f..87e090aa8195cb 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -290,9 +290,12 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): return_lse=output_attentions, ) if not output_attentions: - return attn_output, None + attn_weights = None else: - return attn_output[0], attn_output[1] + attn_output, attn_weights = attn_output + + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): @@ -323,6 +326,7 @@ def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): is_causal=is_causal, scale=config.scaling, ) + attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None @@ -405,7 +409,7 @@ def forward( if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") - attention_type = "eager" + attention_type = "flex_attention" else: attention_type = self.config._attn_implementation diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 06116c4dbafbd4..d65c961bcf65f3 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -385,7 +385,7 @@ def test_model_9b_bf16_flex_attention(self): model = AutoModelForCausalLM.from_pretrained( model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention" ).to(torch_device) - + assert model.config._attn_implementation == "flex_attention" tokenizer = AutoTokenizer.from_pretrained(model_id) inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)