Skip to content

Commit

Permalink
Fix flex_attention in training mode (#35605)
Browse files Browse the repository at this point in the history
* fix flex

* add test

* style
  • Loading branch information
Cyrilvallez authored and ArthurZucker committed Jan 10, 2025
1 parent 7cf6230 commit 59e28c3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/integrations/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def causal_mod(score, b, h, q_idx, kv_idx):
if softcap is not None:
score = softcap * torch.tanh(score / softcap)
if causal_mask is not None:
score += causal_mask[b][0][q_idx][kv_idx]
score = score + causal_mask[b][0][q_idx][kv_idx]
return score

attn_output, attention_weights = flex_attention(
Expand Down
13 changes: 13 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4790,6 +4790,19 @@ def test_forward_with_num_logits_to_keep(self):
# Assert the last tokens are actually the same (except for the natural fluctuation due to order of FP ops)
self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits, atol=1e-5))

@require_torch_gpu
def test_flex_attention_with_grads(self):
for model_class in self.all_model_classes:
if not model_class._supports_flex_attn:
self.skipTest(reason="This model does not support flex attention")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config._attn_implementation = "flex_attention"
model = model_class(config).to(device=torch_device, dtype=torch.float16)
self.assertTrue(model.config._attn_implementation == "flex_attention")

# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
_ = model(inputs_dict["input_ids"].to(torch_device))


global_rng = random.Random()

Expand Down

0 comments on commit 59e28c3

Please sign in to comment.