Skip to content

Commit

Permalink
make sure we raise an errir for static cache with FA2 enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Feb 15, 2024
1 parent b3fc042 commit 5fdb2da
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,13 +431,11 @@ def forward(
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if (
cache_position is not None
): # we slice for static kv cache to be supported in FA2. Not sure it's a must as compile fails
key_states, value_states = (
key_states[:, :, : cache_position[-1] + 1, :],
value_states[:, :, : cache_position[-1] + 1, :],
)

# we slice the states for static kv cache to be supported in FA2. Not sure it's a must as compile fails
if cache_position is not None:
key_states = key_states[:, :, : cache_position[-1] + 1, :]
value_states = value_states[:, :, : cache_position[-1] + 1, :]

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
Expand Down Expand Up @@ -792,6 +790,12 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()

def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
if self.config.attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
raise ValueError(
"`static` cache implementation is not compatible with `config.attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)

if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
Expand Down Expand Up @@ -1028,11 +1032,9 @@ def forward(

def _update_causal_mask(self, attention_mask, input_tensor):
if self.config._attn_implementation == "flash_attention_2":
# since the static cache is padded, you have to pass the attention mask raw.
# similar to https://github.com/facebookresearch/llama/commit/e9077bd24177a74aa79f406bef7d4b57fe393157
if attention_mask is not None and 0.0 in attention_mask:
return None
return attention_mask
return attention_mask
return None

batch_size, seq_length = input_tensor.shape[:2]
dtype = input_tensor.dtype
Expand Down

0 comments on commit 5fdb2da

Please sign in to comment.