Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: dynamic cache with cross attention and UMT5 Cache support #28185

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,30 @@ A [`Constraint`] can be used to force the generation to include specific tokens

[[autodoc]] Cache
- update
- update_cross_attention
- has_cached_cross_attentions
- get_seq_length
- get_max_length
- get_usable_length

[[autodoc]] DynamicCache
- update
- get_seq_length
- get_max_length
- get_usable_length
- reorder_cache
- to_legacy_cache
- from_legacy_cache

[[autodoc]] DynamicCacheWithCrossAttention
- update_cross_attention
- has_cached_cross_attentions
- to_legacy_cache
- from_legacy_cache

[[autodoc]] SinkCache
- update
- get_seq_length
- get_max_length
- get_usable_length
- reorder_cache
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,7 +1308,7 @@
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "DynamicCacheWithCrossAttention", "SinkCache"]
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
Expand Down Expand Up @@ -5966,7 +5966,7 @@
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .cache_utils import Cache, DynamicCache, SinkCache
from .cache_utils import Cache, DynamicCache, DynamicCacheWithCrossAttention, SinkCache
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
Expand Down
129 changes: 118 additions & 11 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,25 @@ def update(
"""
raise NotImplementedError("Make sure to implement `update` in a subclass.")

def update_cross_attention(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int):
"""
Updates the cross attention cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
"""
raise NotImplementedError(
"Make sure to implement `update_cross_attention` in a subclass to use in encoder-decoder models."
)

def has_cached_cross_attentions(self, layer_idx: Optional[int] = 0) -> int:
"""Returns whether it has cached cross attentions. A layer index can be optionally passed."""
return False

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
Expand All @@ -56,7 +75,7 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -

class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
A [~Cache] that grows dynamically as more tokens are generated. This is the default for generative models.

It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
Expand Down Expand Up @@ -136,27 +155,29 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return self.key_cache[layer_idx].shape[-2]

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
"""Returns the maximum sequence length of the cached states. [~DynamicCache] does not have a maximum length."""
return None

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
"""Reorders the self attention cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""Converts the [~DynamicCache] instance into the its equivalent in the legacy cache format."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
legacy_cache += (self[layer_idx],)
return legacy_cache

@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]] = None
) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent [~DynamicCache]."""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
Expand All @@ -165,11 +186,97 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens
return cache


class DynamicCacheWithCrossAttention(DynamicCache):
"""
Expands [~DynamicCache] with cross attention. This is the default for encoder-decoder generative models.
It stores the cross-attention Key and Value states as a list of tensors, one for each layer. The expected shape
for each tensor is `[batch_size, num_heads, encoder_sequence_length, embed_size_per_head]`. Please refer to
[~DynamicCache] for documentation and functions related to self-attention.
"""

def __init__(self) -> None:
self.cross_attention_key_cache: List[torch.Tensor] = []
self.cross_attention_value_cache: List[torch.Tensor] = []
super().__init__()

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if layer_idx < len(self):
return (
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.cross_attention_key_cache[layer_idx],
self.cross_attention_value_cache[layer_idx],
)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

def __iter__(self):
"""
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
keys and values
"""
for layer_idx in range(len(self)):
yield (
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.cross_attention_key_cache[layer_idx],
self.cross_attention_value_cache[layer_idx],
)

def update_cross_attention(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int):
"""
Updates the cross attention cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
"""
# Update the cache
if len(self.cross_attention_key_cache) <= layer_idx:
self.cross_attention_key_cache.append(key_states)
self.cross_attention_value_cache.append(value_states)
else:
raise ValueError(
"Attempted to update the cross attention cache for a layer that already has a cached cross attention."
)

def has_cached_cross_attentions(self, layer_idx: Optional[int] = 0) -> int:
"""Returns whether it has cached cross attentions. A layer index can be optionally passed."""
return len(self.cross_attention_key_cache) > layer_idx

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Converts the [~DynamicCacheWithCrossAttention] instance into the its equivalent in the legacy cache format."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += (self[layer_idx],)
return legacy_cache

@classmethod
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]] = None
) -> "DynamicCacheWithCrossAttention":
"""Converts a cache in the legacy cache format into an equivalent [~DynamicCacheWithCrossAttention]."""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states, cross_attn_key_states, cross_attn_value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
cache.update_cross_attention(cross_attn_key_states, cross_attn_value_states, layer_idx)
return cache


class SinkCache(Cache):
"""
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
A [~Cache] that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model
to generate beyond the length of its context window, without losing fluency in the conversation. As it discards
past tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.

It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will lead"
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will lead"
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will lead"
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __init__(self, config: PersimmonConfig, layer_idx: Optional[int] = None):
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will lead"
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will lead"
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
Expand Down
Loading
Loading