From 2bb60982ac072566aee933d08840f15d801ee10b Mon Sep 17 00:00:00 2001 From: Taha Yassine <40228615+taha-yassine@users.noreply.github.com> Date: Mon, 23 Dec 2024 13:45:55 +0100 Subject: [PATCH] Patch GPTNeoX to use adequate FA2 if position_ids is provided (#35318) --- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index f512938e75f9a7..98418cb02d65ba 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -148,6 +148,7 @@ def flash_attention_forward( norm_factor, attention_dropout, training, + position_ids=None, target_dtype=None, **_kwargs, ): @@ -173,6 +174,7 @@ def flash_attention_forward( attention_mask, query_length, dropout=attention_dropout, + position_ids=position_ids, softmax_scale=norm_factor, is_causal=True, use_top_left_mask=flash_attn_uses_top_left_mask, @@ -353,6 +355,7 @@ def forward( key, value, attention_mask=attention_mask, + position_ids=position_ids, head_mask=head_mask, norm_factor=self.norm_factor, attention_dropout=self.config.attention_dropout,