Skip to content

Commit

Permalink
Patch GPTNeoX to use adequate FA2 if position_ids is provided (#35318)
Browse files Browse the repository at this point in the history
  • Loading branch information
taha-yassine authored Dec 23, 2024
1 parent 5e7aede commit 2bb6098
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def flash_attention_forward(
norm_factor,
attention_dropout,
training,
position_ids=None,
target_dtype=None,
**_kwargs,
):
Expand All @@ -173,6 +174,7 @@ def flash_attention_forward(
attention_mask,
query_length,
dropout=attention_dropout,
position_ids=position_ids,
softmax_scale=norm_factor,
is_causal=True,
use_top_left_mask=flash_attn_uses_top_left_mask,
Expand Down Expand Up @@ -353,6 +355,7 @@ def forward(
key,
value,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
norm_factor=self.norm_factor,
attention_dropout=self.config.attention_dropout,
Expand Down

0 comments on commit 2bb6098

Please sign in to comment.