Skip to content

Commit

Permalink
bart finalized
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Dec 15, 2023
1 parent 2230679 commit 9d5ed42
Show file tree
Hide file tree
Showing 9 changed files with 214 additions and 201 deletions.
15 changes: 15 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,30 @@ A [`Constraint`] can be used to force the generation to include specific tokens

[[autodoc]] Cache
- update
- update_cross_attention
- has_cached_cross_attentions
- get_seq_length
- get_max_length
- get_usable_length

[[autodoc]] DynamicCache
- update
- get_seq_length
- get_max_length
- get_usable_length
- reorder_cache
- to_legacy_cache
- from_legacy_cache

[[autodoc]] DynamicCacheWithCrossAttention
- update_cross_attention
- has_cached_cross_attentions
- to_legacy_cache
- from_legacy_cache

[[autodoc]] SinkCache
- update
- get_seq_length
- get_max_length
- get_usable_length
- reorder_cache
48 changes: 29 additions & 19 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def update_cross_attention(self, key_states: torch.Tensor, value_states: torch.T
layer_idx (`int`):
The index of the layer to cache the states for.
"""
raise NotImplementedError("Make sure to implement `update_cross_attention` in a subclass to use in encoder-decoder models.")
raise NotImplementedError(
"Make sure to implement `update_cross_attention` in a subclass to use in encoder-decoder models."
)

def has_cached_cross_attentions(self, layer_idx: Optional[int] = 0) -> int:
"""Returns whether it has cached cross attentions. A layer index can be optionally passed."""
Expand Down Expand Up @@ -74,7 +76,7 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -

class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
A [~Cache] that grows dynamically as more tokens are generated. This is the default for generative models.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
Expand Down Expand Up @@ -154,27 +156,29 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return self.key_cache[layer_idx].shape[-2]

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
"""Returns the maximum sequence length of the cached states. [~DynamicCache] does not have a maximum length."""
return None

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
"""Reorders the self attention cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
"""Converts the [~DynamicCache] instance into the its equivalent in the legacy cache format."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache

@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]] = None) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]] = None
) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent [~DynamicCache]."""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
Expand All @@ -185,11 +189,11 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.Tensor, t

class DynamicCacheWithCrossAttention(DynamicCache):
"""
Expands `DynamicCache` with cross attention. This is the default for encoder-decoder generative models.
Expands [~DynamicCache] with cross attention. This is the default for encoder-decoder generative models.
It stores the cross-attention Key and Value states as a list of tensors, one for each layer. The expected shape
for each tensor is `[batch_size, num_heads, encoder_sequence_length, embed_size_per_head]`. Please refer to
`DynamicCache` for documentation and functions related to self-attention.
[~DynamicCache] for documentation and functions related to self-attention.
"""

def __init__(self) -> None:
Expand All @@ -207,7 +211,7 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.cross_attention_key_cache[layer_idx],
self.cross_attention_value_cache[layer_idx]
self.cross_attention_value_cache[layer_idx],
)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
Expand All @@ -222,7 +226,7 @@ def __iter__(self):
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.cross_attention_key_cache[layer_idx],
self.cross_attention_value_cache[layer_idx]
self.cross_attention_value_cache[layer_idx],
)

def update_cross_attention(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int):
Expand Down Expand Up @@ -251,18 +255,24 @@ def has_cached_cross_attentions(self, layer_idx: Optional[int] = 0) -> int:
return len(self.cross_attention_key_cache) > layer_idx

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
"""Converts the [~DynamicCacheWithCrossAttention] instance into the its equivalent in the legacy cache format."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]), self.cross_attention_key_cache[layer_idx], self.cross_attention_value_cache[layer_idx])
legacy_cache += (
(
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.cross_attention_key_cache[layer_idx],
self.cross_attention_value_cache[layer_idx],
),
)
return legacy_cache

@classmethod
def from_legacy_cache(
cls,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]] = None
cls, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]] = None
) -> "DynamicCacheWithCrossAttention":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
"""Converts a cache in the legacy cache format into an equivalent [~DynamicCacheWithCrossAttention]."""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
Expand All @@ -274,9 +284,9 @@ def from_legacy_cache(

class SinkCache(Cache):
"""
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
A [~Cache] that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model
to generate beyond the length of its context window, without losing fluency in the conversation. As it discards
past tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
Expand Down
Loading

0 comments on commit 9d5ed42

Please sign in to comment.