Skip to content

Commit

Permalink
Generate: fix SinkCache on Llama models (#30581)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored May 2, 2024
1 parent 66abe13 commit 9719202
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def __init__(self, window_length: int, num_sink_tokens: int) -> None:
self.value_cache: List[torch.Tensor] = []
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.cos_sin_cache = {}
self.cos_sin_rerotation_cache = {}
self._cos_cache = None
self._sin_cache = None
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

@staticmethod
Expand All @@ -225,7 +227,7 @@ def _apply_key_rotary_pos_emb(
def _get_rerotation_cos_sin(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if key_states.shape[-2] not in self.cos_sin_cache:
if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
# Upcast to float32 temporarily for better accuracy
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
Expand All @@ -238,11 +240,11 @@ def _get_rerotation_cos_sin(
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin

self.cos_sin_cache[key_states.shape[-2]] = (
self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
rerotation_cos.to(key_states.dtype).unsqueeze(0),
rerotation_sin.to(key_states.dtype).unsqueeze(0),
)
return self.cos_sin_cache[key_states.shape[-2]]
return self.cos_sin_rerotation_cache[key_states.shape[-2]]

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
Expand Down Expand Up @@ -292,6 +294,21 @@ def update(
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]

# Update the sin/cos cache, which holds sin/cos values for all possible positions
if using_rope and layer_idx == 0:
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
# after all RoPE models have a llama-like cache utilization.
if cos.dim() == 2:
self._cos_cache = cos
self._sin_cache = sin
else:
if self._cos_cache is None:
self._cos_cache = cos[0, ...]
self._sin_cache = sin[0, ...]
elif self._cos_cache.shape[0] < self.window_length:
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)

# [bsz, num_heads, seq_len, head_dim]
if len(self.key_cache) <= layer_idx:
# Empty cache
Expand All @@ -312,7 +329,7 @@ def update(
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if using_rope:
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
key_states, cos[: self.window_length], sin[: self.window_length]
key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
)
if partial_rotation_size is not None:
keys_to_keep, keys_pass = (
Expand Down

0 comments on commit 9719202

Please sign in to comment.