From 2aa4df35746c17c2273690ea5f5cdc86094f646e Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 3 May 2024 17:36:30 +0200 Subject: [PATCH 01/29] first version --- .../models/mistral/modeling_mistral.py | 400 +++++++++++------- tests/models/mistral/test_modeling_mistral.py | 51 ++- 2 files changed, 292 insertions(+), 159 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 665e95a8fd7856..ca96e778c4e43f 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -30,8 +30,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -99,31 +99,46 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + + @property + def sin_cached(self): + logger.warning_once( + "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + ) + return self._sin_cached - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), + @property + def cos_cached(self): + logger.warning_once( + "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) + return self._cos_cached + + @torch.no_grad + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -136,7 +151,7 @@ def rotate_half(x): # copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # TODO @Arthur no longer copied from LLama after static cache -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -144,9 +159,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -157,8 +171,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -227,7 +241,7 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.rotary_emb = MistralRotaryEmbedding( self.head_dim, @@ -246,6 +260,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: @@ -262,20 +277,13 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_key_value = getattr(self, "past_key_value", past_key_value) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position" : cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -284,19 +292,9 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -346,13 +344,14 @@ def forward( use_cache: bool = False, **kwargs, ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" ) - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") + output_attentions = False + bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -363,21 +362,13 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + past_key_value = getattr(self, "past_key_value", past_key_value) kv_seq_len = key_states.shape[-2] if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_seq_length() - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) use_sliding_windows = ( _flash_supports_window_size @@ -632,6 +623,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -658,41 +650,44 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # In case static cache is used, it is an instance attribute. + past_key_value = getattr(self, "past_key_value", past_key_value) + if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + causal_mask = attention_mask if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: + if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an + # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + is_causal = True if causal_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -729,12 +724,9 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -749,6 +741,11 @@ def forward( past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -761,6 +758,8 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + **kwargs, ) hidden_states = residual + hidden_states @@ -934,12 +933,13 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -950,74 +950,35 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, use_cache) + hidden_states = inputs_embeds # decoder layers @@ -1033,20 +994,22 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, position_ids, past_key_values, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -1065,8 +1028,11 @@ def forward( next_cache = None if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - + next_cache = ( + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, DynamicCache) + else next_decoder_cache + ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -1075,6 +1041,106 @@ def forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) + + # copied from Llama implementation + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + use_cache: bool + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self._attn_implementation == "flash_attention_2": + if attention_mask is not None and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + + + exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + + if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() != 4): + # assume signed int tensor for cache_position + exclude_mask |= torch.arange(target_length, device=device) <= (cache_position.reshape(-1,1) - self.config.sliding_window) + + causal_mask *= exclude_mask + + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + elif attention_mask.dim() == 4: + # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with + # cache. In that case, the 4D attention mask attends to the newest tokens only. + if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input lenght is deprecated and will be " + "removed in transformers v4.42.0" + ) + offset = cache_position[0] + else: + offset = 0 + mask_shape = attention_mask.shape + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask class MistralForCausalLM(MistralPreTrainedModel): @@ -1114,13 +1180,14 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1165,6 +1232,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1197,14 +1265,29 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, ): # Omit tokens covered by past_key_values + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() + # cache_length = past_key_values.get_seq_length() + # past_length = past_key_values.seen_tokens + # max_cache_length = past_key_values.get_max_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -1241,13 +1324,14 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.contiguous()} model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, } ) diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 3500024b3ea173..1a45991dbefa7b 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -20,6 +20,8 @@ import unittest import pytest +from packaging import version +from parameterized import parameterized from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed from transformers.testing_utils import ( @@ -469,7 +471,12 @@ def test_flash_attn_2_generate_use_cache(self): @slow def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mistral flash attention does not support right padding") - + + # copied from Llama tests to supress errors for now + @unittest.skip("TODO @gante fix this for Mistral") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass @require_torch_gpu class MistralIntegrationTest(unittest.TestCase): @@ -627,3 +634,45 @@ def test_speculative_generation(self): del model backend_empty_cache(torch_device) gc.collect() + + @slow + def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest("This test requires torch >= 2.3 to run.") + + NUM_TOKENS_TO_GENERATE = 40 + EXPECTED_TEXT_COMPLETION = { + 8: ['My favourite condiment is 100% ketchup. I love it on everything. ' + 'I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles'], + # 7: [], + } + + prompts = ["My favourite condiment is "] + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False) + tokenizer.pad_token = tokenizer.eos_token + model = MistralForCausalLM.from_pretrained( + "mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16 + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], dynamic_text) + + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) + + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) \ No newline at end of file From 1ddd617cc016d4775b89289193fe36cb9d8a7d8b Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Sat, 4 May 2024 07:23:42 +0200 Subject: [PATCH 02/29] fix sliding window --- src/transformers/models/mistral/modeling_mistral.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index ca96e778c4e43f..d2f711f143a02b 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1094,16 +1094,10 @@ def _update_causal_mask( ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() != 4): - # assume signed int tensor for cache_position - exclude_mask |= torch.arange(target_length, device=device) <= (cache_position.reshape(-1,1) - self.config.sliding_window) - + exclude_mask |= torch.arange(target_length, device=device) < (cache_position.reshape(-1,1) - self.config.sliding_window) causal_mask *= exclude_mask causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) From 2f5c7ca50808c6a5282d7045b349d3ad74bd7472 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 7 May 2024 06:23:21 +0200 Subject: [PATCH 03/29] fix style --- docs/source/en/llm_optims.md | 2 +- .../models/mistral/modeling_mistral.py | 89 +++++++++---------- tests/models/mistral/test_modeling_mistral.py | 23 +++-- 3 files changed, 58 insertions(+), 56 deletions(-) diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 4b44c1d78c81f0..3273f5dac41dfe 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -29,7 +29,7 @@ To optimize this, you can use a kv-cache to store the past keys and values inste The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up. > [!WARNING] -> Currently, only [Command R](./model_doc/cohere), [Gemma](./model_doc/gemma) and [Llama](./model_doc/llama2) models support static kv-cache and torch.compile. +> Currently, only [Command R](./model_doc/cohere), [Gemma](./model_doc/gemma), [Llama](./model_doc/llama2) and [Mistral](./model_doc/mistral.md) models support static kv-cache and torch.compile. For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model. diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index d2f711f143a02b..f77312e80fa83d 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -17,10 +17,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Mistral model.""" +"""PyTorch Mistral model.""" + import inspect import math -import warnings from typing import List, Optional, Tuple, Union import torch @@ -88,8 +88,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral -# TODO @Arthur no longer copied from LLama after static cache +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral class MistralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -123,7 +122,7 @@ def cos_cached(self): "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) return self._cos_cached - + @torch.no_grad def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] @@ -149,8 +148,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -# TODO @Arthur no longer copied from LLama after static cache +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -261,12 +259,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -278,12 +271,12 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) - + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position" : cache_position} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -342,7 +335,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - **kwargs, + cache_position: Optional[torch.LongTensor] = None, ): if isinstance(past_key_value, StaticCache): raise ValueError( @@ -365,7 +358,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_seq_length() + kv_seq_len += cache_position[0] cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -605,8 +598,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral -# TODO @Arthur no longer copied from LLama after static cache +# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral class MistralSdpaAttention(MistralAttention): """ Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -669,7 +661,6 @@ def forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: @@ -680,14 +671,14 @@ def forward( # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` is_causal = True if causal_mask is None and q_len > 1 else False - + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -741,11 +732,6 @@ def forward( past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -954,7 +940,7 @@ def forward( raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) - + if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." @@ -972,13 +958,14 @@ def forward( cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - + if position_ids is None: position_ids = cache_position.unsqueeze(0) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, use_cache + ) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, use_cache) - hidden_states = inputs_embeds # decoder layers @@ -1041,21 +1028,20 @@ def forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) - - # copied from Llama implementation + def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - use_cache: bool + use_cache: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - + if self._attn_implementation == "flash_attention_2": if attention_mask is not None and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] @@ -1065,9 +1051,10 @@ def _update_causal_mask( " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - if attention_mask is not None and 0.0 in attention_mask: return attention_mask + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask return None - + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. @@ -1076,8 +1063,11 @@ def _update_causal_mask( using_static_cache = isinstance(past_key_values, StaticCache) if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, is_training=self.training, + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, ): return None @@ -1096,8 +1086,10 @@ def _update_causal_mask( causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() != 4): - exclude_mask |= torch.arange(target_length, device=device) < (cache_position.reshape(-1,1) - self.config.sliding_window) + if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() == 2): + exclude_mask |= torch.arange(target_length, device=device) < ( + cache_position.reshape(-1, 1) - self.config.sliding_window + ) causal_mask *= exclude_mask causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) @@ -1120,9 +1112,9 @@ def _update_causal_mask( offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( + mask_slice + ) if ( self.config._attn_implementation == "sdpa" @@ -1258,6 +1250,7 @@ def forward( attentions=outputs.attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1272,9 +1265,6 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - # cache_length = past_key_values.get_seq_length() - # past_length = past_key_values.seen_tokens - # max_cache_length = past_key_values.get_max_length() past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) @@ -1282,6 +1272,7 @@ def prepare_inputs_for_generation( else None ) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -1320,6 +1311,12 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids.contiguous()} + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + model_inputs.update( { "position_ids": position_ids, diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 1a45991dbefa7b..25c2cbc1f658a4 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -12,8 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Testing suite for the PyTorch Mistral model. """ - +"""Testing suite for the PyTorch Mistral model.""" import gc import tempfile @@ -471,13 +470,14 @@ def test_flash_attn_2_generate_use_cache(self): @slow def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mistral flash attention does not support right padding") - - # copied from Llama tests to supress errors for now + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_new_cache_format @unittest.skip("TODO @gante fix this for Mistral") @parameterized.expand([(1, False), (1, True), (4, False)]) def test_new_cache_format(self, num_beams, do_sample): pass + @require_torch_gpu class MistralIntegrationTest(unittest.TestCase): # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) @@ -634,7 +634,7 @@ def test_speculative_generation(self): del model backend_empty_cache(torch_device) gc.collect() - + @slow def test_compile_static_cache(self): # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 @@ -644,9 +644,14 @@ def test_compile_static_cache(self): NUM_TOKENS_TO_GENERATE = 40 EXPECTED_TEXT_COMPLETION = { - 8: ['My favourite condiment is 100% ketchup. I love it on everything. ' - 'I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles'], - # 7: [], + 8: [ + "My favourite condiment is 100% ketchup. I love it on everything. " + "I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles" + ], + 7: [ + "My favourite condiment is 100% ketchup. I love it on everything. " + "I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles" + ], } prompts = ["My favourite condiment is "] @@ -675,4 +680,4 @@ def test_compile_static_cache(self): **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" ) static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) \ No newline at end of file + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) From e1173238e344f14f8b8c2ca035e102a9064d9e1e Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Thu, 9 May 2024 05:56:25 +0200 Subject: [PATCH 04/29] add sliding window cache --- src/transformers/cache_utils.py | 112 ++++++++++++++++++ src/transformers/generation/utils.py | 16 ++- .../open_llama/modeling_open_llama.py | 6 +- .../models/falcon/modeling_falcon.py | 6 +- .../models/gpt_neox/modeling_gpt_neox.py | 6 +- .../modeling_gpt_neox_japanese.py | 3 +- .../models/idefics/modeling_idefics.py | 3 +- .../models/mistral/modeling_mistral.py | 64 +++++++--- .../models/mixtral/modeling_mixtral.py | 18 ++- .../models/persimmon/modeling_persimmon.py | 6 +- src/transformers/models/phi/modeling_phi.py | 6 +- .../models/qwen2/modeling_qwen2.py | 9 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 9 +- .../models/stablelm/modeling_stablelm.py | 6 +- .../models/starcoder2/modeling_starcoder2.py | 15 ++- tests/models/mistral/test_modeling_mistral.py | 48 +++++++- tests/models/mixtral/test_modeling_mixtral.py | 6 + tests/models/qwen2/test_modeling_qwen2.py | 6 + .../qwen2_moe/test_modeling_qwen2_moe.py | 6 + 19 files changed, 300 insertions(+), 51 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 2e29e19ade46a4..da8fc9ebc24b8b 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -448,3 +448,115 @@ def reset(self): # In-place ops prevent breaking the static address self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() + + +class SlidingWindowCache(Cache): + """ + Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. + + Parameters: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (*optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + """ + + def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: + super().__init__() + self.max_batch_size = max_batch_size + # take the minimum of max_cache_len and config.sliding_window so that we allocate less memory + # when we do short-sentence generation + self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len + self.model_sliding_window_size = config.sliding_window + self.sliding_window_size = min(self.max_cache_len, self.model_sliding_window_size) + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + + self.dtype = dtype if dtype is not None else torch.float32 + self.num_key_value_heads = ( + config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads + ) + + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + cache_shape = ( + config.num_hidden_layers, + max_batch_size, + self.num_key_value_heads, + self.sliding_window_size, + self.head_dim, + ) + + self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + + torch._dynamo.mark_static_address(self.key_cache) + torch._dynamo.mark_static_address(self.value_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Dict[str, Any] | None = None, + ) -> Tuple[torch.Tensor]: + cache_position = cache_kwargs.get("cache_position") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + + # assume this only happens in prefill phase when prompt length > sliding_window_size + if cache_position.shape[0] > self.sliding_window_size: + k_out = key_states[:, :, -self.sliding_window_size :, :] + v_out = value_states[:, :, -self.sliding_window_size :, :] + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return key_states, value_states + + slicing = torch.ones(self.sliding_window_size, dtype=torch.long, device=value_states.device).cumsum(0) + cache_position = cache_position.clamp(0, self.sliding_window_size - 1) + to_shift = cache_position >= self.sliding_window_size - 1 + indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size + + k_out, v_out = k_out, v_out + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + + return k_out, v_out + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + # assume this will be called only in the first generation step + # `cache_postion` will be used in other cases + return 0 + + def get_max_length(self) -> int | None: + # in theory there is no limit because the sliding window size is fixed + # no matter how long the sentence is + return None + + def need_new_cache(self, max_batch_size: int, new_max_cache_len: int) -> bool: + # this is used by model.generate, when we reuse model between generations, + # we need to be careful because the new `max_cache_len` may become + # larger and `self.sliding_window_size` might change accordingly + return max_batch_size > self.max_batch_size or ( + self.sliding_window_size < self.model_sliding_window_size and new_max_cache_len > self.max_cache_len + ) + + def reset(self): + self.key_cache.zero_() + self.value_cache.zero_() diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3bff8eea50f0c6..934c186c73a0cd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -24,7 +24,7 @@ import torch.distributed as dist from torch import nn -from ..cache_utils import Cache, DynamicCache, StaticCache +from ..cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..models.auto import ( @@ -95,9 +95,7 @@ if is_accelerate_available(): from accelerate.hooks import AlignDevicesHook, add_hook_to_module -NEED_SETUP_CACHE_CLASSES_MAPPING = { - "static": StaticCache, -} +NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache} @dataclass @@ -1594,6 +1592,16 @@ def generate( ) if generation_config.cache_implementation == "static": model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length) + elif generation_config.cache_implementation == "sliding_window": + if not hasattr(self.config, "sliding_window") or self.config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + model_kwargs["past_key_values"] = self._get_sliding_window_cache( + batch_size, generation_config.max_length + ) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 098f8c7da50d5e..e4f7677a671ef2 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -63,7 +63,8 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama +# TODO @longjie no longer copied from Mistral after static cache class OpenLlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -154,7 +155,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 76ca4110e81848..ede15ca427c6a1 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -82,7 +82,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -124,7 +125,8 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon +# TODO @longjie no longer copied from Mistral after static cache class FalconRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index e0b2309fc9658b..815933e13ac279 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -522,7 +522,8 @@ def attention_mask_func(attention_scores, ltor_mask): class GPTNeoXRotaryEmbedding(nn.Module): - # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ + # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ + # TODO @longjie no longer copied from Mistral after static cache def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -614,7 +615,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index ea934581aa4f21..b5dccf6ea1d6be 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -230,7 +230,8 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding class RotaryEmbedding(nn.Module): - # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ + # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ + # TODO @longjie no longer copied from Mistral after static cache def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 622e336fe4034e..57eb1e14bb055c 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -478,7 +478,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f77312e80fa83d..1d50ee0999f918 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -30,7 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel @@ -88,7 +88,6 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral class MistralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -630,6 +629,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) bsz, q_len, _ = hidden_states.size() @@ -644,7 +644,7 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # In case static cache is used, it is an instance attribute. past_key_value = getattr(self, "past_key_value", past_key_value) @@ -1059,9 +1059,12 @@ def _update_causal_mask( # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + # cache_position must be valid here no matter which cache we use + past_seen_tokens = cache_position[0] if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) - if self.config._attn_implementation == "sdpa" and not using_static_cache: + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + if self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1074,8 +1077,13 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if using_static_cache: + # SlidingWindowCache + if using_sliding_window_cache: + target_length = max(sequence_length, self.config.sliding_window) + # StaticCache + elif using_static_cache: target_length = past_key_values.get_max_length() + # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -1086,12 +1094,30 @@ def _update_causal_mask( causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() == 2): - exclude_mask |= torch.arange(target_length, device=device) < ( - cache_position.reshape(-1, 1) - self.config.sliding_window - ) - causal_mask *= exclude_mask + if self.config.sliding_window is not None: + if attention_mask is not None and attention_mask.dim() == 4: + logger.warning_once( + "Sliding window will not take effect when passing 4d custom masks" + "you may get unexpected results, use attention mask generated by tokenizer" + "or set model.config.sliding_window to None if you don't want sliding window" + ) + + # can only happen in prefill phase, when the prompt length > sliding window length, we need to do this + # manually because we are returning the whole prompt token sequence in `SlidingWindowCache`, maybe a better + # way is to support chunked prefill instead + if sequence_length > self.config.sliding_window and using_sliding_window_cache: + exclude_mask |= torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - self.config.sliding_window + ) + + # not using `SlidingWindowCache` and attention mask supports sliding window + if (attention_mask is None or attention_mask.dim() == 2) and not using_sliding_window_cache: + exclude_mask |= torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - self.config.sliding_window + ) + + causal_mask *= exclude_mask causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit @@ -1112,9 +1138,9 @@ def _update_causal_mask( offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( - mask_slice - ) + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice if ( self.config._attn_implementation == "sdpa" @@ -1250,7 +1276,6 @@ def forward( attentions=outputs.attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1305,6 +1330,15 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] + # crop the attention_mask to sliding window size during decode phase if using SlidingWindowCache + if ( + past_length > 0 + and attention_mask is not None + and isinstance(past_key_values, SlidingWindowCache) + and attention_mask.shape[1] > past_key_values.sliding_window_size + ): + attention_mask = attention_mask[:, -past_key_values.sliding_window_size :] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index e5a81c4c9083ed..9c1ab1c07d6230 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -181,7 +181,8 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache class MixtralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -226,7 +227,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -268,7 +270,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache class MixtralAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -397,7 +400,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache class MixtralFlashAttention2(MixtralAttention): """ Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays @@ -692,7 +696,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache class MixtralSdpaAttention(MixtralAttention): """ Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -1074,7 +1079,8 @@ def _init_weights(self, module): "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", MIXTRAL_START_DOCSTRING, ) -# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache class MixtralModel(MixtralPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 8d4ad532074f19..2e05a0ecea51f0 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -40,7 +40,8 @@ _CONFIG_FOR_DOC = "PersimmonConfig" -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Persimmon +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Persimmon +# TODO @longjie no longer copied from Mistral after static cache class PersimmonRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -132,7 +133,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 795ff18e5bcd1f..0fb3d1a63ec1a3 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -76,7 +76,8 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi +# TODO @longjie no longer copied from Mistral after static cache class PhiRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -168,7 +169,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index b5a1370ae1fc8f..9e803da066aac6 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -90,7 +90,8 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2 +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2 +# TODO @longjie no longer copied from Mistral after static cache class Qwen2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -135,7 +136,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -620,7 +622,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2 +# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2 +# TODO @longjie no longer copied from Mistral after static cache class Qwen2SdpaAttention(Qwen2Attention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 838425505b3b1a..9e0218c2a2383e 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -166,7 +166,8 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2Moe +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2Moe +# TODO @longjie no longer copied from Mistral after static cache class Qwen2MoeRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -211,7 +212,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -698,7 +700,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2Moe +# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2Moe +# TODO @longjie no longer copied from Mistral after static cache class Qwen2MoeSdpaAttention(Qwen2MoeAttention): """ Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index bc133ffb3d7329..6dd2c58e5e91e6 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -66,7 +66,8 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->StableLm +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->StableLm +# TODO @longjie no longer copied from Mistral after static cache class StableLmRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -158,7 +159,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 61e8518d659cae..8660ee06ad9b42 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -70,7 +70,8 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Starcoder2 +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Starcoder2 +# TODO @longjie no longer copied from Mistral after static cache class Starcoder2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -115,7 +116,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -599,7 +601,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Starcoder2 +# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Starcoder2 +# TODO @longjie no longer copied from Mistral after static cache class Starcoder2SdpaAttention(Starcoder2Attention): """ Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -708,7 +711,8 @@ def __init__(self, config: Starcoder2Config, layer_idx: int): self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) - # Copied from transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward + # copied from transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward + # TODO @longjie no longer copied from Mistral after static cache def forward( self, hidden_states: torch.Tensor, @@ -1067,7 +1071,8 @@ def forward( ) -# Copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM with MISTRAL->STARCODER2,Mistral-7B-v0.1->starcoder2-7b_16k,Mistral->Starcoder2,mistralai->bigcode +# copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM with MISTRAL->STARCODER2,Mistral-7B-v0.1->starcoder2-7b_16k,Mistral->Starcoder2,mistralai->bigcode +# TODO @longjie no longer copied from Mistral after static cache class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 25c2cbc1f658a4..c3c685112a4c0c 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -471,7 +471,6 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mistral flash attention does not support right padding") - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_new_cache_format @unittest.skip("TODO @gante fix this for Mistral") @parameterized.expand([(1, False), (1, True), (4, False)]) def test_new_cache_format(self, num_beams, do_sample): @@ -681,3 +680,50 @@ def test_compile_static_cache(self): ) static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) + + @slow + def test_compile_sliding_window_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest("This test requires torch >= 2.3 to run.") + + NUM_TOKENS_TO_GENERATE = 40 + EXPECTED_TEXT_COMPLETION = { + 8: [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in a vacuum, " + "and 2) the laws of physics are the same for all observers in uniform motion.\n\nThe first part of the theory is" + ], + 7: [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in a vacuum, " + "and 2) the laws of physics are the same for all observers in uniform motion.\n\nThe first part of the theory is" + ], + } + + prompts = ["Simply put, the theory of relativity states that "] + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False) + tokenizer.pad_token = tokenizer.eos_token + model = MistralForCausalLM.from_pretrained( + "mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16 + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], dynamic_text) + + # Sliding Window Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) + + # Sliding Window Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 0d92595d8cfa85..d3ec1a7e18db80 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -19,6 +19,7 @@ import unittest import pytest +from parameterized import parameterized from transformers import MixtralConfig, is_torch_available from transformers.testing_utils import ( @@ -505,6 +506,11 @@ def test_load_balancing_loss(self): # This is to mimic torch.testing.assert_not_close self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) + @unittest.skip("TODO @gante fix this for Mixtral") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + @require_torch class MixtralIntegrationTest(unittest.TestCase): diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index f4e88a97f06a53..2e5de17ffc5433 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -20,6 +20,7 @@ import unittest import pytest +from parameterized import parameterized from transformers import AutoTokenizer, Qwen2Config, is_torch_available, set_seed from transformers.testing_utils import ( @@ -481,6 +482,11 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Qwen2 flash attention does not support right padding") + @unittest.skip("TODO @gante fix this for Qwen2") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + @require_torch class Qwen2IntegrationTest(unittest.TestCase): diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index f0818e680d3da8..8620ddb56575a9 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -20,6 +20,7 @@ import unittest import pytest +from parameterized import parameterized from transformers import AutoTokenizer, Qwen2MoeConfig, is_torch_available, set_seed from transformers.testing_utils import ( @@ -545,6 +546,11 @@ def test_load_balancing_loss(self): # This is to mimic torch.testing.assert_not_close self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) + @unittest.skip("TODO @gante fix this for Qwen2Moe") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + @require_torch class Qwen2MoeIntegrationTest(unittest.TestCase): From dec4904e2299f7cd7ecea3d93422928ea620980d Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Thu, 9 May 2024 06:02:30 +0200 Subject: [PATCH 05/29] fix style --- docs/source/en/llm_optims.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 3273f5dac41dfe..5e49f0e1ebd3ab 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -29,7 +29,7 @@ To optimize this, you can use a kv-cache to store the past keys and values inste The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up. > [!WARNING] -> Currently, only [Command R](./model_doc/cohere), [Gemma](./model_doc/gemma), [Llama](./model_doc/llama2) and [Mistral](./model_doc/mistral.md) models support static kv-cache and torch.compile. +> Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and torch.compile. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list. For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model. From 900615b346d0a5f52799024f2e30f5a1cb7563a1 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 10 May 2024 02:28:01 +0200 Subject: [PATCH 06/29] address comments --- .../object-detection/run_object_detection.py | 4 +- setup.py | 10 +-- src/transformers/cache_utils.py | 11 --- src/transformers/generation/utils.py | 52 ++++++------ .../open_llama/modeling_open_llama.py | 6 +- .../models/falcon/modeling_falcon.py | 6 +- .../models/gemma/modeling_gemma.py | 13 +-- .../models/gpt_neox/modeling_gpt_neox.py | 6 +- .../modeling_gpt_neox_japanese.py | 3 +- .../models/idefics/modeling_idefics.py | 3 +- .../models/llama/modeling_llama.py | 13 +-- .../models/mistral/modeling_mistral.py | 81 +++++++------------ src/transformers/models/olmo/modeling_olmo.py | 11 +-- .../models/persimmon/modeling_persimmon.py | 6 +- src/transformers/models/phi/modeling_phi.py | 6 +- .../models/qwen2/modeling_qwen2.py | 9 +-- .../models/qwen2_moe/modeling_qwen2_moe.py | 9 +-- .../models/stablelm/modeling_stablelm.py | 6 +- .../models/starcoder2/modeling_starcoder2.py | 15 ++-- tests/models/mistral/test_modeling_mistral.py | 6 -- tests/models/mixtral/test_modeling_mixtral.py | 6 -- tests/models/qwen2/test_modeling_qwen2.py | 6 -- .../qwen2_moe/test_modeling_qwen2_moe.py | 6 -- 23 files changed, 96 insertions(+), 198 deletions(-) diff --git a/examples/pytorch/object-detection/run_object_detection.py b/examples/pytorch/object-detection/run_object_detection.py index ba6ee1e55a481a..3f0769568f981a 100644 --- a/examples/pytorch/object-detection/run_object_detection.py +++ b/examples/pytorch/object-detection/run_object_detection.py @@ -244,7 +244,9 @@ class DataTrainingArguments: ) image_square_size: Optional[int] = field( default=600, - metadata={"help": "Image longest size will be resized to this value, then image will be padded to square."}, + metadata={ + "help": "Image longest size will be resized to this value, then image will be padded to square." + }, ) max_train_samples: Optional[int] = field( default=None, diff --git a/setup.py b/setup.py index 3061127768db9b..89d334464c2480 100644 --- a/setup.py +++ b/setup.py @@ -260,15 +260,7 @@ def run(self): extras["sklearn"] = deps_list("scikit-learn") extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp") -extras["tf-cpu"] = deps_list( - "keras", - "tensorflow-cpu", - "onnxconverter-common", - "tf2onnx", - "tensorflow-text", - "keras-nlp", - "tensorflow-probability", -) +extras["tf-cpu"] = deps_list("keras", "tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp", "tensorflow-probability") extras["torch"] = deps_list("torch", "accelerate") extras["accelerate"] = deps_list("accelerate") diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index da8fc9ebc24b8b..fca928a9adf0c5 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -485,8 +485,6 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) - self.key_cache: List[torch.Tensor] = [] - self.value_cache: List[torch.Tensor] = [] cache_shape = ( config.num_hidden_layers, max_batch_size, @@ -527,7 +525,6 @@ def update( to_shift = cache_position >= self.sliding_window_size - 1 indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size - k_out, v_out = k_out, v_out k_out = k_out[:, :, indices] v_out = v_out[:, :, indices] @@ -549,14 +546,6 @@ def get_max_length(self) -> int | None: # no matter how long the sentence is return None - def need_new_cache(self, max_batch_size: int, new_max_cache_len: int) -> bool: - # this is used by model.generate, when we reuse model between generations, - # we need to be careful because the new `max_cache_len` may become - # larger and `self.sliding_window_size` might change accordingly - return max_batch_size > self.max_batch_size or ( - self.sliding_window_size < self.model_sliding_window_size and new_max_cache_len > self.max_cache_len - ) - def reset(self): self.key_cache.zero_() self.value_cache.zero_() diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 934c186c73a0cd..d2446ae68ee712 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1309,24 +1309,39 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) return model_kwargs - def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache: + def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache: """ - Sets a static cache for `generate`, that will persist across calls. A new cache will only be initialized a + Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a new `generate` call requires a larger cache. - Returns the resulting static cache object. + Returns the resulting cache object. """ - needs_new_cache = ( - not hasattr(self, "_static_cache") - or self._static_cache.max_batch_size < max_batch_size - or self._static_cache.max_cache_len < max_cache_len + cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] + need_new_cache = ( + not hasattr(self, "_cache") + or (not isinstance(self._cache, cache_cls)) + or self._cache.max_batch_size < max_batch_size ) - if needs_new_cache: + if cache_implementation == "sliding_window": + if not hasattr(self.config, "sliding_window") or self.config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + need_new_cache = need_new_cache or ( + self._cache.sliding_window_size < self._cache.model_sliding_window_size + and max_cache_len > self._cache.max_cache_len + ) + elif cache_implementation == "static": + need_new_cache = need_new_cache or self._cache.max_cache_len < max_cache_len + + if need_new_cache: if hasattr(self.config, "_pre_quantization_dtype"): cache_dtype = self.config._pre_quantization_dtype else: cache_dtype = self.dtype - self._static_cache = StaticCache( + self._cache = cache_cls( config=self.config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, @@ -1334,8 +1349,8 @@ def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCa dtype=cache_dtype, ) else: - self._static_cache.reset() # reset the cache for a new generation - return self._static_cache + self._cache.reset() + return self._cache def _prepare_special_tokens( self, generation_config: GenerationConfig, kwargs_has_attention_mask: Optional[bool] = None @@ -1590,18 +1605,9 @@ def generate( "This model does not support the `cache_implementation` argument. Please check the following " "issue: https://github.com/huggingface/transformers/issues/28981." ) - if generation_config.cache_implementation == "static": - model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length) - elif generation_config.cache_implementation == "sliding_window": - if not hasattr(self.config, "sliding_window") or self.config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - model_kwargs["past_key_values"] = self._get_sliding_window_cache( - batch_size, generation_config.max_length - ) + model_kwargs["past_key_values"] = self._get_cache( + generation_config.cache_implementation, batch_size, generation_config.max_length + ) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index e4f7677a671ef2..4e42d716e895f9 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -63,8 +63,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->OpenLlama class OpenLlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -155,8 +154,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index ede15ca427c6a1..1e715f8482ae47 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -82,8 +82,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -125,8 +124,7 @@ def _get_unpad_data(attention_mask): ) -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Falcon class FalconRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 8f7893704780d1..3d5aecad4c37df 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -16,7 +16,6 @@ """ PyTorch Gemma model.""" import math -import warnings from typing import List, Optional, Tuple, Union import torch @@ -250,7 +249,6 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -262,8 +260,9 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -617,7 +616,6 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -633,10 +631,6 @@ def forward( (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) residual = hidden_states @@ -651,7 +645,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - **kwargs, ) hidden_states = residual + hidden_states diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 815933e13ac279..4980f7c636175a 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -522,8 +522,7 @@ def attention_mask_func(attention_scores, ltor_mask): class GPTNeoXRotaryEmbedding(nn.Module): - # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ - # TODO @longjie no longer copied from Mistral after static cache + # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -615,8 +614,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index b5dccf6ea1d6be..24d211317887cf 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -230,8 +230,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding class RotaryEmbedding(nn.Module): - # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ - # TODO @longjie no longer copied from Mistral after static cache + # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 57eb1e14bb055c..83a5cb65106383 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -478,8 +478,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d840b03faf71fb..58a2bab727e5b7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -20,7 +20,6 @@ """PyTorch LLaMA model.""" import math -import warnings from typing import List, Optional, Tuple, Union import torch @@ -115,7 +114,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s @property def sin_cached(self): logger.warning_once( - "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "The sin_cached attribute will be removed in 4.42. Bear in mind that its contents changed in v4.40. Use " "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) return self._sin_cached @@ -123,7 +122,7 @@ def sin_cached(self): @property def cos_cached(self): logger.warning_once( - "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "The cos_cached attribute will be removed in 4.42. Bear in mind that its contents changed in v4.40. Use " "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) return self._cos_cached @@ -326,7 +325,6 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -425,7 +423,6 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -714,7 +711,6 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -730,10 +726,6 @@ def forward( (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) residual = hidden_states @@ -748,7 +740,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - **kwargs, ) hidden_states = residual + hidden_states diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 1d50ee0999f918..510ea0137bf8d5 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -107,22 +107,25 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) @property + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.sin_cached with Llama->Mistral def sin_cached(self): logger.warning_once( - "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + "The sin_cached attribute will be removed in 4.42. Bear in mind that its contents changed in v4.40. Use " + "the forward method of RoPE from now on instead. It is not used in the `MistralAttention` class" ) return self._sin_cached @property + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.cos_cached with Llama->Mistral def cos_cached(self): logger.warning_once( - "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + "The cos_cached attribute will be removed in 4.42. Bear in mind that its contents changed in v4.40. Use " + "the forward method of RoPE from now on instead. It is not used in the `MistralAttention` class" ) return self._cos_cached - @torch.no_grad + @torch.no_grad() + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) @@ -220,15 +223,15 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): "when creating this class." ) + self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads + self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True - self.attention_dropout = config.attention_dropout if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -246,9 +249,7 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): base=self.rope_theta, ) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - + # Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward with Gemma->Mistral def forward( self, hidden_states: torch.Tensor, @@ -270,21 +271,20 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) - cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: + if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask @@ -300,8 +300,8 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) if not output_attentions: @@ -696,12 +696,13 @@ def forward( } +# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL class MistralDecoderLayer(nn.Module): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = MistralMLP(config) self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -712,17 +713,17 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -745,7 +746,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - **kwargs, ) hidden_states = residual + hidden_states @@ -963,7 +963,7 @@ def forward( position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, use_cache + attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions ) hidden_states = inputs_embeds @@ -1036,6 +1036,7 @@ def _update_causal_mask( cache_position: torch.Tensor, past_key_values: Cache, use_cache: bool, + output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -1102,17 +1103,7 @@ def _update_causal_mask( "you may get unexpected results, use attention mask generated by tokenizer" "or set model.config.sliding_window to None if you don't want sliding window" ) - - # can only happen in prefill phase, when the prompt length > sliding window length, we need to do this - # manually because we are returning the whole prompt token sequence in `SlidingWindowCache`, maybe a better - # way is to support chunked prefill instead - if sequence_length > self.config.sliding_window and using_sliding_window_cache: - exclude_mask |= torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - self.config.sliding_window - ) - - # not using `SlidingWindowCache` and attention mask supports sliding window - if (attention_mask is None or attention_mask.dim() == 2) and not using_sliding_window_cache: + elif not using_sliding_window_cache or sequence_length > self.config.sliding_window: exclude_mask |= torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - self.config.sliding_window ) @@ -1123,29 +1114,17 @@ def _update_causal_mask( causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) - elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - logger.warning_once( - "Passing a 4d mask shorter than the input lenght is deprecated and will be " - "removed in transformers v4.42.0" - ) - offset = cache_position[0] - else: - offset = 0 - mask_shape = attention_mask.shape - mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" + and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 5009ac84be2ea7..bcad5fb0f6566e 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -20,7 +20,6 @@ """PyTorch OLMo model.""" import math -import warnings from typing import List, Optional, Tuple, Union import torch @@ -112,7 +111,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s @property def sin_cached(self): logger.warning_once( - "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "The sin_cached attribute will be removed in 4.42. Bear in mind that its contents changed in v4.40. Use " "the forward method of RoPE from now on instead. It is not used in the `OlmoAttention` class" ) return self._sin_cached @@ -120,7 +119,7 @@ def sin_cached(self): @property def cos_cached(self): logger.warning_once( - "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "The cos_cached attribute will be removed in 4.42. Bear in mind that its contents changed in v4.40. Use " "the forward method of RoPE from now on instead. It is not used in the `OlmoAttention` class" ) return self._cos_cached @@ -690,7 +689,6 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -706,10 +704,6 @@ def forward( (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) residual = hidden_states @@ -724,7 +718,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - **kwargs, ) hidden_states = residual + hidden_states diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 2e05a0ecea51f0..01d124fb9873fe 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -40,8 +40,7 @@ _CONFIG_FOR_DOC = "PersimmonConfig" -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Persimmon -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Persimmon class PersimmonRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -133,8 +132,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 0fb3d1a63ec1a3..7b79643a17ba8c 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -76,8 +76,7 @@ def _get_unpad_data(attention_mask): ) -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi class PhiRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -169,8 +168,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 9e803da066aac6..a930a4bdcf7190 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -90,8 +90,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2 -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2 class Qwen2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -136,8 +135,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -622,8 +620,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2 -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2 class Qwen2SdpaAttention(Qwen2Attention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 9e0218c2a2383e..26a04d710bdc3d 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -166,8 +166,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2Moe -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe class Qwen2MoeRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -212,8 +211,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -700,8 +698,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2Moe -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2Moe class Qwen2MoeSdpaAttention(Qwen2MoeAttention): """ Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 6dd2c58e5e91e6..f6a8a8a2be2be3 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -66,8 +66,7 @@ def _get_unpad_data(attention_mask): ) -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->StableLm -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->StableLm class StableLmRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -159,8 +158,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 8660ee06ad9b42..5004f698417ff8 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -70,8 +70,7 @@ def _get_unpad_data(attention_mask): ) -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Starcoder2 -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Starcoder2 class Starcoder2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -116,8 +115,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -601,8 +599,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Starcoder2 -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Starcoder2 class Starcoder2SdpaAttention(Starcoder2Attention): """ Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -711,8 +708,7 @@ def __init__(self, config: Starcoder2Config, layer_idx: int): self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) - # copied from transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward - # TODO @longjie no longer copied from Mistral after static cache + # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer.forward def forward( self, hidden_states: torch.Tensor, @@ -725,7 +721,8 @@ def forward( ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: if "padding_mask" in kwargs: warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + "Passing `padding_mask` is deprecated and will be removed in v4.37. " + "Please make sure use `attention_mask` instead.`" ) """ Args: diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index c3c685112a4c0c..7a08785af5a4fd 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -20,7 +20,6 @@ import pytest from packaging import version -from parameterized import parameterized from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed from transformers.testing_utils import ( @@ -471,11 +470,6 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mistral flash attention does not support right padding") - @unittest.skip("TODO @gante fix this for Mistral") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch_gpu class MistralIntegrationTest(unittest.TestCase): diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index d3ec1a7e18db80..0d92595d8cfa85 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -19,7 +19,6 @@ import unittest import pytest -from parameterized import parameterized from transformers import MixtralConfig, is_torch_available from transformers.testing_utils import ( @@ -506,11 +505,6 @@ def test_load_balancing_loss(self): # This is to mimic torch.testing.assert_not_close self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) - @unittest.skip("TODO @gante fix this for Mixtral") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch class MixtralIntegrationTest(unittest.TestCase): diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 2e5de17ffc5433..f4e88a97f06a53 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -20,7 +20,6 @@ import unittest import pytest -from parameterized import parameterized from transformers import AutoTokenizer, Qwen2Config, is_torch_available, set_seed from transformers.testing_utils import ( @@ -482,11 +481,6 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Qwen2 flash attention does not support right padding") - @unittest.skip("TODO @gante fix this for Qwen2") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch class Qwen2IntegrationTest(unittest.TestCase): diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index 8620ddb56575a9..f0818e680d3da8 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -20,7 +20,6 @@ import unittest import pytest -from parameterized import parameterized from transformers import AutoTokenizer, Qwen2MoeConfig, is_torch_available, set_seed from transformers.testing_utils import ( @@ -546,11 +545,6 @@ def test_load_balancing_loss(self): # This is to mimic torch.testing.assert_not_close self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) - @unittest.skip("TODO @gante fix this for Qwen2Moe") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch class Qwen2MoeIntegrationTest(unittest.TestCase): From e04d68b937601a8c302d7e3ce1a408a2931e7cbd Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 10 May 2024 03:21:18 +0200 Subject: [PATCH 07/29] fix test --- .../models/mistral/modeling_mistral.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 510ea0137bf8d5..75952b80f10a55 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -226,7 +226,7 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim + self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings @@ -949,9 +949,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - + + return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) + return_legacy_cache = True if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1013,13 +1015,10 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if isinstance(next_decoder_cache, DynamicCache) - else next_decoder_cache - ) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( From bb0811b172e69a48b6e75e492056e8cf41d1f94d Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 10 May 2024 03:44:50 +0200 Subject: [PATCH 08/29] fix style --- src/transformers/models/gemma/modeling_gemma.py | 1 - src/transformers/models/mistral/modeling_mistral.py | 7 +------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 3d5aecad4c37df..2584c6f823701e 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -260,7 +260,6 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 75952b80f10a55..9ccb71b8b22c96 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -270,7 +270,6 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -354,7 +353,6 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - past_key_value = getattr(self, "past_key_value", past_key_value) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += cache_position[0] @@ -646,9 +644,6 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - # In case static cache is used, it is an instance attribute. - past_key_value = getattr(self, "past_key_value", past_key_value) - if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} @@ -949,7 +944,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - + return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) From dd7ff33ead79621014f8ec877e1851ab2c8b686b Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 10 May 2024 06:20:29 +0200 Subject: [PATCH 09/29] move sliding window check inside cache init --- src/transformers/cache_utils.py | 7 +++++++ src/transformers/generation/utils.py | 6 ------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index fca928a9adf0c5..d015af080ad33e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -468,6 +468,13 @@ class SlidingWindowCache(Cache): """ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: + if not hasattr(config, "sliding_window") or config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + super().__init__() self.max_batch_size = max_batch_size # take the minimum of max_cache_len and config.sliding_window so that we allocate less memory diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d2446ae68ee712..d0c204e8be98b6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1323,12 +1323,6 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l or self._cache.max_batch_size < max_batch_size ) if cache_implementation == "sliding_window": - if not hasattr(self.config, "sliding_window") or self.config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) need_new_cache = need_new_cache or ( self._cache.sliding_window_size < self._cache.model_sliding_window_size and max_cache_len > self._cache.max_cache_len From 3e08b7e29e68b4bfaa8406f6ab601e796e93a301 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Mon, 13 May 2024 04:35:37 +0200 Subject: [PATCH 10/29] add compile for mixtral --- .../object-detection/run_object_detection.py | 4 +- .../open_llama/modeling_open_llama.py | 4 +- .../models/falcon/modeling_falcon.py | 4 +- .../models/gpt_neox/modeling_gpt_neox.py | 4 +- .../modeling_gpt_neox_japanese.py | 2 +- .../models/idefics/modeling_idefics.py | 2 +- .../models/mixtral/modeling_mixtral.py | 445 ++++++++++-------- .../models/persimmon/modeling_persimmon.py | 4 +- src/transformers/models/phi/modeling_phi.py | 4 +- .../models/qwen2/modeling_qwen2.py | 9 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 6 +- .../models/stablelm/modeling_stablelm.py | 4 +- .../models/starcoder2/modeling_starcoder2.py | 8 +- 13 files changed, 279 insertions(+), 221 deletions(-) diff --git a/examples/pytorch/object-detection/run_object_detection.py b/examples/pytorch/object-detection/run_object_detection.py index 3f0769568f981a..ba6ee1e55a481a 100644 --- a/examples/pytorch/object-detection/run_object_detection.py +++ b/examples/pytorch/object-detection/run_object_detection.py @@ -244,9 +244,7 @@ class DataTrainingArguments: ) image_square_size: Optional[int] = field( default=600, - metadata={ - "help": "Image longest size will be resized to this value, then image will be padded to square." - }, + metadata={"help": "Image longest size will be resized to this value, then image will be padded to square."}, ) max_train_samples: Optional[int] = field( default=None, diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 4e42d716e895f9..7f39c65cbbe2c5 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -63,7 +63,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->OpenLlama +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding with Qwen2->OpenLlama class OpenLlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -154,7 +154,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb +# Copied from transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 1e715f8482ae47..86ce1be53fcaff 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -82,7 +82,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb +# Copied from transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -124,7 +124,7 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Falcon +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding with Mixtral->Qwen2 class FalconRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 4980f7c636175a..83b562b1e85e45 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -522,7 +522,7 @@ def attention_mask_func(attention_scores, ltor_mask): class GPTNeoXRotaryEmbedding(nn.Module): - # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__ + # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding.__init__ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -614,7 +614,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb +# Copied from transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 24d211317887cf..6e3612a09c7e52 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -230,7 +230,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding class RotaryEmbedding(nn.Module): - # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__ + # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding.__init__ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 83a5cb65106383..9e9f70c4f025ed 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -478,7 +478,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb +# Copied from transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 9c1ab1c07d6230..db12c1aa545e2e 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -20,7 +20,6 @@ """ PyTorch Mixtral model.""" import inspect import math -import warnings from typing import List, Optional, Tuple, Union import torch @@ -30,18 +29,14 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache -from ...modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import is_torch_greater_or_equal_than_1_13 from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -50,7 +45,6 @@ logging, replace_return_docstrings, ) -from ...utils.import_utils import is_torch_fx_available from .configuration_mixtral import MixtralConfig @@ -60,14 +54,6 @@ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) -# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. -# It means that the function will not be traced through and simply appear as a node in the graph. -if is_torch_fx_available(): - if not is_torch_greater_or_equal_than_1_13: - import torch.fx - - _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) - logger = logging.get_logger(__name__) @@ -181,8 +167,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral class MixtralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -192,31 +177,49 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + @property + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.sin_cached with Llama->Mixtral + def sin_cached(self): + logger.warning_once( + "The sin_cached attribute will be removed in 4.42. Bear in mind that its contents changed in v4.40. Use " + "the forward method of RoPE from now on instead. It is not used in the `MixtralAttention` class" + ) + return self._sin_cached - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), + @property + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.cos_cached with Llama->Mixtral + def cos_cached(self): + logger.warning_once( + "The cos_cached attribute will be removed in 4.42. Bear in mind that its contents changed in v4.40. Use " + "the forward method of RoPE from now on instead. It is not used in the `MixtralAttention` class" ) + return self._cos_cached + + @torch.no_grad() + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -227,9 +230,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -237,9 +239,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -250,8 +251,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -270,8 +271,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral class MixtralAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -289,6 +289,7 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): "when creating this class." ) + self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -297,7 +298,6 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True - self.attention_dropout = config.attention_dropout if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -307,7 +307,7 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.rotary_emb = MixtralRotaryEmbedding( self.head_dim, @@ -315,9 +315,7 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): base=self.rope_theta, ) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - + # Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward with Gemma->Mixtral def forward( self, hidden_states: torch.Tensor, @@ -326,12 +324,8 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - **kwargs, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -342,41 +336,22 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -390,8 +365,8 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) if not output_attentions: @@ -426,15 +401,14 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - **kwargs, + cache_position: Optional[torch.LongTensor] = None, ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" ) - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -447,18 +421,9 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + kv_seq_len += cache_position[0] + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) use_sliding_windows = ( @@ -696,8 +661,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral class MixtralSdpaAttention(MixtralAttention): """ Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -714,6 +678,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -728,6 +693,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) bsz, q_len, _ = hidden_states.size() @@ -740,41 +706,40 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + causal_mask = attention_mask if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: + if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an + # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + is_causal = True if causal_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -880,8 +845,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + # the `top_x` tensor here. Use non-inplace version to make cudagraphs happy. + final_hidden_states.index_add(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits @@ -906,12 +871,8 @@ def forward( output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, - **kwargs, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -941,6 +902,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = residual + hidden_states @@ -1079,8 +1041,7 @@ def _init_weights(self, module): "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", MIXTRAL_START_DOCSTRING, ) -# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral class MixtralModel(MixtralPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] @@ -1125,6 +1086,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, MoeModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( @@ -1138,74 +1100,38 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - past_key_values_length = 0 - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + return_legacy_cache = True - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions + ) + hidden_states = inputs_embeds # decoder layers @@ -1222,22 +1148,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, position_ids, past_key_values, output_attentions, output_router_logits, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -1257,9 +1185,9 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() if not return_dict: return tuple( @@ -1275,6 +1203,110 @@ def forward( router_logits=all_router_logits, ) + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + use_cache: bool, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self._attn_implementation == "flash_attention_2": + if attention_mask is not None and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + + # cache_position must be valid here no matter which cache we use + past_seen_tokens = cache_position[0] if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + if self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache + if using_sliding_window_cache: + target_length = max(sequence_length, self.config.sliding_window) + # StaticCache + elif using_static_cache: + target_length = past_key_values.get_max_length() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + + exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + + if self.config.sliding_window is not None: + if attention_mask is not None and attention_mask.dim() == 4: + logger.warning_once( + "Sliding window will not take effect when passing 4d custom masks" + "you may get unexpected results, use attention mask generated by tokenizer" + "or set model.config.sliding_window to None if you don't want sliding window" + ) + elif not using_sliding_window_cache or sequence_length > self.config.sliding_window: + exclude_mask |= torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - self.config.sliding_window + ) + + causal_mask *= exclude_mask + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + class MixtralForCausalLM(MixtralPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] @@ -1324,6 +1356,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: @@ -1373,6 +1406,7 @@ def forward( output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1426,14 +1460,21 @@ def prepare_inputs_for_generation( attention_mask=None, inputs_embeds=None, output_router_logits=False, + cache_position=None, + use_cache=True, **kwargs, ): # Omit tokens covered by past_key_values if past_key_values is not None: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -1466,17 +1507,33 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] + # crop the attention_mask to sliding window size during decode phase if using SlidingWindowCache + if ( + past_length > 0 + and attention_mask is not None + and isinstance(past_key_values, SlidingWindowCache) + and attention_mask.shape[1] > past_key_values.sliding_window_size + ): + attention_mask = attention_mask[:, -past_key_values.sliding_window_size :] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, "output_router_logits": output_router_logits, } diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 01d124fb9873fe..ad563f1112e5fa 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -40,7 +40,7 @@ _CONFIG_FOR_DOC = "PersimmonConfig" -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Persimmon +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding with Mixtral->Qwen2 class PersimmonRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -132,7 +132,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb +# Copied from transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 7b79643a17ba8c..6836bfc9c68b68 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -76,7 +76,7 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding with Qwen2->Phi class PhiRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -168,7 +168,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb +# Copied from transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index a930a4bdcf7190..c0f5b2b6bd63f1 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -90,7 +90,8 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2 +# copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2 +# TODO: @longjie no longer copied from after static cache class Qwen2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -135,7 +136,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb +# copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb +# TODO: @longjie no longer copied from after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -620,7 +622,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2 +# copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2 +# TODO: @longjie no longer copied from after static cache class Qwen2SdpaAttention(Qwen2Attention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 26a04d710bdc3d..bf5e3187dc61cc 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -166,7 +166,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding with Qwen2->Qwen2Moe class Qwen2MoeRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -211,7 +211,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb +# Copied from transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -698,7 +698,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2Moe +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe class Qwen2MoeSdpaAttention(Qwen2MoeAttention): """ Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index f6a8a8a2be2be3..193a6325321e53 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -66,7 +66,7 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->StableLm +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding with Qwen2->StableLm class StableLmRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -158,7 +158,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb +# Copied from transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 5004f698417ff8..3ac1d59fba1793 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -70,7 +70,7 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Starcoder2 +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding with Mixtral->Qwen2 class Starcoder2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -115,7 +115,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb +# Copied from transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -1068,8 +1068,7 @@ def forward( ) -# copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM with MISTRAL->STARCODER2,Mistral-7B-v0.1->starcoder2-7b_16k,Mistral->Starcoder2,mistralai->bigcode -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM with QWEN2->STARCODER2,Qwen2->Starcoder2 class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] @@ -1102,6 +1101,7 @@ def get_decoder(self): @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy def forward( self, input_ids: torch.LongTensor = None, From dcac131d79a6937a8b4428ab9fcf267b997faa9c Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 3 May 2024 17:36:30 +0200 Subject: [PATCH 11/29] first version --- .../models/mistral/modeling_mistral.py | 400 +++++++++++------- tests/models/mistral/test_modeling_mistral.py | 51 ++- 2 files changed, 292 insertions(+), 159 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 665e95a8fd7856..ca96e778c4e43f 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -30,8 +30,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -99,31 +99,46 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + + @property + def sin_cached(self): + logger.warning_once( + "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + ) + return self._sin_cached - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), + @property + def cos_cached(self): + logger.warning_once( + "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) + return self._cos_cached + + @torch.no_grad + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -136,7 +151,7 @@ def rotate_half(x): # copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # TODO @Arthur no longer copied from LLama after static cache -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -144,9 +159,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -157,8 +171,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -227,7 +241,7 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.rotary_emb = MistralRotaryEmbedding( self.head_dim, @@ -246,6 +260,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: @@ -262,20 +277,13 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_key_value = getattr(self, "past_key_value", past_key_value) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position" : cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -284,19 +292,9 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -346,13 +344,14 @@ def forward( use_cache: bool = False, **kwargs, ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" ) - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") + output_attentions = False + bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -363,21 +362,13 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + past_key_value = getattr(self, "past_key_value", past_key_value) kv_seq_len = key_states.shape[-2] if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_seq_length() - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) use_sliding_windows = ( _flash_supports_window_size @@ -632,6 +623,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -658,41 +650,44 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # In case static cache is used, it is an instance attribute. + past_key_value = getattr(self, "past_key_value", past_key_value) + if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + causal_mask = attention_mask if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: + if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an + # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + is_causal = True if causal_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -729,12 +724,9 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -749,6 +741,11 @@ def forward( past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -761,6 +758,8 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + **kwargs, ) hidden_states = residual + hidden_states @@ -934,12 +933,13 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -950,74 +950,35 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, use_cache) + hidden_states = inputs_embeds # decoder layers @@ -1033,20 +994,22 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, position_ids, past_key_values, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -1065,8 +1028,11 @@ def forward( next_cache = None if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - + next_cache = ( + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, DynamicCache) + else next_decoder_cache + ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -1075,6 +1041,106 @@ def forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) + + # copied from Llama implementation + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + use_cache: bool + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self._attn_implementation == "flash_attention_2": + if attention_mask is not None and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + + + exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + + if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() != 4): + # assume signed int tensor for cache_position + exclude_mask |= torch.arange(target_length, device=device) <= (cache_position.reshape(-1,1) - self.config.sliding_window) + + causal_mask *= exclude_mask + + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + elif attention_mask.dim() == 4: + # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with + # cache. In that case, the 4D attention mask attends to the newest tokens only. + if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input lenght is deprecated and will be " + "removed in transformers v4.42.0" + ) + offset = cache_position[0] + else: + offset = 0 + mask_shape = attention_mask.shape + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask class MistralForCausalLM(MistralPreTrainedModel): @@ -1114,13 +1180,14 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1165,6 +1232,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1197,14 +1265,29 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, ): # Omit tokens covered by past_key_values + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() + # cache_length = past_key_values.get_seq_length() + # past_length = past_key_values.seen_tokens + # max_cache_length = past_key_values.get_max_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -1241,13 +1324,14 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.contiguous()} model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, } ) diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index bbc36c050e23f0..1e2979f195ec8e 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -20,6 +20,8 @@ import unittest import pytest +from packaging import version +from parameterized import parameterized from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed from transformers.testing_utils import ( @@ -469,7 +471,12 @@ def test_flash_attn_2_generate_use_cache(self): @slow def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mistral flash attention does not support right padding") - + + # copied from Llama tests to supress errors for now + @unittest.skip("TODO @gante fix this for Mistral") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass @require_torch_gpu class MistralIntegrationTest(unittest.TestCase): @@ -628,6 +635,48 @@ def test_speculative_generation(self): backend_empty_cache(torch_device) gc.collect() + @slow + def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest("This test requires torch >= 2.3 to run.") + + NUM_TOKENS_TO_GENERATE = 40 + EXPECTED_TEXT_COMPLETION = { + 8: ['My favourite condiment is 100% ketchup. I love it on everything. ' + 'I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles'], + # 7: [], + } + + prompts = ["My favourite condiment is "] + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False) + tokenizer.pad_token = tokenizer.eos_token + model = MistralForCausalLM.from_pretrained( + "mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16 + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], dynamic_text) + + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) + + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) + @slow @require_torch_gpu From 9afc73b949d1b846fb42c4f2affbd5d952043cf3 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Sat, 4 May 2024 07:23:42 +0200 Subject: [PATCH 12/29] fix sliding window --- src/transformers/models/mistral/modeling_mistral.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index ca96e778c4e43f..d2f711f143a02b 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1094,16 +1094,10 @@ def _update_causal_mask( ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() != 4): - # assume signed int tensor for cache_position - exclude_mask |= torch.arange(target_length, device=device) <= (cache_position.reshape(-1,1) - self.config.sliding_window) - + exclude_mask |= torch.arange(target_length, device=device) < (cache_position.reshape(-1,1) - self.config.sliding_window) causal_mask *= exclude_mask causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) From 3fa9285c54cc55847db0a230060339a44a8b34b2 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 7 May 2024 06:23:21 +0200 Subject: [PATCH 13/29] fix style --- docs/source/en/llm_optims.md | 2 +- .../models/mistral/modeling_mistral.py | 89 +++++++++---------- tests/models/mistral/test_modeling_mistral.py | 19 ++-- 3 files changed, 56 insertions(+), 54 deletions(-) diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 4b44c1d78c81f0..3273f5dac41dfe 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -29,7 +29,7 @@ To optimize this, you can use a kv-cache to store the past keys and values inste The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up. > [!WARNING] -> Currently, only [Command R](./model_doc/cohere), [Gemma](./model_doc/gemma) and [Llama](./model_doc/llama2) models support static kv-cache and torch.compile. +> Currently, only [Command R](./model_doc/cohere), [Gemma](./model_doc/gemma), [Llama](./model_doc/llama2) and [Mistral](./model_doc/mistral.md) models support static kv-cache and torch.compile. For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model. diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index d2f711f143a02b..f77312e80fa83d 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -17,10 +17,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Mistral model.""" +"""PyTorch Mistral model.""" + import inspect import math -import warnings from typing import List, Optional, Tuple, Union import torch @@ -88,8 +88,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral -# TODO @Arthur no longer copied from LLama after static cache +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral class MistralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -123,7 +122,7 @@ def cos_cached(self): "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) return self._cos_cached - + @torch.no_grad def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] @@ -149,8 +148,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -# TODO @Arthur no longer copied from LLama after static cache +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -261,12 +259,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -278,12 +271,12 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) - + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position" : cache_position} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -342,7 +335,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - **kwargs, + cache_position: Optional[torch.LongTensor] = None, ): if isinstance(past_key_value, StaticCache): raise ValueError( @@ -365,7 +358,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_seq_length() + kv_seq_len += cache_position[0] cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -605,8 +598,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral -# TODO @Arthur no longer copied from LLama after static cache +# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral class MistralSdpaAttention(MistralAttention): """ Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -669,7 +661,6 @@ def forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: @@ -680,14 +671,14 @@ def forward( # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` is_causal = True if causal_mask is None and q_len > 1 else False - + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -741,11 +732,6 @@ def forward( past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -954,7 +940,7 @@ def forward( raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) - + if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." @@ -972,13 +958,14 @@ def forward( cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - + if position_ids is None: position_ids = cache_position.unsqueeze(0) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, use_cache + ) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, use_cache) - hidden_states = inputs_embeds # decoder layers @@ -1041,21 +1028,20 @@ def forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) - - # copied from Llama implementation + def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - use_cache: bool + use_cache: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - + if self._attn_implementation == "flash_attention_2": if attention_mask is not None and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] @@ -1065,9 +1051,10 @@ def _update_causal_mask( " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - if attention_mask is not None and 0.0 in attention_mask: return attention_mask + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask return None - + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. @@ -1076,8 +1063,11 @@ def _update_causal_mask( using_static_cache = isinstance(past_key_values, StaticCache) if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, is_training=self.training, + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, ): return None @@ -1096,8 +1086,10 @@ def _update_causal_mask( causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() != 4): - exclude_mask |= torch.arange(target_length, device=device) < (cache_position.reshape(-1,1) - self.config.sliding_window) + if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() == 2): + exclude_mask |= torch.arange(target_length, device=device) < ( + cache_position.reshape(-1, 1) - self.config.sliding_window + ) causal_mask *= exclude_mask causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) @@ -1120,9 +1112,9 @@ def _update_causal_mask( offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( + mask_slice + ) if ( self.config._attn_implementation == "sdpa" @@ -1258,6 +1250,7 @@ def forward( attentions=outputs.attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1272,9 +1265,6 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - # cache_length = past_key_values.get_seq_length() - # past_length = past_key_values.seen_tokens - # max_cache_length = past_key_values.get_max_length() past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) @@ -1282,6 +1272,7 @@ def prepare_inputs_for_generation( else None ) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -1320,6 +1311,12 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids.contiguous()} + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + model_inputs.update( { "position_ids": position_ids, diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 1e2979f195ec8e..6fb87c334488b1 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -12,8 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Testing suite for the PyTorch Mistral model. """ - +"""Testing suite for the PyTorch Mistral model.""" import gc import tempfile @@ -471,13 +470,14 @@ def test_flash_attn_2_generate_use_cache(self): @slow def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mistral flash attention does not support right padding") - - # copied from Llama tests to supress errors for now + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_new_cache_format @unittest.skip("TODO @gante fix this for Mistral") @parameterized.expand([(1, False), (1, True), (4, False)]) def test_new_cache_format(self, num_beams, do_sample): pass + @require_torch_gpu class MistralIntegrationTest(unittest.TestCase): # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) @@ -644,9 +644,14 @@ def test_compile_static_cache(self): NUM_TOKENS_TO_GENERATE = 40 EXPECTED_TEXT_COMPLETION = { - 8: ['My favourite condiment is 100% ketchup. I love it on everything. ' - 'I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles'], - # 7: [], + 8: [ + "My favourite condiment is 100% ketchup. I love it on everything. " + "I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles" + ], + 7: [ + "My favourite condiment is 100% ketchup. I love it on everything. " + "I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles" + ], } prompts = ["My favourite condiment is "] From 5246cedc28f83e53609230526f83fc3135c75882 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Thu, 9 May 2024 05:56:25 +0200 Subject: [PATCH 14/29] add sliding window cache --- src/transformers/cache_utils.py | 112 ++++++++++++++++++ src/transformers/generation/utils.py | 16 ++- .../open_llama/modeling_open_llama.py | 6 +- .../models/falcon/modeling_falcon.py | 6 +- .../models/gpt_neox/modeling_gpt_neox.py | 6 +- .../modeling_gpt_neox_japanese.py | 3 +- .../models/idefics/modeling_idefics.py | 3 +- .../models/mistral/modeling_mistral.py | 64 +++++++--- .../models/mixtral/modeling_mixtral.py | 18 ++- .../models/persimmon/modeling_persimmon.py | 6 +- src/transformers/models/phi/modeling_phi.py | 6 +- .../models/qwen2/modeling_qwen2.py | 9 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 9 +- .../models/stablelm/modeling_stablelm.py | 6 +- .../models/starcoder2/modeling_starcoder2.py | 15 ++- tests/models/mistral/test_modeling_mistral.py | 48 +++++++- tests/models/mixtral/test_modeling_mixtral.py | 6 + tests/models/qwen2/test_modeling_qwen2.py | 6 + .../qwen2_moe/test_modeling_qwen2_moe.py | 6 + 19 files changed, 300 insertions(+), 51 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 2e29e19ade46a4..da8fc9ebc24b8b 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -448,3 +448,115 @@ def reset(self): # In-place ops prevent breaking the static address self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() + + +class SlidingWindowCache(Cache): + """ + Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. + + Parameters: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (*optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + """ + + def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: + super().__init__() + self.max_batch_size = max_batch_size + # take the minimum of max_cache_len and config.sliding_window so that we allocate less memory + # when we do short-sentence generation + self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len + self.model_sliding_window_size = config.sliding_window + self.sliding_window_size = min(self.max_cache_len, self.model_sliding_window_size) + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + + self.dtype = dtype if dtype is not None else torch.float32 + self.num_key_value_heads = ( + config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads + ) + + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + cache_shape = ( + config.num_hidden_layers, + max_batch_size, + self.num_key_value_heads, + self.sliding_window_size, + self.head_dim, + ) + + self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + + torch._dynamo.mark_static_address(self.key_cache) + torch._dynamo.mark_static_address(self.value_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Dict[str, Any] | None = None, + ) -> Tuple[torch.Tensor]: + cache_position = cache_kwargs.get("cache_position") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + + # assume this only happens in prefill phase when prompt length > sliding_window_size + if cache_position.shape[0] > self.sliding_window_size: + k_out = key_states[:, :, -self.sliding_window_size :, :] + v_out = value_states[:, :, -self.sliding_window_size :, :] + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return key_states, value_states + + slicing = torch.ones(self.sliding_window_size, dtype=torch.long, device=value_states.device).cumsum(0) + cache_position = cache_position.clamp(0, self.sliding_window_size - 1) + to_shift = cache_position >= self.sliding_window_size - 1 + indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size + + k_out, v_out = k_out, v_out + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + + return k_out, v_out + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + # assume this will be called only in the first generation step + # `cache_postion` will be used in other cases + return 0 + + def get_max_length(self) -> int | None: + # in theory there is no limit because the sliding window size is fixed + # no matter how long the sentence is + return None + + def need_new_cache(self, max_batch_size: int, new_max_cache_len: int) -> bool: + # this is used by model.generate, when we reuse model between generations, + # we need to be careful because the new `max_cache_len` may become + # larger and `self.sliding_window_size` might change accordingly + return max_batch_size > self.max_batch_size or ( + self.sliding_window_size < self.model_sliding_window_size and new_max_cache_len > self.max_cache_len + ) + + def reset(self): + self.key_cache.zero_() + self.value_cache.zero_() diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1c90fdd30753e5..85b798ef06a5db 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -24,7 +24,7 @@ import torch.distributed as dist from torch import nn -from ..cache_utils import Cache, DynamicCache, StaticCache +from ..cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..models.auto import ( @@ -95,9 +95,7 @@ if is_accelerate_available(): from accelerate.hooks import AlignDevicesHook, add_hook_to_module -NEED_SETUP_CACHE_CLASSES_MAPPING = { - "static": StaticCache, -} +NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache} @dataclass @@ -1603,6 +1601,16 @@ def generate( ) if generation_config.cache_implementation == "static": model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length) + elif generation_config.cache_implementation == "sliding_window": + if not hasattr(self.config, "sliding_window") or self.config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + model_kwargs["past_key_values"] = self._get_sliding_window_cache( + batch_size, generation_config.max_length + ) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 098f8c7da50d5e..e4f7677a671ef2 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -63,7 +63,8 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama +# TODO @longjie no longer copied from Mistral after static cache class OpenLlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -154,7 +155,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index a171c875dbdc0a..d56ed8747d6273 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -82,7 +82,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -124,7 +125,8 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon +# TODO @longjie no longer copied from Mistral after static cache class FalconRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index e0b2309fc9658b..815933e13ac279 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -522,7 +522,8 @@ def attention_mask_func(attention_scores, ltor_mask): class GPTNeoXRotaryEmbedding(nn.Module): - # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ + # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ + # TODO @longjie no longer copied from Mistral after static cache def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -614,7 +615,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index ea934581aa4f21..b5dccf6ea1d6be 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -230,7 +230,8 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding class RotaryEmbedding(nn.Module): - # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ + # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ + # TODO @longjie no longer copied from Mistral after static cache def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 622e336fe4034e..57eb1e14bb055c 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -478,7 +478,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f77312e80fa83d..1d50ee0999f918 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -30,7 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel @@ -88,7 +88,6 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral class MistralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -630,6 +629,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) bsz, q_len, _ = hidden_states.size() @@ -644,7 +644,7 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # In case static cache is used, it is an instance attribute. past_key_value = getattr(self, "past_key_value", past_key_value) @@ -1059,9 +1059,12 @@ def _update_causal_mask( # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + # cache_position must be valid here no matter which cache we use + past_seen_tokens = cache_position[0] if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) - if self.config._attn_implementation == "sdpa" and not using_static_cache: + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + if self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1074,8 +1077,13 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if using_static_cache: + # SlidingWindowCache + if using_sliding_window_cache: + target_length = max(sequence_length, self.config.sliding_window) + # StaticCache + elif using_static_cache: target_length = past_key_values.get_max_length() + # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -1086,12 +1094,30 @@ def _update_causal_mask( causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() == 2): - exclude_mask |= torch.arange(target_length, device=device) < ( - cache_position.reshape(-1, 1) - self.config.sliding_window - ) - causal_mask *= exclude_mask + if self.config.sliding_window is not None: + if attention_mask is not None and attention_mask.dim() == 4: + logger.warning_once( + "Sliding window will not take effect when passing 4d custom masks" + "you may get unexpected results, use attention mask generated by tokenizer" + "or set model.config.sliding_window to None if you don't want sliding window" + ) + + # can only happen in prefill phase, when the prompt length > sliding window length, we need to do this + # manually because we are returning the whole prompt token sequence in `SlidingWindowCache`, maybe a better + # way is to support chunked prefill instead + if sequence_length > self.config.sliding_window and using_sliding_window_cache: + exclude_mask |= torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - self.config.sliding_window + ) + + # not using `SlidingWindowCache` and attention mask supports sliding window + if (attention_mask is None or attention_mask.dim() == 2) and not using_sliding_window_cache: + exclude_mask |= torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - self.config.sliding_window + ) + + causal_mask *= exclude_mask causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit @@ -1112,9 +1138,9 @@ def _update_causal_mask( offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( - mask_slice - ) + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice if ( self.config._attn_implementation == "sdpa" @@ -1250,7 +1276,6 @@ def forward( attentions=outputs.attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1305,6 +1330,15 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] + # crop the attention_mask to sliding window size during decode phase if using SlidingWindowCache + if ( + past_length > 0 + and attention_mask is not None + and isinstance(past_key_values, SlidingWindowCache) + and attention_mask.shape[1] > past_key_values.sliding_window_size + ): + attention_mask = attention_mask[:, -past_key_values.sliding_window_size :] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index e5a81c4c9083ed..9c1ab1c07d6230 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -181,7 +181,8 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache class MixtralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -226,7 +227,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -268,7 +270,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache class MixtralAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -397,7 +400,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache class MixtralFlashAttention2(MixtralAttention): """ Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays @@ -692,7 +696,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache class MixtralSdpaAttention(MixtralAttention): """ Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -1074,7 +1079,8 @@ def _init_weights(self, module): "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", MIXTRAL_START_DOCSTRING, ) -# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache class MixtralModel(MixtralPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 8d4ad532074f19..2e05a0ecea51f0 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -40,7 +40,8 @@ _CONFIG_FOR_DOC = "PersimmonConfig" -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Persimmon +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Persimmon +# TODO @longjie no longer copied from Mistral after static cache class PersimmonRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -132,7 +133,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 795ff18e5bcd1f..0fb3d1a63ec1a3 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -76,7 +76,8 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi +# TODO @longjie no longer copied from Mistral after static cache class PhiRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -168,7 +169,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 709504aba7157c..6c9a8b4af38fcc 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -90,7 +90,8 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2 +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2 +# TODO @longjie no longer copied from Mistral after static cache class Qwen2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -135,7 +136,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -620,7 +622,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2 +# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2 +# TODO @longjie no longer copied from Mistral after static cache class Qwen2SdpaAttention(Qwen2Attention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 838425505b3b1a..9e0218c2a2383e 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -166,7 +166,8 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2Moe +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2Moe +# TODO @longjie no longer copied from Mistral after static cache class Qwen2MoeRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -211,7 +212,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -698,7 +700,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2Moe +# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2Moe +# TODO @longjie no longer copied from Mistral after static cache class Qwen2MoeSdpaAttention(Qwen2MoeAttention): """ Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index bc133ffb3d7329..6dd2c58e5e91e6 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -66,7 +66,8 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->StableLm +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->StableLm +# TODO @longjie no longer copied from Mistral after static cache class StableLmRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -158,7 +159,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 61e8518d659cae..8660ee06ad9b42 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -70,7 +70,8 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Starcoder2 +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Starcoder2 +# TODO @longjie no longer copied from Mistral after static cache class Starcoder2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -115,7 +116,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -599,7 +601,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Starcoder2 +# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Starcoder2 +# TODO @longjie no longer copied from Mistral after static cache class Starcoder2SdpaAttention(Starcoder2Attention): """ Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -708,7 +711,8 @@ def __init__(self, config: Starcoder2Config, layer_idx: int): self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) - # Copied from transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward + # copied from transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward + # TODO @longjie no longer copied from Mistral after static cache def forward( self, hidden_states: torch.Tensor, @@ -1067,7 +1071,8 @@ def forward( ) -# Copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM with MISTRAL->STARCODER2,Mistral-7B-v0.1->starcoder2-7b_16k,Mistral->Starcoder2,mistralai->bigcode +# copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM with MISTRAL->STARCODER2,Mistral-7B-v0.1->starcoder2-7b_16k,Mistral->Starcoder2,mistralai->bigcode +# TODO @longjie no longer copied from Mistral after static cache class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 6fb87c334488b1..11996f83d34622 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -471,7 +471,6 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mistral flash attention does not support right padding") - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_new_cache_format @unittest.skip("TODO @gante fix this for Mistral") @parameterized.expand([(1, False), (1, True), (4, False)]) def test_new_cache_format(self, num_beams, do_sample): @@ -681,6 +680,53 @@ def test_compile_static_cache(self): ) static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) + + @slow + def test_compile_sliding_window_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest("This test requires torch >= 2.3 to run.") + + NUM_TOKENS_TO_GENERATE = 40 + EXPECTED_TEXT_COMPLETION = { + 8: [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in a vacuum, " + "and 2) the laws of physics are the same for all observers in uniform motion.\n\nThe first part of the theory is" + ], + 7: [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in a vacuum, " + "and 2) the laws of physics are the same for all observers in uniform motion.\n\nThe first part of the theory is" + ], + } + + prompts = ["Simply put, the theory of relativity states that "] + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False) + tokenizer.pad_token = tokenizer.eos_token + model = MistralForCausalLM.from_pretrained( + "mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16 + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], dynamic_text) + + # Sliding Window Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) + + # Sliding Window Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) @slow diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 0d92595d8cfa85..d3ec1a7e18db80 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -19,6 +19,7 @@ import unittest import pytest +from parameterized import parameterized from transformers import MixtralConfig, is_torch_available from transformers.testing_utils import ( @@ -505,6 +506,11 @@ def test_load_balancing_loss(self): # This is to mimic torch.testing.assert_not_close self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) + @unittest.skip("TODO @gante fix this for Mixtral") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + @require_torch class MixtralIntegrationTest(unittest.TestCase): diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index f4e88a97f06a53..2e5de17ffc5433 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -20,6 +20,7 @@ import unittest import pytest +from parameterized import parameterized from transformers import AutoTokenizer, Qwen2Config, is_torch_available, set_seed from transformers.testing_utils import ( @@ -481,6 +482,11 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Qwen2 flash attention does not support right padding") + @unittest.skip("TODO @gante fix this for Qwen2") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + @require_torch class Qwen2IntegrationTest(unittest.TestCase): diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index f0818e680d3da8..8620ddb56575a9 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -20,6 +20,7 @@ import unittest import pytest +from parameterized import parameterized from transformers import AutoTokenizer, Qwen2MoeConfig, is_torch_available, set_seed from transformers.testing_utils import ( @@ -545,6 +546,11 @@ def test_load_balancing_loss(self): # This is to mimic torch.testing.assert_not_close self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) + @unittest.skip("TODO @gante fix this for Qwen2Moe") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + @require_torch class Qwen2MoeIntegrationTest(unittest.TestCase): From 6367154e8ab703c7a8fa32e676f0471bca1e4ae6 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Thu, 9 May 2024 06:02:30 +0200 Subject: [PATCH 15/29] fix style --- docs/source/en/llm_optims.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 3273f5dac41dfe..5e49f0e1ebd3ab 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -29,7 +29,7 @@ To optimize this, you can use a kv-cache to store the past keys and values inste The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up. > [!WARNING] -> Currently, only [Command R](./model_doc/cohere), [Gemma](./model_doc/gemma), [Llama](./model_doc/llama2) and [Mistral](./model_doc/mistral.md) models support static kv-cache and torch.compile. +> Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and torch.compile. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list. For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model. From c74b329b22867a781891d90fa260471db14389d8 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 10 May 2024 02:28:01 +0200 Subject: [PATCH 16/29] address comments --- .../object-detection/run_object_detection.py | 4 +- setup.py | 10 +-- src/transformers/cache_utils.py | 11 --- src/transformers/generation/utils.py | 52 ++++++------ .../open_llama/modeling_open_llama.py | 6 +- .../models/falcon/modeling_falcon.py | 6 +- .../models/gemma/modeling_gemma.py | 13 +-- .../models/gpt_neox/modeling_gpt_neox.py | 6 +- .../modeling_gpt_neox_japanese.py | 3 +- .../models/idefics/modeling_idefics.py | 3 +- .../models/llama/modeling_llama.py | 13 +-- .../models/mistral/modeling_mistral.py | 81 +++++++------------ src/transformers/models/olmo/modeling_olmo.py | 11 +-- .../models/persimmon/modeling_persimmon.py | 6 +- src/transformers/models/phi/modeling_phi.py | 6 +- .../models/qwen2/modeling_qwen2.py | 9 +-- .../models/qwen2_moe/modeling_qwen2_moe.py | 9 +-- .../models/stablelm/modeling_stablelm.py | 6 +- .../models/starcoder2/modeling_starcoder2.py | 15 ++-- tests/models/mistral/test_modeling_mistral.py | 6 -- tests/models/mixtral/test_modeling_mixtral.py | 6 -- tests/models/qwen2/test_modeling_qwen2.py | 6 -- .../qwen2_moe/test_modeling_qwen2_moe.py | 6 -- 23 files changed, 96 insertions(+), 198 deletions(-) diff --git a/examples/pytorch/object-detection/run_object_detection.py b/examples/pytorch/object-detection/run_object_detection.py index ba6ee1e55a481a..3f0769568f981a 100644 --- a/examples/pytorch/object-detection/run_object_detection.py +++ b/examples/pytorch/object-detection/run_object_detection.py @@ -244,7 +244,9 @@ class DataTrainingArguments: ) image_square_size: Optional[int] = field( default=600, - metadata={"help": "Image longest size will be resized to this value, then image will be padded to square."}, + metadata={ + "help": "Image longest size will be resized to this value, then image will be padded to square." + }, ) max_train_samples: Optional[int] = field( default=None, diff --git a/setup.py b/setup.py index 3061127768db9b..89d334464c2480 100644 --- a/setup.py +++ b/setup.py @@ -260,15 +260,7 @@ def run(self): extras["sklearn"] = deps_list("scikit-learn") extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp") -extras["tf-cpu"] = deps_list( - "keras", - "tensorflow-cpu", - "onnxconverter-common", - "tf2onnx", - "tensorflow-text", - "keras-nlp", - "tensorflow-probability", -) +extras["tf-cpu"] = deps_list("keras", "tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp", "tensorflow-probability") extras["torch"] = deps_list("torch", "accelerate") extras["accelerate"] = deps_list("accelerate") diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index da8fc9ebc24b8b..fca928a9adf0c5 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -485,8 +485,6 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) - self.key_cache: List[torch.Tensor] = [] - self.value_cache: List[torch.Tensor] = [] cache_shape = ( config.num_hidden_layers, max_batch_size, @@ -527,7 +525,6 @@ def update( to_shift = cache_position >= self.sliding_window_size - 1 indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size - k_out, v_out = k_out, v_out k_out = k_out[:, :, indices] v_out = v_out[:, :, indices] @@ -549,14 +546,6 @@ def get_max_length(self) -> int | None: # no matter how long the sentence is return None - def need_new_cache(self, max_batch_size: int, new_max_cache_len: int) -> bool: - # this is used by model.generate, when we reuse model between generations, - # we need to be careful because the new `max_cache_len` may become - # larger and `self.sliding_window_size` might change accordingly - return max_batch_size > self.max_batch_size or ( - self.sliding_window_size < self.model_sliding_window_size and new_max_cache_len > self.max_cache_len - ) - def reset(self): self.key_cache.zero_() self.value_cache.zero_() diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 85b798ef06a5db..ec149de32e14e1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1310,24 +1310,39 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) return model_kwargs - def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache: + def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache: """ - Sets a static cache for `generate`, that will persist across calls. A new cache will only be initialized a + Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a new `generate` call requires a larger cache. - Returns the resulting static cache object. + Returns the resulting cache object. """ - needs_new_cache = ( - not hasattr(self, "_static_cache") - or self._static_cache.max_batch_size < max_batch_size - or self._static_cache.max_cache_len < max_cache_len + cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] + need_new_cache = ( + not hasattr(self, "_cache") + or (not isinstance(self._cache, cache_cls)) + or self._cache.max_batch_size < max_batch_size ) - if needs_new_cache: + if cache_implementation == "sliding_window": + if not hasattr(self.config, "sliding_window") or self.config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + need_new_cache = need_new_cache or ( + self._cache.sliding_window_size < self._cache.model_sliding_window_size + and max_cache_len > self._cache.max_cache_len + ) + elif cache_implementation == "static": + need_new_cache = need_new_cache or self._cache.max_cache_len < max_cache_len + + if need_new_cache: if hasattr(self.config, "_pre_quantization_dtype"): cache_dtype = self.config._pre_quantization_dtype else: cache_dtype = self.dtype - self._static_cache = StaticCache( + self._cache = cache_cls( config=self.config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, @@ -1335,8 +1350,8 @@ def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCa dtype=cache_dtype, ) else: - self._static_cache.reset() # reset the cache for a new generation - return self._static_cache + self._cache.reset() + return self._cache def _prepare_special_tokens( self, @@ -1599,18 +1614,9 @@ def generate( "This model does not support the `cache_implementation` argument. Please check the following " "issue: https://github.com/huggingface/transformers/issues/28981." ) - if generation_config.cache_implementation == "static": - model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length) - elif generation_config.cache_implementation == "sliding_window": - if not hasattr(self.config, "sliding_window") or self.config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - model_kwargs["past_key_values"] = self._get_sliding_window_cache( - batch_size, generation_config.max_length - ) + model_kwargs["past_key_values"] = self._get_cache( + generation_config.cache_implementation, batch_size, generation_config.max_length + ) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index e4f7677a671ef2..4e42d716e895f9 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -63,8 +63,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->OpenLlama class OpenLlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -155,8 +154,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index d56ed8747d6273..bfa18af295d80c 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -82,8 +82,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -125,8 +124,7 @@ def _get_unpad_data(attention_mask): ) -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Falcon class FalconRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 12d01a6ea04d3e..3c59724888aba1 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -16,7 +16,6 @@ """ PyTorch Gemma model.""" import math -import warnings from typing import List, Optional, Tuple, Union import torch @@ -250,7 +249,6 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -262,8 +260,9 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -617,7 +616,6 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -633,10 +631,6 @@ def forward( (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) residual = hidden_states @@ -651,7 +645,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - **kwargs, ) hidden_states = residual + hidden_states diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 815933e13ac279..4980f7c636175a 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -522,8 +522,7 @@ def attention_mask_func(attention_scores, ltor_mask): class GPTNeoXRotaryEmbedding(nn.Module): - # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ - # TODO @longjie no longer copied from Mistral after static cache + # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -615,8 +614,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index b5dccf6ea1d6be..24d211317887cf 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -230,8 +230,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding class RotaryEmbedding(nn.Module): - # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ - # TODO @longjie no longer copied from Mistral after static cache + # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 57eb1e14bb055c..83a5cb65106383 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -478,8 +478,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c6da59fcfb3edc..f30f9d633e7db6 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -20,7 +20,6 @@ """PyTorch LLaMA model.""" import math -import warnings from typing import List, Optional, Tuple, Union import torch @@ -115,7 +114,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s @property def sin_cached(self): logger.warning_once( - "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "The sin_cached attribute will be removed in 4.42. Bear in mind that its contents changed in v4.40. Use " "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) return self._sin_cached @@ -123,7 +122,7 @@ def sin_cached(self): @property def cos_cached(self): logger.warning_once( - "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "The cos_cached attribute will be removed in 4.42. Bear in mind that its contents changed in v4.40. Use " "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) return self._cos_cached @@ -326,7 +325,6 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -425,7 +423,6 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -714,7 +711,6 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -730,10 +726,6 @@ def forward( (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) residual = hidden_states @@ -748,7 +740,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - **kwargs, ) hidden_states = residual + hidden_states diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 1d50ee0999f918..510ea0137bf8d5 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -107,22 +107,25 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) @property + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.sin_cached with Llama->Mistral def sin_cached(self): logger.warning_once( - "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + "The sin_cached attribute will be removed in 4.42. Bear in mind that its contents changed in v4.40. Use " + "the forward method of RoPE from now on instead. It is not used in the `MistralAttention` class" ) return self._sin_cached @property + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.cos_cached with Llama->Mistral def cos_cached(self): logger.warning_once( - "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + "The cos_cached attribute will be removed in 4.42. Bear in mind that its contents changed in v4.40. Use " + "the forward method of RoPE from now on instead. It is not used in the `MistralAttention` class" ) return self._cos_cached - @torch.no_grad + @torch.no_grad() + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) @@ -220,15 +223,15 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): "when creating this class." ) + self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads + self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True - self.attention_dropout = config.attention_dropout if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -246,9 +249,7 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): base=self.rope_theta, ) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - + # Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward with Gemma->Mistral def forward( self, hidden_states: torch.Tensor, @@ -270,21 +271,20 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) - cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: + if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask @@ -300,8 +300,8 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) if not output_attentions: @@ -696,12 +696,13 @@ def forward( } +# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL class MistralDecoderLayer(nn.Module): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = MistralMLP(config) self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -712,17 +713,17 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -745,7 +746,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - **kwargs, ) hidden_states = residual + hidden_states @@ -963,7 +963,7 @@ def forward( position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, use_cache + attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions ) hidden_states = inputs_embeds @@ -1036,6 +1036,7 @@ def _update_causal_mask( cache_position: torch.Tensor, past_key_values: Cache, use_cache: bool, + output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -1102,17 +1103,7 @@ def _update_causal_mask( "you may get unexpected results, use attention mask generated by tokenizer" "or set model.config.sliding_window to None if you don't want sliding window" ) - - # can only happen in prefill phase, when the prompt length > sliding window length, we need to do this - # manually because we are returning the whole prompt token sequence in `SlidingWindowCache`, maybe a better - # way is to support chunked prefill instead - if sequence_length > self.config.sliding_window and using_sliding_window_cache: - exclude_mask |= torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - self.config.sliding_window - ) - - # not using `SlidingWindowCache` and attention mask supports sliding window - if (attention_mask is None or attention_mask.dim() == 2) and not using_sliding_window_cache: + elif not using_sliding_window_cache or sequence_length > self.config.sliding_window: exclude_mask |= torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - self.config.sliding_window ) @@ -1123,29 +1114,17 @@ def _update_causal_mask( causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) - elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - logger.warning_once( - "Passing a 4d mask shorter than the input lenght is deprecated and will be " - "removed in transformers v4.42.0" - ) - offset = cache_position[0] - else: - offset = 0 - mask_shape = attention_mask.shape - mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" + and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 6a7b2f748fcf03..924c16a78b43e3 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -20,7 +20,6 @@ """PyTorch OLMo model.""" import math -import warnings from typing import List, Optional, Tuple, Union import torch @@ -112,7 +111,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s @property def sin_cached(self): logger.warning_once( - "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "The sin_cached attribute will be removed in 4.42. Bear in mind that its contents changed in v4.40. Use " "the forward method of RoPE from now on instead. It is not used in the `OlmoAttention` class" ) return self._sin_cached @@ -120,7 +119,7 @@ def sin_cached(self): @property def cos_cached(self): logger.warning_once( - "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "The cos_cached attribute will be removed in 4.42. Bear in mind that its contents changed in v4.40. Use " "the forward method of RoPE from now on instead. It is not used in the `OlmoAttention` class" ) return self._cos_cached @@ -690,7 +689,6 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -706,10 +704,6 @@ def forward( (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) residual = hidden_states @@ -724,7 +718,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - **kwargs, ) hidden_states = residual + hidden_states diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 2e05a0ecea51f0..01d124fb9873fe 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -40,8 +40,7 @@ _CONFIG_FOR_DOC = "PersimmonConfig" -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Persimmon -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Persimmon class PersimmonRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -133,8 +132,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 0fb3d1a63ec1a3..7b79643a17ba8c 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -76,8 +76,7 @@ def _get_unpad_data(attention_mask): ) -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi class PhiRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -169,8 +168,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 6c9a8b4af38fcc..76212277621843 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -90,8 +90,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2 -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2 class Qwen2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -136,8 +135,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -622,8 +620,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2 -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2 class Qwen2SdpaAttention(Qwen2Attention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 9e0218c2a2383e..26a04d710bdc3d 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -166,8 +166,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2Moe -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe class Qwen2MoeRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -212,8 +211,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -700,8 +698,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2Moe -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2Moe class Qwen2MoeSdpaAttention(Qwen2MoeAttention): """ Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 6dd2c58e5e91e6..f6a8a8a2be2be3 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -66,8 +66,7 @@ def _get_unpad_data(attention_mask): ) -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->StableLm -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->StableLm class StableLmRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -159,8 +158,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 8660ee06ad9b42..5004f698417ff8 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -70,8 +70,7 @@ def _get_unpad_data(attention_mask): ) -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Starcoder2 -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Starcoder2 class Starcoder2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -116,8 +115,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -601,8 +599,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Starcoder2 -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Starcoder2 class Starcoder2SdpaAttention(Starcoder2Attention): """ Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -711,8 +708,7 @@ def __init__(self, config: Starcoder2Config, layer_idx: int): self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) - # copied from transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward - # TODO @longjie no longer copied from Mistral after static cache + # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer.forward def forward( self, hidden_states: torch.Tensor, @@ -725,7 +721,8 @@ def forward( ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: if "padding_mask" in kwargs: warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + "Passing `padding_mask` is deprecated and will be removed in v4.37. " + "Please make sure use `attention_mask` instead.`" ) """ Args: diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 11996f83d34622..654bce492cfe3e 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -20,7 +20,6 @@ import pytest from packaging import version -from parameterized import parameterized from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed from transformers.testing_utils import ( @@ -471,11 +470,6 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mistral flash attention does not support right padding") - @unittest.skip("TODO @gante fix this for Mistral") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch_gpu class MistralIntegrationTest(unittest.TestCase): diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index d3ec1a7e18db80..0d92595d8cfa85 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -19,7 +19,6 @@ import unittest import pytest -from parameterized import parameterized from transformers import MixtralConfig, is_torch_available from transformers.testing_utils import ( @@ -506,11 +505,6 @@ def test_load_balancing_loss(self): # This is to mimic torch.testing.assert_not_close self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) - @unittest.skip("TODO @gante fix this for Mixtral") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch class MixtralIntegrationTest(unittest.TestCase): diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 2e5de17ffc5433..f4e88a97f06a53 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -20,7 +20,6 @@ import unittest import pytest -from parameterized import parameterized from transformers import AutoTokenizer, Qwen2Config, is_torch_available, set_seed from transformers.testing_utils import ( @@ -482,11 +481,6 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Qwen2 flash attention does not support right padding") - @unittest.skip("TODO @gante fix this for Qwen2") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch class Qwen2IntegrationTest(unittest.TestCase): diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index 8620ddb56575a9..f0818e680d3da8 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -20,7 +20,6 @@ import unittest import pytest -from parameterized import parameterized from transformers import AutoTokenizer, Qwen2MoeConfig, is_torch_available, set_seed from transformers.testing_utils import ( @@ -546,11 +545,6 @@ def test_load_balancing_loss(self): # This is to mimic torch.testing.assert_not_close self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) - @unittest.skip("TODO @gante fix this for Qwen2Moe") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch class Qwen2MoeIntegrationTest(unittest.TestCase): From 1cd711c878ccb879856667b7c2b70514797154d4 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 10 May 2024 03:21:18 +0200 Subject: [PATCH 17/29] fix test --- .../models/mistral/modeling_mistral.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 510ea0137bf8d5..75952b80f10a55 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -226,7 +226,7 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim + self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings @@ -949,9 +949,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - + + return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) + return_legacy_cache = True if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1013,13 +1015,10 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if isinstance(next_decoder_cache, DynamicCache) - else next_decoder_cache - ) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( From 06b64ca88f70a87d7f7166896a63c84d4299d3e5 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 10 May 2024 03:44:50 +0200 Subject: [PATCH 18/29] fix style --- src/transformers/models/gemma/modeling_gemma.py | 1 - src/transformers/models/mistral/modeling_mistral.py | 7 +------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 3c59724888aba1..bb35f7ac13eda4 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -260,7 +260,6 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 75952b80f10a55..9ccb71b8b22c96 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -270,7 +270,6 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -354,7 +353,6 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - past_key_value = getattr(self, "past_key_value", past_key_value) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += cache_position[0] @@ -646,9 +644,6 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - # In case static cache is used, it is an instance attribute. - past_key_value = getattr(self, "past_key_value", past_key_value) - if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} @@ -949,7 +944,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - + return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) From d46c601c154675424e958a601f397f6cc13da576 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 10 May 2024 06:20:29 +0200 Subject: [PATCH 19/29] move sliding window check inside cache init --- src/transformers/cache_utils.py | 7 +++++++ src/transformers/generation/utils.py | 6 ------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index fca928a9adf0c5..d015af080ad33e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -468,6 +468,13 @@ class SlidingWindowCache(Cache): """ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: + if not hasattr(config, "sliding_window") or config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + super().__init__() self.max_batch_size = max_batch_size # take the minimum of max_cache_len and config.sliding_window so that we allocate less memory diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ec149de32e14e1..4e2c5b4b1a3111 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1324,12 +1324,6 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l or self._cache.max_batch_size < max_batch_size ) if cache_implementation == "sliding_window": - if not hasattr(self.config, "sliding_window") or self.config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) need_new_cache = need_new_cache or ( self._cache.sliding_window_size < self._cache.model_sliding_window_size and max_cache_len > self._cache.max_cache_len From ec8f338cd1e1ef62b4617637b6e5d7f61f1ee0d5 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Mon, 13 May 2024 20:47:21 +0200 Subject: [PATCH 20/29] revert changes on irrelevant files & add comment on SlidingWindowCache --- .../pytorch/object-detection/run_object_detection.py | 4 +--- setup.py | 10 +++++++++- src/transformers/cache_utils.py | 12 ++++++++++++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/object-detection/run_object_detection.py b/examples/pytorch/object-detection/run_object_detection.py index 3f0769568f981a..ba6ee1e55a481a 100644 --- a/examples/pytorch/object-detection/run_object_detection.py +++ b/examples/pytorch/object-detection/run_object_detection.py @@ -244,9 +244,7 @@ class DataTrainingArguments: ) image_square_size: Optional[int] = field( default=600, - metadata={ - "help": "Image longest size will be resized to this value, then image will be padded to square." - }, + metadata={"help": "Image longest size will be resized to this value, then image will be padded to square."}, ) max_train_samples: Optional[int] = field( default=None, diff --git a/setup.py b/setup.py index 89d334464c2480..3061127768db9b 100644 --- a/setup.py +++ b/setup.py @@ -260,7 +260,15 @@ def run(self): extras["sklearn"] = deps_list("scikit-learn") extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp") -extras["tf-cpu"] = deps_list("keras", "tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp", "tensorflow-probability") +extras["tf-cpu"] = deps_list( + "keras", + "tensorflow-cpu", + "onnxconverter-common", + "tf2onnx", + "tensorflow-text", + "keras-nlp", + "tensorflow-probability", +) extras["torch"] = deps_list("torch", "accelerate") extras["accelerate"] = deps_list("accelerate") diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d015af080ad33e..ae6d30857241d6 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -453,6 +453,18 @@ def reset(self): class SlidingWindowCache(Cache): """ Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. + Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window_size - 1`, + if true we need to do a cycle shift on the current cache to replace the old states by the new key value states passed in. + + The `to_shift` is only true once we are above sliding_window_size. Thus with `sliding_window_size==64`: + + indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window_size + tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) + + We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window_size`) Parameters: config (`PretrainedConfig): From a51b44f9df01c3abf28717528af0d2100ca82dae Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Mon, 13 May 2024 21:16:47 +0200 Subject: [PATCH 21/29] address comments & fix style fix style --- src/transformers/cache_utils.py | 5 +++-- src/transformers/models/starcoder2/modeling_starcoder2.py | 4 ++-- tests/models/mistral/test_modeling_mistral.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index ae6d30857241d6..657dc48762bdb5 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -454,8 +454,9 @@ class SlidingWindowCache(Cache): """ Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window_size - 1`, - if true we need to do a cycle shift on the current cache to replace the old states by the new key value states passed in. - + if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), + we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. + The `to_shift` is only true once we are above sliding_window_size. Thus with `sliding_window_size==64`: indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window_size diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 5004f698417ff8..d6ee644f2f0735 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1068,8 +1068,7 @@ def forward( ) -# copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM with MISTRAL->STARCODER2,Mistral-7B-v0.1->starcoder2-7b_16k,Mistral->Starcoder2,mistralai->bigcode -# TODO @longjie no longer copied from Mistral after static cache +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM with QWEN2->STARCODER2,Qwen2->Starcoder2 class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] @@ -1102,6 +1101,7 @@ def get_decoder(self): @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy def forward( self, input_ids: torch.LongTensor = None, diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 654bce492cfe3e..fa629b6dac3a44 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -674,7 +674,7 @@ def test_compile_static_cache(self): ) static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) - + @slow def test_compile_sliding_window_cache(self): # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 From c1fca1aae44e7fbe57ce696c27faac68ec4f8378 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Mon, 13 May 2024 22:29:16 +0200 Subject: [PATCH 22/29] update causal mask --- .../models/mistral/modeling_mistral.py | 53 +++++++++---------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 9ccb71b8b22c96..ffe8302b049812 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1086,33 +1086,32 @@ def _update_causal_mask( else past_seen_tokens + sequence_length + 1 ) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - - exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - - if self.config.sliding_window is not None: - if attention_mask is not None and attention_mask.dim() == 4: - logger.warning_once( - "Sliding window will not take effect when passing 4d custom masks" - "you may get unexpected results, use attention mask generated by tokenizer" - "or set model.config.sliding_window to None if you don't want sliding window" - ) - elif not using_sliding_window_cache or sequence_length > self.config.sliding_window: - exclude_mask |= torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - self.config.sliding_window - ) - - causal_mask *= exclude_mask - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if self.config.sliding_window is not None: + if not using_sliding_window_cache or sequence_length > self.config.sliding_window: + exclude_mask |= torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - self.config.sliding_window + ) + causal_mask *= exclude_mask + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) if ( self.config._attn_implementation == "sdpa" From 6d0bf35842bdb04a399049368d7a4de058d74511 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 14 May 2024 01:53:48 +0200 Subject: [PATCH 23/29] fix style --- src/transformers/models/falcon/modeling_falcon.py | 2 +- src/transformers/models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/starcoder2/modeling_starcoder2.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index d701b042631c13..8d043e70c6c8c1 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -124,7 +124,7 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding with Mixtral->Qwen2 +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding with Qwen2->Falcon class FalconRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index ad563f1112e5fa..d343f1cb7190ef 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -40,7 +40,7 @@ _CONFIG_FOR_DOC = "PersimmonConfig" -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding with Mixtral->Qwen2 +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding with Qwen2->Persimmon class PersimmonRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 3ac1d59fba1793..98f028408b4089 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -70,7 +70,7 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding with Mixtral->Qwen2 +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding with Qwen2->Starcoder2 class Starcoder2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -599,7 +599,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Starcoder2 +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Starcoder2 class Starcoder2SdpaAttention(Starcoder2Attention): """ Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from From 66de109b6cdd49a249ae63433c5dbb9933fdf18d Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 14 May 2024 02:10:28 +0200 Subject: [PATCH 24/29] revert setup.py --- setup.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 89d334464c2480..3061127768db9b 100644 --- a/setup.py +++ b/setup.py @@ -260,7 +260,15 @@ def run(self): extras["sklearn"] = deps_list("scikit-learn") extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp") -extras["tf-cpu"] = deps_list("keras", "tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp", "tensorflow-probability") +extras["tf-cpu"] = deps_list( + "keras", + "tensorflow-cpu", + "onnxconverter-common", + "tf2onnx", + "tensorflow-text", + "keras-nlp", + "tensorflow-probability", +) extras["torch"] = deps_list("torch", "accelerate") extras["accelerate"] = deps_list("accelerate") From 210179b8d79f6ab7a3cfa9f4ae870a690030e858 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 14 May 2024 04:33:25 +0200 Subject: [PATCH 25/29] fix some bug --- .../models/mixtral/modeling_mixtral.py | 57 ++++++++++--------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index baf033f10732e0..0cb2e9768bd687 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -849,7 +849,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. Use non-inplace version to make cudagraphs happy. - final_hidden_states.index_add(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states[top_x] += current_hidden_states.to(hidden_states.dtype) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits @@ -1207,6 +1207,7 @@ def forward( router_logits=all_router_logits, ) + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._update_causal_mask with Mistral->Mixtral def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -1270,33 +1271,32 @@ def _update_causal_mask( else past_seen_tokens + sequence_length + 1 ) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - - exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - - if self.config.sliding_window is not None: - if attention_mask is not None and attention_mask.dim() == 4: - logger.warning_once( - "Sliding window will not take effect when passing 4d custom masks" - "you may get unexpected results, use attention mask generated by tokenizer" - "or set model.config.sliding_window to None if you don't want sliding window" - ) - elif not using_sliding_window_cache or sequence_length > self.config.sliding_window: - exclude_mask |= torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - self.config.sliding_window - ) - - causal_mask *= exclude_mask - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if self.config.sliding_window is not None: + if not using_sliding_window_cache or sequence_length > self.config.sliding_window: + exclude_mask |= torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - self.config.sliding_window + ) + causal_mask *= exclude_mask + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) if ( self.config._attn_implementation == "sdpa" @@ -1469,6 +1469,7 @@ def prepare_inputs_for_generation( **kwargs, ): # Omit tokens covered by past_key_values + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() From c6bf7a124332e60a2e6664bd5e1ecb6f1b386f22 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Wed, 15 May 2024 04:12:53 +0200 Subject: [PATCH 26/29] attempt --- src/transformers/models/mixtral/modeling_mixtral.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 0cb2e9768bd687..1e663d7e7f7917 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -848,7 +848,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. Use non-inplace version to make cudagraphs happy. + # the `top_x` tensor here. + # final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + # still suffers from `skipping cudagrahs due to ['incompatible ops']` final_hidden_states[top_x] += current_hidden_states.to(hidden_states.dtype) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits From 93680eaebd5f92fee527b9781e3f84157106b8a3 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Wed, 15 May 2024 04:13:44 +0200 Subject: [PATCH 27/29] attempt --- src/transformers/models/mixtral/modeling_mixtral.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 1e663d7e7f7917..c5f85b26af89a0 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -848,10 +848,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. + # the `top_x` tensor here. this will give `skipping cudagraphs due to index put with accumulate` + # in compile # final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - # still suffers from `skipping cudagrahs due to ['incompatible ops']` + # still suffers from `skipping cudagraphs due to ['incompatible ops']` final_hidden_states[top_x] += current_hidden_states.to(hidden_states.dtype) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits From 18fa18692527319bc07926a508322db0b43ce42f Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Thu, 16 May 2024 03:20:06 +0200 Subject: [PATCH 28/29] attempt --- .../models/mixtral/modeling_mixtral.py | 88 ++++++++++++++++-- tests/models/mixtral/test_modeling_mixtral.py | 93 +++++++++++++++++++ 2 files changed, 171 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 4889f5cff36719..2b468a8c442b95 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -786,6 +786,80 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) +class MixtralBlockTop2MLP(nn.Module): + def __init__(self, config: MixtralConfig): + super().__init__() + self.num_experts = config.num_local_experts + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Parameter(torch.empty(self.num_experts, self.ffn_dim, self.hidden_dim)) + self.w2 = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.ffn_dim)) + self.w3 = nn.Parameter(torch.empty(self.num_experts, self.ffn_dim, self.hidden_dim)) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor + ) -> torch.Tensor: + """_summary_ + + Args: + hidden_states (torch.Tensor): (batch_size * token_num, hidden_dim) + selected_experts (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, top_k) + + Returns: + torch.Tensor: _description_ + """ + + ts, tk = hidden_states.size(0), selected_experts.size(-1) + + w1 = self.w1[selected_experts] # (batch_size * token_num, top_k, ffn_dim, hidden_dim) + w2 = self.w2[selected_experts] # (batch_size * token_num, top_k, hidden_dim, ffn_dim) + w3 = self.w3[selected_experts] # (batch_size * token_num, ffn_dim, hidden_dim) + + x1 = torch.matmul(w1, hidden_states[:, None, :, None]) + x3 = torch.matmul(w3, hidden_states[:, None, :, None]) + x1 = self.act_fn(x1) + final_hidden_states = torch.matmul(w2, x1 * x3).reshape(ts, tk, self.hidden_dim) + final_hidden_states = final_hidden_states * routing_weights[:, :, None] + final_hidden_states = final_hidden_states.sum(dim=1) + return final_hidden_states + + +class MixtralMoeBlock(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + self.experts = MixtralBlockTop2MLP(config) + # Jitter parameters + self.jitter_noise = config.router_jitter_noise + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + class MixtralSparseMoeBlock(nn.Module): """ This implementation is @@ -840,20 +914,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) - # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. this will give `skipping cudagraphs due to index put with accumulate` - # in compile - # final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - - # still suffers from `skipping cudagraphs due to ['incompatible ops']` - final_hidden_states[top_x] += current_hidden_states.to(hidden_states.dtype) + final_hidden_states[top_x].index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits @@ -866,6 +932,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) self.block_sparse_moe = MixtralSparseMoeBlock(config) + # self.block_sparse_moe = MixtralMoeBlock(config) self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -954,7 +1021,7 @@ def forward( "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", MIXTRAL_START_DOCSTRING, ) -# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral +# copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Mixtral class MixtralPreTrainedModel(PreTrainedModel): config_class = MixtralConfig base_model_prefix = "model" @@ -963,6 +1030,7 @@ class MixtralPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 0d92595d8cfa85..59097d660e4fe4 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -19,6 +19,7 @@ import unittest import pytest +from packaging import version from transformers import MixtralConfig, is_torch_available from transformers.testing_utils import ( @@ -604,3 +605,95 @@ def test_small_model_logits_batched(self): atol=1e-3, rtol=1e-3, ) + + @slow + @require_torch_gpu + def test_compile_sliding_window_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest("This test requires torch >= 2.3 to run.") + NUM_TOKENS_TO_GENERATE = 20 + EXPECTED_TOKEN_COMPLETION = { + 8: [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 18261, + 21705, + 9341, + 20302, + 31111, + 5535, + 10439, + 3799, + 12334, + 28929, + 15688, + 8388, + 15592, + 4507, + 12986, + 13895, + 14997, + 30984, + 23273, + 17094, + ], + 7: [], + } + + model_id = "hf-internal-testing/Mixtral-tiny" + dummy_input = torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9]]).to(torch_device) + attention_mask = dummy_input.ne(0).to(torch.long) + + model = MixtralForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to( + torch_device + ) + + # ugly hack + with torch.no_grad(): + from transformers.models.mixtral.modeling_mixtral import MixtralMoeBlock, MixtralSparseMoeBlock + + for decode_layer in model.model.layers: + original_block: MixtralSparseMoeBlock = getattr(decode_layer, "block_sparse_moe") + new_block: MixtralMoeBlock = MixtralMoeBlock(model.config).to(torch_device) + new_block.gate = original_block.gate + for i in range(model.config.num_local_experts): + new_block.experts.w1[i].copy_(original_block.experts[i].w1.weight) + new_block.experts.w2[i].copy_(original_block.experts[i].w2.weight) + new_block.experts.w3[i].copy_(original_block.experts[i].w3.weight) + new_block.experts.w1.data = new_block.experts.w1.data.to(original_block.experts[i].w1.weight.dtype) + new_block.experts.w2.data = new_block.experts.w2.data.to(original_block.experts[i].w2.weight.dtype) + new_block.experts.w3.data = new_block.experts.w3.data.to(original_block.experts[i].w3.weight.dtype) + setattr(decode_layer, "block_sparse_moe", new_block) + del original_block + + inputs = {"input_ids": dummy_input, "attention_mask": attention_mask} + with torch.no_grad(): + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + + self.assertEqual( + EXPECTED_TOKEN_COMPLETION[self.cuda_compute_capability_major_version], generated_ids.tolist()[0] + ) + + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" + ) + self.assertEqual( + EXPECTED_TOKEN_COMPLETION[self.cuda_compute_capability_major_version], generated_ids.tolist()[0] + ) + + model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead") + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" + ) + self.assertEqual( + EXPECTED_TOKEN_COMPLETION[self.cuda_compute_capability_major_version], generated_ids.tolist()[0] + ) From 9b2c1049756978081b99588231d4cb0245c733d4 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Thu, 16 May 2024 06:03:34 +0200 Subject: [PATCH 29/29] fix some bug --- src/transformers/models/mixtral/modeling_mixtral.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 2b468a8c442b95..d9e58efaa9165f 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -919,7 +919,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - final_hidden_states[top_x].index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits @@ -1315,7 +1315,11 @@ def _update_causal_mask( using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - if self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache): + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor,