From 6528edbdc2093ba629402281c3c08a1f52eb2bd6 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 2 Jul 2024 15:16:32 +0100 Subject: [PATCH] [whisper] remove un-necessary transpose for fa2 attention --- src/transformers/models/whisper/modeling_whisper.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index f1467a55e03b9b..033b65e744dac2 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -402,7 +402,7 @@ def forward( bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz) + query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim)) if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) @@ -431,7 +431,6 @@ def forward( # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim] # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2)