diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 4b44c1d78c81f0..3273f5dac41dfe 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -29,7 +29,7 @@ To optimize this, you can use a kv-cache to store the past keys and values inste The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up. > [!WARNING] -> Currently, only [Command R](./model_doc/cohere), [Gemma](./model_doc/gemma) and [Llama](./model_doc/llama2) models support static kv-cache and torch.compile. +> Currently, only [Command R](./model_doc/cohere), [Gemma](./model_doc/gemma), [Llama](./model_doc/llama2) and [Mistral](./model_doc/mistral.md) models support static kv-cache and torch.compile. For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model. diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index d2f711f143a02b..f77312e80fa83d 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -17,10 +17,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Mistral model.""" +"""PyTorch Mistral model.""" + import inspect import math -import warnings from typing import List, Optional, Tuple, Union import torch @@ -88,8 +88,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral -# TODO @Arthur no longer copied from LLama after static cache +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral class MistralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -123,7 +122,7 @@ def cos_cached(self): "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) return self._cos_cached - + @torch.no_grad def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] @@ -149,8 +148,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -# TODO @Arthur no longer copied from LLama after static cache +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -261,12 +259,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -278,12 +271,12 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) - + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position" : cache_position} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -342,7 +335,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - **kwargs, + cache_position: Optional[torch.LongTensor] = None, ): if isinstance(past_key_value, StaticCache): raise ValueError( @@ -365,7 +358,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_seq_length() + kv_seq_len += cache_position[0] cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -605,8 +598,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral -# TODO @Arthur no longer copied from LLama after static cache +# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral class MistralSdpaAttention(MistralAttention): """ Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -669,7 +661,6 @@ def forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: @@ -680,14 +671,14 @@ def forward( # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` is_causal = True if causal_mask is None and q_len > 1 else False - + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -741,11 +732,6 @@ def forward( past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -954,7 +940,7 @@ def forward( raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) - + if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." @@ -972,13 +958,14 @@ def forward( cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - + if position_ids is None: position_ids = cache_position.unsqueeze(0) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, use_cache + ) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, use_cache) - hidden_states = inputs_embeds # decoder layers @@ -1041,21 +1028,20 @@ def forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) - - # copied from Llama implementation + def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - use_cache: bool + use_cache: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - + if self._attn_implementation == "flash_attention_2": if attention_mask is not None and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] @@ -1065,9 +1051,10 @@ def _update_causal_mask( " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - if attention_mask is not None and 0.0 in attention_mask: return attention_mask + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask return None - + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. @@ -1076,8 +1063,11 @@ def _update_causal_mask( using_static_cache = isinstance(past_key_values, StaticCache) if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, is_training=self.training, + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, ): return None @@ -1096,8 +1086,10 @@ def _update_causal_mask( causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() != 4): - exclude_mask |= torch.arange(target_length, device=device) < (cache_position.reshape(-1,1) - self.config.sliding_window) + if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() == 2): + exclude_mask |= torch.arange(target_length, device=device) < ( + cache_position.reshape(-1, 1) - self.config.sliding_window + ) causal_mask *= exclude_mask causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) @@ -1120,9 +1112,9 @@ def _update_causal_mask( offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( + mask_slice + ) if ( self.config._attn_implementation == "sdpa" @@ -1258,6 +1250,7 @@ def forward( attentions=outputs.attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1272,9 +1265,6 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - # cache_length = past_key_values.get_seq_length() - # past_length = past_key_values.seen_tokens - # max_cache_length = past_key_values.get_max_length() past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) @@ -1282,6 +1272,7 @@ def prepare_inputs_for_generation( else None ) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -1320,6 +1311,12 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids.contiguous()} + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + model_inputs.update( { "position_ids": position_ids, diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 1a45991dbefa7b..25c2cbc1f658a4 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -12,8 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Testing suite for the PyTorch Mistral model. """ - +"""Testing suite for the PyTorch Mistral model.""" import gc import tempfile @@ -471,13 +470,14 @@ def test_flash_attn_2_generate_use_cache(self): @slow def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mistral flash attention does not support right padding") - - # copied from Llama tests to supress errors for now + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_new_cache_format @unittest.skip("TODO @gante fix this for Mistral") @parameterized.expand([(1, False), (1, True), (4, False)]) def test_new_cache_format(self, num_beams, do_sample): pass + @require_torch_gpu class MistralIntegrationTest(unittest.TestCase): # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) @@ -634,7 +634,7 @@ def test_speculative_generation(self): del model backend_empty_cache(torch_device) gc.collect() - + @slow def test_compile_static_cache(self): # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 @@ -644,9 +644,14 @@ def test_compile_static_cache(self): NUM_TOKENS_TO_GENERATE = 40 EXPECTED_TEXT_COMPLETION = { - 8: ['My favourite condiment is 100% ketchup. I love it on everything. ' - 'I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles'], - # 7: [], + 8: [ + "My favourite condiment is 100% ketchup. I love it on everything. " + "I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles" + ], + 7: [ + "My favourite condiment is 100% ketchup. I love it on everything. " + "I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles" + ], } prompts = ["My favourite condiment is "] @@ -675,4 +680,4 @@ def test_compile_static_cache(self): **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" ) static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) \ No newline at end of file + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)