Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core generation] Adds support for static KV cache #27931

Merged
merged 121 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
121 commits
Select commit Hold shift + click to select a range
17b8b38
initial commit
ArthurZucker Dec 10, 2023
80ef815
lol
ArthurZucker Dec 10, 2023
2639b5d
nits
ArthurZucker Dec 10, 2023
9f2e1e4
nits nits nits nits nits
ArthurZucker Dec 11, 2023
271260c
Merge branch 'main' of github.com:huggingface/transformers into stati…
ArthurZucker Dec 21, 2023
5be65ff
Merge branch 'main' of github.com:huggingface/transformers into stati…
ArthurZucker Jan 4, 2024
c6b6d35
some nits and some testing
ArthurZucker Jan 4, 2024
90224dd
nits
ArthurZucker Jan 4, 2024
24ffbfb
Wrong implementation but creates good masks in general and is pretty …
ArthurZucker Jan 5, 2024
cd95e98
what seems to work for now
ArthurZucker Jan 5, 2024
7cd3655
nites
ArthurZucker Jan 5, 2024
eeebc66
re-init cache
ArthurZucker Jan 5, 2024
5819a85
make it automatic
ArthurZucker Jan 5, 2024
216dd8f
nits and nits
ArthurZucker Jan 5, 2024
a48ae88
more nits
ArthurZucker Jan 5, 2024
aeefa26
nits
ArthurZucker Jan 8, 2024
e05f8da
nits
ArthurZucker Jan 9, 2024
07f5cdc
more nits
ArthurZucker Jan 9, 2024
f769b0e
nits
ArthurZucker Jan 10, 2024
bb6a160
fastest working cache for now
ArthurZucker Jan 10, 2024
dd1e42c
also include the attention mask
ArthurZucker Jan 10, 2024
a3b0003
updates
ArthurZucker Jan 11, 2024
dacd0ff
current state
ArthurZucker Jan 11, 2024
021f674
working code
ArthurZucker Jan 11, 2024
98af852
dummy mask for now
ArthurZucker Jan 15, 2024
8594670
Merge branch 'main' of github.com:huggingface/transformers into stati…
ArthurZucker Jan 15, 2024
60af293
Merge branch 'static-cache' of github.com:huggingface/transformers in…
ArthurZucker Jan 15, 2024
05166fe
Merge branch 'main' of github.com:huggingface/transformers into stati…
ArthurZucker Jan 16, 2024
9c1a3b4
a better design
ArthurZucker Jan 24, 2024
d5395af
some fix
ArthurZucker Jan 24, 2024
a20a183
make outputs match
ArthurZucker Jan 24, 2024
bce7653
fastest yet
ArthurZucker Jan 24, 2024
0e59f70
remove chunck qkv
ArthurZucker Jan 24, 2024
e573000
cleanup
ArthurZucker Jan 25, 2024
fce7e46
some test
ArthurZucker Jan 28, 2024
24ef3cf
goat changes
ArthurZucker Jan 29, 2024
344309f
nits
ArthurZucker Jan 29, 2024
42e5a38
dynamic was not working anymore
ArthurZucker Jan 29, 2024
6637755
cache reverts
ArthurZucker Jan 29, 2024
6ec92df
small nits
ArthurZucker Jan 29, 2024
d784927
sdpa
ArthurZucker Jan 29, 2024
0332d3f
Merge branch 'static-cache' of github.com:huggingface/transformers in…
ArthurZucker Jan 29, 2024
4e40703
make sure sdpa passed
ArthurZucker Jan 29, 2024
770c5e6
nit
ArthurZucker Jan 29, 2024
7bd1fca
cleqnups
ArthurZucker Jan 29, 2024
25fd440
cleanup
ArthurZucker Feb 1, 2024
4c3220f
nits
ArthurZucker Feb 1, 2024
d51acfa
Merge branch 'main' of github.com:huggingface/transformers into stati…
ArthurZucker Feb 1, 2024
2b2e0c2
pass sdpa
ArthurZucker Feb 1, 2024
4b93379
make sure dynamic is BC
ArthurZucker Feb 1, 2024
ab07e80
update check on the attn weight
ArthurZucker Feb 1, 2024
77ccdce
Merge branch 'static-cache' of https://github.com/huggingface/transfo…
ArthurZucker Feb 1, 2024
ad6832a
faster?
ArthurZucker Feb 1, 2024
1cb6a16
add `_reset_cache`
ArthurZucker Feb 1, 2024
d044263
Merge branch 'static-cache' of github.com:huggingface/transformers in…
ArthurZucker Feb 1, 2024
c838352
nit
ArthurZucker Feb 1, 2024
e80b6a1
Merge branch 'static-cache' of https://github.com/huggingface/transfo…
ArthurZucker Feb 1, 2024
8308809
nit
ArthurZucker Feb 1, 2024
0132a2c
Merge branch 'static-cache' of github.com:huggingface/transformers in…
ArthurZucker Feb 1, 2024
87b3064
merges
ArthurZucker Feb 1, 2024
4d88605
Styling
ArthurZucker Feb 1, 2024
011931e
nites
ArthurZucker Feb 1, 2024
e838f57
revert some BC breaking changes
ArthurZucker Feb 1, 2024
c23815a
make all tests pass
ArthurZucker Feb 1, 2024
c985064
torch long not float for attention mask
ArthurZucker Feb 1, 2024
6a954d5
try to remove the guard
ArthurZucker Feb 1, 2024
45760d6
BC
ArthurZucker Feb 1, 2024
64f5455
even more cleanup
ArthurZucker Feb 1, 2024
f103454
fix `past_key_value.get_usable_length(kv_seq_len, self.layer_idx)`
ArthurZucker Feb 1, 2024
c7b5d2c
pushh a fast version
ArthurZucker Feb 1, 2024
538ccf0
what actually works
ArthurZucker Feb 1, 2024
ce42624
no contigious()
ArthurZucker Feb 1, 2024
33832d2
push for eager as well
ArthurZucker Feb 2, 2024
8a53f53
simplest and best way to do it yet
ArthurZucker Feb 2, 2024
f560fe5
merge
ArthurZucker Feb 2, 2024
5f90ed4
style
ArthurZucker Feb 2, 2024
e5c731e
Merge branch 'main' of github.com:huggingface/transformers into stati…
ArthurZucker Feb 2, 2024
b6c9180
dix dtype
ArthurZucker Feb 2, 2024
8de700f
fix dtype issues
ArthurZucker Feb 2, 2024
e92b1a0
nits
ArthurZucker Feb 2, 2024
d9f7f16
nit
ArthurZucker Feb 2, 2024
d98f277
support export to torchscript
ArthurZucker Feb 2, 2024
65217de
Credit helpers
ArthurZucker Feb 2, 2024
a219236
nits
ArthurZucker Feb 2, 2024
7a6b57d
handle SDPA edge cases
ArthurZucker Feb 5, 2024
2822423
handle sdpa quircks
ArthurZucker Feb 5, 2024
70df80e
revert performance break
ArthurZucker Feb 5, 2024
b4fbf3f
Apply suggestions from code review
ArthurZucker Feb 6, 2024
70d5ded
fix merges
ArthurZucker Feb 6, 2024
ec22fb1
revert removing ```
ArthurZucker Feb 6, 2024
9968b0e
add another test
ArthurZucker Feb 6, 2024
dc885ca
update test
ArthurZucker Feb 6, 2024
0c2a66f
Merge branch 'static-cache' of https://github.com/huggingface/transfo…
ArthurZucker Feb 6, 2024
e087adc
use a model that is not protected
ArthurZucker Feb 6, 2024
c0cf294
only test generation
ArthurZucker Feb 6, 2024
da720c8
update the cache utils to define the position_ids in the cache class
ArthurZucker Feb 6, 2024
8f4c49d
fix static cache
ArthurZucker Feb 6, 2024
c22d564
add subtest to llama tests
ArthurZucker Feb 6, 2024
89929b9
update testing suite
ArthurZucker Feb 6, 2024
d4b24ee
nuke whatever we can
ArthurZucker Feb 6, 2024
d7e400e
smthing wrong with cache
ArthurZucker Feb 6, 2024
9d9eec3
nit
ArthurZucker Feb 6, 2024
4eb8a9e
latest changes
ArthurZucker Feb 7, 2024
dad35d6
Merge branch 'main' of https://github.com/huggingface/transformers in…
ArthurZucker Feb 7, 2024
6f516a0
don't use einsum
ArthurZucker Feb 7, 2024
f25ac8e
nit
ArthurZucker Feb 7, 2024
17f0350
remove one unused var
ArthurZucker Feb 7, 2024
b91efbb
update test value
ArthurZucker Feb 7, 2024
256c324
let style be happy
ArthurZucker Feb 7, 2024
327b77a
make sure cache tests are slow
ArthurZucker Feb 7, 2024
8509e91
slow was removed add it back to test cach utils
ArthurZucker Feb 7, 2024
60aa86d
fix flash_attention_2
ArthurZucker Feb 8, 2024
7de4ace
very small nit
ArthurZucker Feb 8, 2024
453df24
revert test change
ArthurZucker Feb 8, 2024
0a1f8d2
make mistral the default copied from
ArthurZucker Feb 8, 2024
040b2f1
fix copies
ArthurZucker Feb 8, 2024
1763ec7
nits
ArthurZucker Feb 8, 2024
c4242c8
finishup
ArthurZucker Feb 8, 2024
af097af
fixup
ArthurZucker Feb 8, 2024
5bbde6f
Merge branch 'main' of https://github.com/huggingface/transformers in…
ArthurZucker Feb 8, 2024
7f8ca33
skip tests
ArthurZucker Feb 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- update
- get_seq_length
- reorder_cache

[[autodoc]] StaticCache
- update
- get_seq_length
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,7 @@
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache", "StaticCache"]
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
Expand Down Expand Up @@ -6073,7 +6073,7 @@
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .cache_utils import Cache, DynamicCache, SinkCache
from .cache_utils import Cache, DynamicCache, SinkCache, StaticCache
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
Expand Down
92 changes: 92 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import torch

from .configuration_utils import PretrainedConfig


@dataclass
class Cache:
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
Expand Down Expand Up @@ -320,3 +324,91 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))


class StaticCache(Cache):
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
"""
Static Cache class to be used with `torch.compile(model)`.

Parameters:
config (`PretrainedConfig):
The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
"""

def __init__(
self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=torch.float32
) -> None:
super().__init__()
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads
self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype

cache_shape = (max_batch_size, self.num_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
self.seen_tokens = 0

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.

Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for. Kept for backward compatibility
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
to know how much of the cache it should overwrite.

Return:
A tuple containing the updated key and value states.
"""
new_cache_positions = cache_kwargs.get("position_ids")
k_out = self.key_cache
v_out = self.value_cache

k_out[:, :, new_cache_positions] = key_states
v_out[:, :, new_cache_positions] = value_states

self.seen_tokens += key_states.shape[-2]
return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
return self.seen_tokens

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return self.max_cache_len

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
device = self.key_cache.device
self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
device = self.value_cache.device
self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))

def to_legacy_cache(self):
"""Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
return None
8 changes: 8 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ class GenerationConfig(PushToHubMixin):
reduce by 1
- `"constant"`: `num_assistant_tokens` stays unchanged during generation

> Parameters specific to the caching mechanism:

cache_implementation (`str`, *optional*, default to `None`):
Cache class that should be used when generating.

> Wild card

generation_kwargs:
Expand Down Expand Up @@ -321,6 +326,9 @@ def __init__(self, **kwargs):
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")

# Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None)

# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)

Expand Down
19 changes: 18 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch.distributed as dist
from torch import nn

from ..cache_utils import Cache, DynamicCache
from ..cache_utils import Cache, DynamicCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
Expand Down Expand Up @@ -92,6 +92,10 @@
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
}


@dataclass
class GenerateDecoderOnlyOutput(ModelOutput):
Expand Down Expand Up @@ -1398,6 +1402,19 @@ def generate(
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_length

# if we don't pass `past_key_values` and a cache_implementation is specified
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING and not model_kwargs.get(
"past_key_values", False
):
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[generation_config.cache_implementation]
if not callable(getattr(self, "_setup_cache", None)):
raise ValueError(
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
" Make sure it has a `_setup_cache` function."
)
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

# 7. determine generation mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def forward(self, hidden_states):
return self.weight * hidden_states.to(input_dtype)


# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama
class OpenLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
Expand Down Expand Up @@ -154,7 +154,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down Expand Up @@ -130,7 +130,7 @@ def _get_unpad_data(attention_mask):
)


# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon
class FalconRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def attention_mask_func(attention_scores, ltor_mask):


class GPTNeoXRotaryEmbedding(nn.Module):
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

Expand Down Expand Up @@ -617,7 +617,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):

# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding
class RotaryEmbedding(nn.Module):
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down
Loading
Loading