Skip to content

Commit

Permalink
[Core generation] Adds support for static KV cache (#27931)
Browse files Browse the repository at this point in the history
Co-authored-by: fxmarty <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: Joao Gante <[email protected]>
  • Loading branch information
4 people authored and Ita Zaporozhets committed May 14, 2024
1 parent 59c9e40 commit a130b12
Show file tree
Hide file tree
Showing 19 changed files with 473 additions and 231 deletions.
4 changes: 4 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- update
- get_seq_length
- reorder_cache

[[autodoc]] StaticCache
- update
- get_seq_length
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,7 @@
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache", "StaticCache"]
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
Expand Down Expand Up @@ -6073,7 +6073,7 @@
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .cache_utils import Cache, DynamicCache, SinkCache
from .cache_utils import Cache, DynamicCache, SinkCache, StaticCache
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
Expand Down
92 changes: 92 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import torch

from .configuration_utils import PretrainedConfig


@dataclass
class Cache:
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
Expand Down Expand Up @@ -320,3 +324,91 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))


class StaticCache(Cache):
"""
Static Cache class to be used with `torch.compile(model)`.
Parameters:
config (`PretrainedConfig):
The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
"""

def __init__(
self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=torch.float32
) -> None:
super().__init__()
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads
self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype

cache_shape = (max_batch_size, self.num_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.seen_tokens = 0

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for. Kept for backward compatibility
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
to know how much of the cache it should overwrite.
Return:
A tuple containing the updated key and value states.
"""
new_cache_positions = cache_kwargs.get("position_ids")
k_out = self.key_cache
v_out = self.value_cache

k_out[:, :, new_cache_positions] = key_states
v_out[:, :, new_cache_positions] = value_states

self.seen_tokens += key_states.shape[-2]
return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
return self.seen_tokens

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return self.max_cache_len

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
device = self.key_cache.device
self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
device = self.value_cache.device
self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))

def to_legacy_cache(self):
"""Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
return None
8 changes: 8 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ class GenerationConfig(PushToHubMixin):
reduce by 1
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
> Parameters specific to the caching mechanism:
cache_implementation (`str`, *optional*, default to `None`):
Cache class that should be used when generating.
> Wild card
generation_kwargs:
Expand Down Expand Up @@ -321,6 +326,9 @@ def __init__(self, **kwargs):
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")

# Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None)

# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)

Expand Down
19 changes: 18 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch.distributed as dist
from torch import nn

from ..cache_utils import Cache, DynamicCache
from ..cache_utils import Cache, DynamicCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
Expand Down Expand Up @@ -92,6 +92,10 @@
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
}


@dataclass
class GenerateDecoderOnlyOutput(ModelOutput):
Expand Down Expand Up @@ -1398,6 +1402,19 @@ def generate(
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_length

# if we don't pass `past_key_values` and a cache_implementation is specified
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING and not model_kwargs.get(
"past_key_values", False
):
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[generation_config.cache_implementation]
if not callable(getattr(self, "_setup_cache", None)):
raise ValueError(
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
" Make sure it has a `_setup_cache` function."
)
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

# 7. determine generation mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def forward(self, hidden_states):
return self.weight * hidden_states.to(input_dtype)


# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama
class OpenLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
Expand Down Expand Up @@ -154,7 +154,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand Down Expand Up @@ -130,7 +130,7 @@ def _get_unpad_data(attention_mask):
)


# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon
class FalconRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def attention_mask_func(attention_scores, ltor_mask):


class GPTNeoXRotaryEmbedding(nn.Module):
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

Expand Down Expand Up @@ -617,7 +617,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):

# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding
class RotaryEmbedding(nn.Module):
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand Down
Loading

0 comments on commit a130b12

Please sign in to comment.