From 9719202d37492a43c323d03a81aed23d14d98dec Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 2 May 2024 15:24:33 +0100 Subject: [PATCH] Generate: fix `SinkCache` on Llama models (#30581) --- src/transformers/cache_utils.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index ceca9d3eeb3592..2e29e19ade46a4 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -207,7 +207,9 @@ def __init__(self, window_length: int, num_sink_tokens: int) -> None: self.value_cache: List[torch.Tensor] = [] self.window_length = window_length self.num_sink_tokens = num_sink_tokens - self.cos_sin_cache = {} + self.cos_sin_rerotation_cache = {} + self._cos_cache = None + self._sin_cache = None self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen @staticmethod @@ -225,7 +227,7 @@ def _apply_key_rotary_pos_emb( def _get_rerotation_cos_sin( self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - if key_states.shape[-2] not in self.cos_sin_cache: + if key_states.shape[-2] not in self.cos_sin_rerotation_cache: # Upcast to float32 temporarily for better accuracy cos = cos.to(torch.float32) sin = sin.to(torch.float32) @@ -238,11 +240,11 @@ def _get_rerotation_cos_sin( rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin - self.cos_sin_cache[key_states.shape[-2]] = ( + self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( rerotation_cos.to(key_states.dtype).unsqueeze(0), rerotation_sin.to(key_states.dtype).unsqueeze(0), ) - return self.cos_sin_cache[key_states.shape[-2]] + return self.cos_sin_rerotation_cache[key_states.shape[-2]] def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" @@ -292,6 +294,21 @@ def update( if layer_idx == 0: self._seen_tokens += key_states.shape[-2] + # Update the sin/cos cache, which holds sin/cos values for all possible positions + if using_rope and layer_idx == 0: + # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove + # after all RoPE models have a llama-like cache utilization. + if cos.dim() == 2: + self._cos_cache = cos + self._sin_cache = sin + else: + if self._cos_cache is None: + self._cos_cache = cos[0, ...] + self._sin_cache = sin[0, ...] + elif self._cos_cache.shape[0] < self.window_length: + self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) + self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) + # [bsz, num_heads, seq_len, head_dim] if len(self.key_cache) <= layer_idx: # Empty cache @@ -312,7 +329,7 @@ def update( # On RoPE models, we need to recompute the Key rotation as the tokens are shifted if using_rope: rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( - key_states, cos[: self.window_length], sin[: self.window_length] + key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length] ) if partial_rotation_size is not None: keys_to_keep, keys_pass = (