Skip to content

Commit

Permalink
[FlexAttention] Update gemma2 (#34942)
Browse files Browse the repository at this point in the history
* update tests

* now maybe this fixes the previous fialing tests!

* nit default

* Update src/transformers/models/gemma2/modular_gemma2.py

Co-authored-by: Anton Vlasjuk <[email protected]>

* fix-copies

---------

Co-authored-by: Anton Vlasjuk <[email protected]>
  • Loading branch information
ArthurZucker and vasqu authored Nov 27, 2024
1 parent 6c3f168 commit 4c1388f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
10 changes: 7 additions & 3 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,12 @@ def tanh_softcap(score, b, h, q_idx, kv_idx):
return_lse=output_attentions,
)
if not output_attentions:
return attn_output, None
attn_weights = None
else:
return attn_output[0], attn_output[1]
attn_output, attn_weights = attn_output

attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights


def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
Expand Down Expand Up @@ -280,6 +283,7 @@ def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
is_causal=is_causal,
scale=config.scaling,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None


Expand Down Expand Up @@ -362,7 +366,7 @@ def forward(

if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
attention_type = "eager"
attention_type = "flex_attention"
else:
attention_type = self.config._attn_implementation

Expand Down
10 changes: 7 additions & 3 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,12 @@ def tanh_softcap(score, b, h, q_idx, kv_idx):
return_lse=output_attentions,
)
if not output_attentions:
return attn_output, None
attn_weights = None
else:
return attn_output[0], attn_output[1]
attn_output, attn_weights = attn_output

attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights


def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
Expand Down Expand Up @@ -323,6 +326,7 @@ def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
is_causal=is_causal,
scale=config.scaling,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None


Expand Down Expand Up @@ -405,7 +409,7 @@ def forward(

if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
attention_type = "eager"
attention_type = "flex_attention"
else:
attention_type = self.config._attn_implementation

Expand Down
2 changes: 1 addition & 1 deletion tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def test_model_9b_bf16_flex_attention(self):
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
).to(torch_device)

assert model.config._attn_implementation == "flex_attention"
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)

Expand Down

0 comments on commit 4c1388f

Please sign in to comment.