Skip to content

Commit

Permalink
propagate
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi committed Jul 22, 2024
1 parent a84474b commit bc88467
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
6 changes: 2 additions & 4 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def forward(
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

Expand All @@ -311,7 +311,6 @@ def forward(

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

Expand Down Expand Up @@ -817,7 +816,7 @@ def forward(
key_states = self.k_proj(torch.cat([context, latents], dim=-2))
value_states = self.v_proj(torch.cat([context, latents], dim=-2))

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

Expand Down Expand Up @@ -882,7 +881,6 @@ def forward(
value_states = value_states.to(target_dtype)

# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def forward(
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

Expand Down Expand Up @@ -469,7 +469,6 @@ def forward(
value_states = value_states.to(target_dtype)

# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

Expand Down

0 comments on commit bc88467

Please sign in to comment.