Skip to content

Commit

Permalink
first test passing :o
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Dec 14, 2023
1 parent ac2eccd commit 2230679
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 117 deletions.
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "DynamicCacheWithCrossAttention", "SinkCache"]
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
Expand Down Expand Up @@ -5951,7 +5951,7 @@
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .cache_utils import Cache, DynamicCache, SinkCache
from .cache_utils import Cache, DynamicCache, DynamicCacheWithCrossAttention, SinkCache
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
Expand Down
113 changes: 109 additions & 4 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ def update(
"""
raise NotImplementedError("Make sure to implement `update` in a subclass.")

def update_cross_attention(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int):
"""
Updates the cross attention cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
"""
raise NotImplementedError("Make sure to implement `update_cross_attention` in a subclass to use in encoder-decoder models.")

def has_cached_cross_attentions(self, layer_idx: Optional[int] = 0) -> int:
"""Returns whether it has cached cross attentions. A layer index can be optionally passed."""
return False

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
Expand Down Expand Up @@ -65,8 +83,6 @@ class DynamicCache(Cache):
def __init__(self) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.cross_attention_key_cache: List[torch.Tensor] = []
self.cross_attention_value_cache: List[torch.Tensor] = []
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
Expand Down Expand Up @@ -149,15 +165,15 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache

@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]] = None) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
cache = cls()
if past_key_values is not None:
Expand All @@ -167,6 +183,95 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens
return cache


class DynamicCacheWithCrossAttention(DynamicCache):
"""
Expands `DynamicCache` with cross attention. This is the default for encoder-decoder generative models.
It stores the cross-attention Key and Value states as a list of tensors, one for each layer. The expected shape
for each tensor is `[batch_size, num_heads, encoder_sequence_length, embed_size_per_head]`. Please refer to
`DynamicCache` for documentation and functions related to self-attention.
"""

def __init__(self) -> None:
self.cross_attention_key_cache: List[torch.Tensor] = []
self.cross_attention_value_cache: List[torch.Tensor] = []
super().__init__()

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if layer_idx < len(self):
return (
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.cross_attention_key_cache[layer_idx],
self.cross_attention_value_cache[layer_idx]
)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

def __iter__(self):
"""
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
keys and values
"""
for layer_idx in range(len(self)):
yield (
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.cross_attention_key_cache[layer_idx],
self.cross_attention_value_cache[layer_idx]
)

def update_cross_attention(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int):
"""
Updates the cross attention cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
"""
# Update the cache
if len(self.cross_attention_key_cache) <= layer_idx:
self.cross_attention_key_cache.append(key_states)
self.cross_attention_value_cache.append(value_states)
else:
raise ValueError(
"Attempted to update the cross attention cache for a layer that already has a cached cross attention."
)

def has_cached_cross_attentions(self, layer_idx: Optional[int] = 0) -> int:
"""Returns whether it has cached cross attentions. A layer index can be optionally passed."""
return len(self.cross_attention_key_cache) > layer_idx

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]), self.cross_attention_key_cache[layer_idx], self.cross_attention_value_cache[layer_idx])
return legacy_cache

@classmethod
def from_legacy_cache(
cls,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]] = None
) -> "DynamicCacheWithCrossAttention":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states, cross_attn_key_states, cross_attn_value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
cache.update_cross_attention(cross_attn_key_states, cross_attn_value_states, layer_idx)
return cache


class SinkCache(Cache):
"""
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
Expand Down
Loading

0 comments on commit 2230679

Please sign in to comment.