Skip to content

Commit

Permalink
Update src/transformers/models/llama/modeling_llama.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker authored Feb 15, 2024
1 parent 9fbe901 commit 7afe7d9
Showing 1 changed file with 0 additions and 5 deletions.
5 changes: 0 additions & 5 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,6 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# we slice the states for static kv cache to be supported in FA2. Not sure it's a must as compile fails
if cache_position is not None:
key_states = key_states[:, :, : cache_position[-1] + 1, :]
value_states = value_states[:, :, : cache_position[-1] + 1, :]

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
Expand Down

0 comments on commit 7afe7d9

Please sign in to comment.