diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 40969f227e9dbe..e52c9b2536dbf5 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -373,15 +373,30 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] Cache - update + - update_cross_attention + - has_cached_cross_attentions + - get_seq_length + - get_max_length + - get_usable_length [[autodoc]] DynamicCache - update - get_seq_length + - get_max_length + - get_usable_length - reorder_cache - to_legacy_cache - from_legacy_cache +[[autodoc]] DynamicCacheWithCrossAttention + - update_cross_attention + - has_cached_cross_attentions + - to_legacy_cache + - from_legacy_cache + [[autodoc]] SinkCache - update - get_seq_length + - get_max_length + - get_usable_length - reorder_cache diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 7d0e239a2767cb..6023528feaa38c 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -46,7 +46,9 @@ def update_cross_attention(self, key_states: torch.Tensor, value_states: torch.T layer_idx (`int`): The index of the layer to cache the states for. """ - raise NotImplementedError("Make sure to implement `update_cross_attention` in a subclass to use in encoder-decoder models.") + raise NotImplementedError( + "Make sure to implement `update_cross_attention` in a subclass to use in encoder-decoder models." + ) def has_cached_cross_attentions(self, layer_idx: Optional[int] = 0) -> int: """Returns whether it has cached cross attentions. A layer index can be optionally passed.""" @@ -74,7 +76,7 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) - class DynamicCache(Cache): """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. + A [~Cache] that grows dynamically as more tokens are generated. This is the default for generative models. It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]`. @@ -154,11 +156,11 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return self.key_cache[layer_idx].shape[-2] def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" + """Returns the maximum sequence length of the cached states. [~DynamicCache] does not have a maximum length.""" return None def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" + """Reorders the self attention cache for beam search, given the selected beam indices.""" for layer_idx in range(len(self.key_cache)): device = self.key_cache[layer_idx].device self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) @@ -166,15 +168,17 @@ def reorder_cache(self, beam_idx: torch.LongTensor): self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: - """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" + """Converts the [~DynamicCache] instance into the its equivalent in the legacy cache format.""" legacy_cache = () for layer_idx in range(len(self)): legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) return legacy_cache @classmethod - def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]] = None) -> "DynamicCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" + def from_legacy_cache( + cls, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]] = None + ) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent [~DynamicCache].""" cache = cls() if past_key_values is not None: for layer_idx in range(len(past_key_values)): @@ -185,11 +189,11 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.Tensor, t class DynamicCacheWithCrossAttention(DynamicCache): """ - Expands `DynamicCache` with cross attention. This is the default for encoder-decoder generative models. + Expands [~DynamicCache] with cross attention. This is the default for encoder-decoder generative models. It stores the cross-attention Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, encoder_sequence_length, embed_size_per_head]`. Please refer to - `DynamicCache` for documentation and functions related to self-attention. + [~DynamicCache] for documentation and functions related to self-attention. """ def __init__(self) -> None: @@ -207,7 +211,7 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: self.key_cache[layer_idx], self.value_cache[layer_idx], self.cross_attention_key_cache[layer_idx], - self.cross_attention_value_cache[layer_idx] + self.cross_attention_value_cache[layer_idx], ) else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") @@ -222,7 +226,7 @@ def __iter__(self): self.key_cache[layer_idx], self.value_cache[layer_idx], self.cross_attention_key_cache[layer_idx], - self.cross_attention_value_cache[layer_idx] + self.cross_attention_value_cache[layer_idx], ) def update_cross_attention(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int): @@ -251,18 +255,24 @@ def has_cached_cross_attentions(self, layer_idx: Optional[int] = 0) -> int: return len(self.cross_attention_key_cache) > layer_idx def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: - """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" + """Converts the [~DynamicCacheWithCrossAttention] instance into the its equivalent in the legacy cache format.""" legacy_cache = () for layer_idx in range(len(self)): - legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]), self.cross_attention_key_cache[layer_idx], self.cross_attention_value_cache[layer_idx]) + legacy_cache += ( + ( + self.key_cache[layer_idx], + self.value_cache[layer_idx], + self.cross_attention_key_cache[layer_idx], + self.cross_attention_value_cache[layer_idx], + ), + ) return legacy_cache @classmethod def from_legacy_cache( - cls, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]] = None + cls, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]] = None ) -> "DynamicCacheWithCrossAttention": - """Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" + """Converts a cache in the legacy cache format into an equivalent [~DynamicCacheWithCrossAttention].""" cache = cls() if past_key_values is not None: for layer_idx in range(len(past_key_values)): @@ -274,9 +284,9 @@ def from_legacy_cache( class SinkCache(Cache): """ - A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to - generate beyond the length of its context window, without losing fluency in the conversation. As it discards past - tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. + A [~Cache] that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model + to generate beyond the length of its context window, without losing fluency in the conversation. As it discards + past tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]`. diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 09058da495d8f0..dad278f079edf4 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCacheWithCrossAttention +from ...cache_utils import Cache, DynamicCache, DynamicCacheWithCrossAttention from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -148,7 +148,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[BartConfig] = None, - layer_idx: Optional[int] = None + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -159,7 +159,7 @@ def __init__( self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will lead" "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) @@ -181,37 +181,26 @@ def __init__( def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( + def _prepare_key_values( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: - """Input shape: Batch x Time x Channel""" - if past_key_value is not None and self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) + ) -> Tuple[torch.Tensor, torch.Tensor]: + bsz = hidden_states.shape[0] # if key_value_states are provided this layer is used as a cross-attention layer for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - - # get key, value proj if is_cross_attention: - # `past_key_value.get_seq_length(self.layer_idx) == key_value_states.shape[1]` + # `past_key_value[self.layer_idx][2].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning - if past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) == key_value_states.shape[1]: + if ( + past_key_value is not None + and past_key_value.has_cached_cross_attentions(self.layer_idx) + and past_key_value[self.layer_idx][2].shape[2] == key_value_states.shape[1] + ): # reuse k,v, cross_attentions key_states = past_key_value[self.layer_idx][2] value_states = past_key_value[self.layer_idx][3] @@ -228,6 +217,32 @@ def forward( if past_key_value is not None: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + return key_states, value_states + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + """Input shape: Batch x Time x Channel""" + if past_key_value is not None and self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + + # get key, value proj + key_states, value_states = self._prepare_key_values(hidden_states, key_value_states, past_key_value) + + bsz, tgt_len, _ = hidden_states.size() proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.reshape(*proj_shape) @@ -315,8 +330,8 @@ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def forward( self, hidden_states: torch.Tensor, - key_value_states: Optional[Cache] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -325,84 +340,24 @@ def forward( if output_attentions: raise ValueError("BartFlashAttention2 attention does not support output_attentions") - # if past_key_value is not None and self.layer_idx is None: - # raise ValueError( - # f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - # "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - # "with a layer index." - # ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None + if past_key_value is not None and self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) bsz, q_len, _ = hidden_states.size() # get query proj query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # # get key, value proj - # if is_cross_attention: - # # `past_key_value.get_seq_length(self.layer_idx) == key_value_states.shape[1]` - # # is checking that the `sequence_length` of the `past_key_value` is the same as - # # the provided `key_value_states` to support prefix tuning - # if past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) == key_value_states.shape[1]: - # # reuse k,v, cross_attentions - # key_states = past_key_value[self.layer_idx][2].transpose(1, 2) - # value_states = past_key_value[self.layer_idx][3].transpose(1, 2) - # else: - # # compute cross attention k and v and cache them (if there is a cache) - # key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - # value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - # if past_key_value is not None: - # past_key_value.update_cross_attention(key_states, value_states, self.layer_idx) - # else: - # # compute self attention k and v and cache them (if there is a cache) - # key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - # value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - # if past_key_value is not None: - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + key_states, value_states = self._prepare_key_values(hidden_states, key_value_states, past_key_value) + + # FA2 uses `[batch_size, seq_len, num_heads, head_dim]` key/value states + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need @@ -573,37 +528,13 @@ def forward( "with a layer index." ) - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - # get query proj query_states = self.q_proj(hidden_states) # get key, value proj - if is_cross_attention: - # `past_key_value[self.layer_idx][2].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if past_key_value is not None and past_key_value.has_cached_cross_attentions(self.layer_idx) and past_key_value[self.layer_idx][2].shape[2] == key_value_states.shape[1]: - # reuse k,v, cross_attentions - key_states = past_key_value[self.layer_idx][2] - value_states = past_key_value[self.layer_idx][3] - else: - # compute cross attention k and v and cache them (if there is a cache) - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - if past_key_value is not None: - past_key_value.update_cross_attention(key_states, value_states, self.layer_idx) - else: - # compute self attention k and v and cache them (if there is a cache) - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + key_states, value_states = self._prepare_key_values(hidden_states, key_value_states, past_key_value) + bsz, tgt_len, _ = hidden_states.size() query_states = self._shape(query_states, tgt_len, bsz) # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, @@ -859,6 +790,7 @@ class BartPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True def _init_weights(self, module): std = self.config.init_std @@ -1018,13 +950,20 @@ def __init_subclass__(self): Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. This is also known as the legacy + cache format. - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all @@ -1266,7 +1205,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([BartDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [BartDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)] + ) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_sdpa = config._attn_implementation == "sdpa" @@ -1338,21 +1279,23 @@ def forward( - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of - shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing - `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more - control over how to convert `input_ids` indices into associated vectors than the model's internal - embedding lookup matrix. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1367,6 +1310,7 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache + use_cache = use_cache or past_key_values is not None # If a cache is passed, assumes it is to be used return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds @@ -1382,6 +1326,16 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + if encoder_hidden_states is not None: + # Used as an encoder-decoder -- has cross attention + past_key_values = DynamicCacheWithCrossAttention.from_legacy_cache(past_key_values) + else: + # Used as a decoder-only -- doesn't have cross attention + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + # past_key_values_length seq_length = input_shape[-1] past_key_values_length = past_key_values.get_usable_length(seq_length) if past_key_values is not None else 0 @@ -1446,7 +1400,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1505,6 +1458,9 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) + if use_cache and use_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1604,11 +1560,6 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCacheWithCrossAttention.from_legacy_cache(past_key_values) - if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, @@ -1646,14 +1597,9 @@ def forward( if not return_dict: return decoder_outputs + encoder_outputs - if use_cache: - next_cache = decoder_outputs.past_key_values - if use_legacy_cache: - next_cache = decoder_outputs.past_key_values.to_legacy_cache() - return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=next_cache, + past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, @@ -1803,18 +1749,36 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past_key_values is used + # Omit tokens covered by past_key_values if past_key_values is not None: - past_length = past_key_values.get_seq_length() - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the decoder_attention_mask exceeds the length of decoder_input_ids, then we are in a + # setting where some of the inputs are exclusivelly passed as part of the cache (e.g. when passing + # input_embeds as input) + if decoder_attention_mask is not None and decoder_attention_mask.shape[1] > decoder_input_ids.shape[1]: + decoder_input_ids = decoder_input_ids[:, -(decoder_attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # decoder_input_ids based on the past_length. + elif past_length < decoder_input_ids.shape[1]: + decoder_input_ids = decoder_input_ids[:, past_length:] + # 3 - Otherwise (past_length >= decoder_input_ids.shape[1]), let's assume decoder_input_ids only has + # unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and decoder_attention_mask is not None + and cache_length + decoder_input_ids.shape[1] > max_cache_length + ): + decoder_attention_mask = decoder_attention_mask[:, -max_cache_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -2304,17 +2268,36 @@ def prepare_inputs_for_generation( if attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) - if past_key_values: - past_length = past_key_values.get_seq_length() - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] - input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index bba2680f5732e4..e8ceffda7bcc8c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -291,7 +291,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will lead" "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 29af7c0e88e979..2dcfec63475ad6 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -202,7 +202,7 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will lead" "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 7268673441fe87..2d2d4124ef990f 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -252,7 +252,7 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will lead" "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 17163dcd8edf9b..bd31e7f89dc483 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -185,7 +185,7 @@ def __init__(self, config: PersimmonConfig, layer_idx: Optional[int] = None): self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will lead" "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index c73d5b942e6d4f..aa5472d5dfcf2a 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -224,7 +224,7 @@ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None): self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will lead" "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6e11818f69a134..b9a721f6658d5e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -54,7 +54,7 @@ SpeechEncoderDecoderModel, top_k_top_p_filtering, ) - from transformers.cache_utils import DynamicCache + from transformers.cache_utils import DynamicCache, DynamicCacheWithCrossAttention from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -1919,6 +1919,11 @@ def test_new_cache_format(self, num_beams, do_sample): config.use_cache = True config.is_decoder = True + if config.is_encoder_decoder: + cache_cls = DynamicCacheWithCrossAttention + else: + cache_cls = DynamicCache + model = model_class(config).to(torch_device).eval() generation_kwargs = { "max_new_tokens": 5, @@ -1934,14 +1939,14 @@ def test_new_cache_format(self, num_beams, do_sample): legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) set_seed(seed) new_results = model.generate( - input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs + input_ids, attention_mask=attention_mask, past_key_values=cache_cls(), **generation_kwargs ) # The two sets of generated sequences must match, despite the cache format between forward passes being # different self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist()) self.assertTrue(isinstance(legacy_results.past_key_values, tuple)) - self.assertTrue(isinstance(new_results.past_key_values, DynamicCache)) + self.assertTrue(isinstance(new_results.past_key_values, cache_cls)) # The contents of the two caches, when converted to the same format (in both directions!), must match legacy_cache = legacy_results.past_key_values @@ -1956,7 +1961,7 @@ def test_new_cache_format(self, num_beams, do_sample): ) new_cache = new_results.past_key_values - legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values) + legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) for layer_idx in range(len(new_cache)): for kv_idx in range(len(new_cache[layer_idx])): self.assertTrue(