From 9f4de68fe2019fd50750593064ad410c680e9e3d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Sun, 22 Dec 2024 13:23:51 +0100 Subject: [PATCH] fix --- src/transformers/integrations/flash_attention.py | 3 +++ src/transformers/models/gpt2/modeling_gpt2.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index b8407bc29c6a8a..65674f0329ada1 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -44,6 +44,9 @@ def flash_attention_forward( else: target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype + # FA2 always relies on the value set in the module, so remove it if present in kwargs + kwargs.pop("is_causal", None) + attn_output = _flash_attention_forward( query, key, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index ad53c7804ebeea..854c21576b5048 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -295,9 +295,9 @@ def forward( shape_q = (*query_states.shape[:-1], -1, self.head_dim) shape_kv = (*key_states.shape[:-1], -1, self.head_dim) - query_states = query_states.reshape(shape_q).transpose(1, 2) - key_states = key_states.reshape(shape_kv).transpose(1, 2) - value_states = value_states.reshape(shape_kv).transpose(1, 2) + query_states = query_states.view(shape_q).transpose(1, 2) + key_states = key_states.view(shape_kv).transpose(1, 2) + value_states = value_states.view(shape_kv).transpose(1, 2) if layer_past is not None: past_key, past_value = layer_past