From 7afe7d93014a7db874f01ab8549e63de0fad6d98 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 15 Feb 2024 05:34:30 +0100 Subject: [PATCH] Update src/transformers/models/llama/modeling_llama.py --- src/transformers/models/llama/modeling_llama.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8b887200d22bb4..9d7f2f1b7a2f7e 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -432,11 +432,6 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # we slice the states for static kv cache to be supported in FA2. Not sure it's a must as compile fails - if cache_position is not None: - key_states = key_states[:, :, : cache_position[-1] + 1, :] - value_states = value_states[:, :, : cache_position[-1] + 1, :] - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2)