From 19bd6cd27adc07e4d6e5d3a10172b4a76aa427c2 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Thu, 9 May 2024 05:56:25 +0200 Subject: [PATCH] add sliding window cache --- setup.py | 19 +-- src/transformers/cache_utils.py | 112 ++++++++++++++++++ src/transformers/generation/utils.py | 43 ++++++- .../open_llama/modeling_open_llama.py | 6 +- .../models/falcon/modeling_falcon.py | 6 +- .../models/gpt_neox/modeling_gpt_neox.py | 6 +- .../modeling_gpt_neox_japanese.py | 3 +- .../models/idefics/modeling_idefics.py | 3 +- .../models/mistral/modeling_mistral.py | 64 +++++++--- .../models/mixtral/modeling_mixtral.py | 18 ++- .../models/persimmon/modeling_persimmon.py | 6 +- src/transformers/models/phi/modeling_phi.py | 6 +- .../models/qwen2/modeling_qwen2.py | 9 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 9 +- .../models/stablelm/modeling_stablelm.py | 6 +- .../models/starcoder2/modeling_starcoder2.py | 15 ++- tests/models/mistral/test_modeling_mistral.py | 48 +++++++- tests/models/mixtral/test_modeling_mixtral.py | 6 + tests/models/qwen2/test_modeling_qwen2.py | 6 + .../qwen2_moe/test_modeling_qwen2_moe.py | 6 + 20 files changed, 338 insertions(+), 59 deletions(-) diff --git a/setup.py b/setup.py index 23021e8affb0a8..3061127768db9b 100644 --- a/setup.py +++ b/setup.py @@ -260,7 +260,15 @@ def run(self): extras["sklearn"] = deps_list("scikit-learn") extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp") -extras["tf-cpu"] = deps_list("keras", "tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp", "tensorflow-probability") +extras["tf-cpu"] = deps_list( + "keras", + "tensorflow-cpu", + "onnxconverter-common", + "tf2onnx", + "tensorflow-text", + "keras-nlp", + "tensorflow-probability", +) extras["torch"] = deps_list("torch", "accelerate") extras["accelerate"] = deps_list("accelerate") @@ -380,12 +388,7 @@ def run(self): + extras["tf-speech"] ) extras["dev"] = ( - extras["all"] - + extras["testing"] - + extras["quality"] - + extras["ja"] - + extras["sklearn"] - + extras["modelcreation"] + extras["all"] + extras["testing"] + extras["quality"] + extras["ja"] + extras["sklearn"] + extras["modelcreation"] ) extras["torchhub"] = deps_list( @@ -470,4 +473,4 @@ def run(self): extras["tests_examples_tf"] = deps_list() extras["tests_custom_tokenizers"] = deps_list() extras["tests_exotic_models"] = deps_list() -extras["consistency"] = deps_list() \ No newline at end of file +extras["consistency"] = deps_list() diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 2e29e19ade46a4..da8fc9ebc24b8b 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -448,3 +448,115 @@ def reset(self): # In-place ops prevent breaking the static address self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() + + +class SlidingWindowCache(Cache): + """ + Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. + + Parameters: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (*optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + """ + + def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: + super().__init__() + self.max_batch_size = max_batch_size + # take the minimum of max_cache_len and config.sliding_window so that we allocate less memory + # when we do short-sentence generation + self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len + self.model_sliding_window_size = config.sliding_window + self.sliding_window_size = min(self.max_cache_len, self.model_sliding_window_size) + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + + self.dtype = dtype if dtype is not None else torch.float32 + self.num_key_value_heads = ( + config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads + ) + + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + cache_shape = ( + config.num_hidden_layers, + max_batch_size, + self.num_key_value_heads, + self.sliding_window_size, + self.head_dim, + ) + + self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + + torch._dynamo.mark_static_address(self.key_cache) + torch._dynamo.mark_static_address(self.value_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Dict[str, Any] | None = None, + ) -> Tuple[torch.Tensor]: + cache_position = cache_kwargs.get("cache_position") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + + # assume this only happens in prefill phase when prompt length > sliding_window_size + if cache_position.shape[0] > self.sliding_window_size: + k_out = key_states[:, :, -self.sliding_window_size :, :] + v_out = value_states[:, :, -self.sliding_window_size :, :] + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return key_states, value_states + + slicing = torch.ones(self.sliding_window_size, dtype=torch.long, device=value_states.device).cumsum(0) + cache_position = cache_position.clamp(0, self.sliding_window_size - 1) + to_shift = cache_position >= self.sliding_window_size - 1 + indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size + + k_out, v_out = k_out, v_out + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + + return k_out, v_out + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + # assume this will be called only in the first generation step + # `cache_postion` will be used in other cases + return 0 + + def get_max_length(self) -> int | None: + # in theory there is no limit because the sliding window size is fixed + # no matter how long the sentence is + return None + + def need_new_cache(self, max_batch_size: int, new_max_cache_len: int) -> bool: + # this is used by model.generate, when we reuse model between generations, + # we need to be careful because the new `max_cache_len` may become + # larger and `self.sliding_window_size` might change accordingly + return max_batch_size > self.max_batch_size or ( + self.sliding_window_size < self.model_sliding_window_size and new_max_cache_len > self.max_cache_len + ) + + def reset(self): + self.key_cache.zero_() + self.value_cache.zero_() diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7f4caf26aeac7d..6521e121880ae9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -24,7 +24,7 @@ import torch.distributed as dist from torch import nn -from ..cache_utils import Cache, DynamicCache, StaticCache +from ..cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..models.auto import ( @@ -94,9 +94,7 @@ if is_accelerate_available(): from accelerate.hooks import AlignDevicesHook, add_hook_to_module -NEED_SETUP_CACHE_CLASSES_MAPPING = { - "static": StaticCache, -} +NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache} @dataclass @@ -1342,6 +1340,33 @@ def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCa self._static_cache.reset() # reset the cache for a new generation return self._static_cache + # maybe a better way is to use a single factory function to set caches for models ? + def _get_sliding_window_cache(self, max_batch_size: int, max_cache_len: int) -> SlidingWindowCache: + """ + Sets a sliding window cache for `generate`, that will persist across calls. A new cache will only be initialized a + new `generate` call requires a larger cache. + + Returns the resulting sliding window cache object. + """ + needs_new_cache = not hasattr(self, "_sliding_window_cache") or self._sliding_window_cache.need_new_cache( + max_batch_size, max_cache_len + ) + if needs_new_cache: + if hasattr(self.config, "_pre_quantization_dtype"): + cache_dtype = self.config._pre_quantization_dtype + else: + cache_dtype = self.dtype + self._sliding_window_cache = SlidingWindowCache( + config=self.config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=self.device, + dtype=cache_dtype, + ) + else: + self._sliding_window_cache.reset() # reset the cache for a new generation + return self._sliding_window_cache + @torch.no_grad() def generate( self, @@ -1557,6 +1582,16 @@ def generate( ) if generation_config.cache_implementation == "static": model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length) + elif generation_config.cache_implementation == "sliding_window": + if not hasattr(self.config, "sliding_window") or self.config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + model_kwargs["past_key_values"] = self._get_sliding_window_cache( + batch_size, generation_config.max_length + ) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 098f8c7da50d5e..e4f7677a671ef2 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -63,7 +63,8 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama +# TODO @longjie no longer copied from Mistral after static cache class OpenLlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -154,7 +155,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 1f4fd41afa2e89..d48d97c64f193e 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -84,7 +84,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -126,7 +127,8 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon +# TODO @longjie no longer copied from Mistral after static cache class FalconRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index e338c529abf293..e560e55f5dbea2 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -525,7 +525,8 @@ def attention_mask_func(attention_scores, ltor_mask): class GPTNeoXRotaryEmbedding(nn.Module): - # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ + # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ + # TODO @longjie no longer copied from Mistral after static cache def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -617,7 +618,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 9fdff2c8387006..2ae01f4fecd8ab 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -233,7 +233,8 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding class RotaryEmbedding(nn.Module): - # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ + # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ + # TODO @longjie no longer copied from Mistral after static cache def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index a01c2279c15586..a5d5c7e2f28878 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -481,7 +481,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f77312e80fa83d..1d50ee0999f918 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -30,7 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel @@ -88,7 +88,6 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral class MistralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -630,6 +629,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) bsz, q_len, _ = hidden_states.size() @@ -644,7 +644,7 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # In case static cache is used, it is an instance attribute. past_key_value = getattr(self, "past_key_value", past_key_value) @@ -1059,9 +1059,12 @@ def _update_causal_mask( # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + # cache_position must be valid here no matter which cache we use + past_seen_tokens = cache_position[0] if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) - if self.config._attn_implementation == "sdpa" and not using_static_cache: + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + if self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1074,8 +1077,13 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if using_static_cache: + # SlidingWindowCache + if using_sliding_window_cache: + target_length = max(sequence_length, self.config.sliding_window) + # StaticCache + elif using_static_cache: target_length = past_key_values.get_max_length() + # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -1086,12 +1094,30 @@ def _update_causal_mask( causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() == 2): - exclude_mask |= torch.arange(target_length, device=device) < ( - cache_position.reshape(-1, 1) - self.config.sliding_window - ) - causal_mask *= exclude_mask + if self.config.sliding_window is not None: + if attention_mask is not None and attention_mask.dim() == 4: + logger.warning_once( + "Sliding window will not take effect when passing 4d custom masks" + "you may get unexpected results, use attention mask generated by tokenizer" + "or set model.config.sliding_window to None if you don't want sliding window" + ) + + # can only happen in prefill phase, when the prompt length > sliding window length, we need to do this + # manually because we are returning the whole prompt token sequence in `SlidingWindowCache`, maybe a better + # way is to support chunked prefill instead + if sequence_length > self.config.sliding_window and using_sliding_window_cache: + exclude_mask |= torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - self.config.sliding_window + ) + + # not using `SlidingWindowCache` and attention mask supports sliding window + if (attention_mask is None or attention_mask.dim() == 2) and not using_sliding_window_cache: + exclude_mask |= torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - self.config.sliding_window + ) + + causal_mask *= exclude_mask causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit @@ -1112,9 +1138,9 @@ def _update_causal_mask( offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( - mask_slice - ) + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice if ( self.config._attn_implementation == "sdpa" @@ -1250,7 +1276,6 @@ def forward( attentions=outputs.attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1305,6 +1330,15 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] + # crop the attention_mask to sliding window size during decode phase if using SlidingWindowCache + if ( + past_length > 0 + and attention_mask is not None + and isinstance(past_key_values, SlidingWindowCache) + and attention_mask.shape[1] > past_key_values.sliding_window_size + ): + attention_mask = attention_mask[:, -past_key_values.sliding_window_size :] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index e5a81c4c9083ed..9c1ab1c07d6230 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -181,7 +181,8 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache class MixtralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -226,7 +227,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -268,7 +270,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache class MixtralAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -397,7 +400,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache class MixtralFlashAttention2(MixtralAttention): """ Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays @@ -692,7 +696,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache class MixtralSdpaAttention(MixtralAttention): """ Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -1074,7 +1079,8 @@ def _init_weights(self, module): "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", MIXTRAL_START_DOCSTRING, ) -# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache class MixtralModel(MixtralPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 8d4ad532074f19..2e05a0ecea51f0 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -40,7 +40,8 @@ _CONFIG_FOR_DOC = "PersimmonConfig" -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Persimmon +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Persimmon +# TODO @longjie no longer copied from Mistral after static cache class PersimmonRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -132,7 +133,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index b23073d332e4d5..133ab90c0fe5fd 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -79,7 +79,8 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi +# TODO @longjie no longer copied from Mistral after static cache class PhiRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -171,7 +172,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index b5a1370ae1fc8f..9e803da066aac6 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -90,7 +90,8 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2 +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2 +# TODO @longjie no longer copied from Mistral after static cache class Qwen2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -135,7 +136,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -620,7 +622,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2 +# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2 +# TODO @longjie no longer copied from Mistral after static cache class Qwen2SdpaAttention(Qwen2Attention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index ca349dca1c1bc3..5da4282dcf828a 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -170,7 +170,8 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2Moe +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2Moe +# TODO @longjie no longer copied from Mistral after static cache class Qwen2MoeRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -215,7 +216,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -702,7 +704,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2Moe +# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2Moe +# TODO @longjie no longer copied from Mistral after static cache class Qwen2MoeSdpaAttention(Qwen2MoeAttention): """ Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index bc133ffb3d7329..6dd2c58e5e91e6 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -66,7 +66,8 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->StableLm +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->StableLm +# TODO @longjie no longer copied from Mistral after static cache class StableLmRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -158,7 +159,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 61e8518d659cae..8660ee06ad9b42 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -70,7 +70,8 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Starcoder2 +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Starcoder2 +# TODO @longjie no longer copied from Mistral after static cache class Starcoder2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -115,7 +116,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -599,7 +601,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Starcoder2 +# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Starcoder2 +# TODO @longjie no longer copied from Mistral after static cache class Starcoder2SdpaAttention(Starcoder2Attention): """ Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -708,7 +711,8 @@ def __init__(self, config: Starcoder2Config, layer_idx: int): self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) - # Copied from transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward + # copied from transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward + # TODO @longjie no longer copied from Mistral after static cache def forward( self, hidden_states: torch.Tensor, @@ -1067,7 +1071,8 @@ def forward( ) -# Copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM with MISTRAL->STARCODER2,Mistral-7B-v0.1->starcoder2-7b_16k,Mistral->Starcoder2,mistralai->bigcode +# copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM with MISTRAL->STARCODER2,Mistral-7B-v0.1->starcoder2-7b_16k,Mistral->Starcoder2,mistralai->bigcode +# TODO @longjie no longer copied from Mistral after static cache class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 25c2cbc1f658a4..c3c685112a4c0c 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -471,7 +471,6 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mistral flash attention does not support right padding") - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_new_cache_format @unittest.skip("TODO @gante fix this for Mistral") @parameterized.expand([(1, False), (1, True), (4, False)]) def test_new_cache_format(self, num_beams, do_sample): @@ -681,3 +680,50 @@ def test_compile_static_cache(self): ) static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) + + @slow + def test_compile_sliding_window_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest("This test requires torch >= 2.3 to run.") + + NUM_TOKENS_TO_GENERATE = 40 + EXPECTED_TEXT_COMPLETION = { + 8: [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in a vacuum, " + "and 2) the laws of physics are the same for all observers in uniform motion.\n\nThe first part of the theory is" + ], + 7: [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in a vacuum, " + "and 2) the laws of physics are the same for all observers in uniform motion.\n\nThe first part of the theory is" + ], + } + + prompts = ["Simply put, the theory of relativity states that "] + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False) + tokenizer.pad_token = tokenizer.eos_token + model = MistralForCausalLM.from_pretrained( + "mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16 + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], dynamic_text) + + # Sliding Window Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) + + # Sliding Window Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 0d92595d8cfa85..d3ec1a7e18db80 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -19,6 +19,7 @@ import unittest import pytest +from parameterized import parameterized from transformers import MixtralConfig, is_torch_available from transformers.testing_utils import ( @@ -505,6 +506,11 @@ def test_load_balancing_loss(self): # This is to mimic torch.testing.assert_not_close self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) + @unittest.skip("TODO @gante fix this for Mixtral") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + @require_torch class MixtralIntegrationTest(unittest.TestCase): diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index f4e88a97f06a53..2e5de17ffc5433 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -20,6 +20,7 @@ import unittest import pytest +from parameterized import parameterized from transformers import AutoTokenizer, Qwen2Config, is_torch_available, set_seed from transformers.testing_utils import ( @@ -481,6 +482,11 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Qwen2 flash attention does not support right padding") + @unittest.skip("TODO @gante fix this for Qwen2") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + @require_torch class Qwen2IntegrationTest(unittest.TestCase): diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index f0818e680d3da8..8620ddb56575a9 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -20,6 +20,7 @@ import unittest import pytest +from parameterized import parameterized from transformers import AutoTokenizer, Qwen2MoeConfig, is_torch_available, set_seed from transformers.testing_utils import ( @@ -545,6 +546,11 @@ def test_load_balancing_loss(self): # This is to mimic torch.testing.assert_not_close self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) + @unittest.skip("TODO @gante fix this for Qwen2Moe") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + @require_torch class Qwen2MoeIntegrationTest(unittest.TestCase):