Skip to content

Commit

Permalink
[whisper] remove un-necessary transpose for fa2 attention
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi committed Jul 2, 2024
1 parent a970195 commit 6528edb
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def forward(
bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))

if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
Expand Down Expand Up @@ -431,7 +431,6 @@ def forward(

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

Expand Down

0 comments on commit 6528edb

Please sign in to comment.