diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index aacd0fdad1ba3b..af66be643a58cb 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -306,7 +306,7 @@ def forward( # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -316,7 +316,6 @@ def forward( # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) @@ -927,7 +926,7 @@ def forward( key_states = self.k_proj(torch.cat([context, latents], dim=-2)) value_states = self.v_proj(torch.cat([context, latents], dim=-2)) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -992,7 +991,6 @@ def forward( value_states = value_states.to(target_dtype) # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index f49f55f57797f8..398a3ee7ef46c9 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -423,7 +423,7 @@ def forward( # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -498,7 +498,6 @@ def forward( value_states = value_states.to(target_dtype) # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2)