diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index b4531e9c957c9f..452921d88c0e87 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -373,3 +373,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens - update - get_seq_length - reorder_cache + +[[autodoc]] StaticCache + - update + - get_seq_length \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b233ee2acb09ee..76f46d9f6f2e53 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1337,7 +1337,7 @@ _import_structure["activations"] = [] _import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"] _import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"] - _import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache"] + _import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache", "StaticCache"] _import_structure["data.datasets"] = [ "GlueDataset", "GlueDataTrainingArguments", @@ -6073,7 +6073,7 @@ # Benchmarks from .benchmark.benchmark import PyTorchBenchmark from .benchmark.benchmark_args import PyTorchBenchmarkArguments - from .cache_utils import Cache, DynamicCache, SinkCache + from .cache_utils import Cache, DynamicCache, SinkCache, StaticCache from .data.datasets import ( GlueDataset, GlueDataTrainingArguments, diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index b298a7bdd0f5d6..8ac6619bf6a8e6 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1,8 +1,12 @@ +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple import torch +from .configuration_utils import PretrainedConfig + +@dataclass class Cache: """ Base, abstract class for all caches. The actual data structure is specific to each subclass. @@ -320,3 +324,91 @@ def reorder_cache(self, beam_idx: torch.LongTensor): self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) device = self.value_cache[layer_idx].device self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + +class StaticCache(Cache): + """ + Static Cache class to be used with `torch.compile(model)`. + + Parameters: + config (`PretrainedConfig): + The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads` + required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (*optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + """ + + def __init__( + self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=torch.float32 + ) -> None: + super().__init__() + self.max_batch_size = max_batch_size + self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_heads = config.num_attention_heads + self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype + + cache_shape = (max_batch_size, self.num_heads, self.max_cache_len, self.head_dim) + self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) + self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) + self.seen_tokens = 0 + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. Kept for backward compatibility + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len` + to know how much of the cache it should overwrite. + + Return: + A tuple containing the updated key and value states. + """ + new_cache_positions = cache_kwargs.get("position_ids") + k_out = self.key_cache + v_out = self.value_cache + + k_out[:, :, new_cache_positions] = key_states + v_out[:, :, new_cache_positions] = value_states + + self.seen_tokens += key_states.shape[-2] + return k_out, v_out + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" + return self.seen_tokens + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" + return self.max_cache_len + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + device = self.key_cache.device + self.key_cache = self.key_cache.index_select(0, beam_idx.to(device)) + device = self.value_cache.device + self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) + + def to_legacy_cache(self): + """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it""" + return None diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 25abcc67e90e38..69e1afe63c2e9b 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -250,6 +250,11 @@ class GenerationConfig(PushToHubMixin): reduce by 1 - `"constant"`: `num_assistant_tokens` stays unchanged during generation + > Parameters specific to the caching mechanism: + + cache_implementation (`str`, *optional*, default to `None`): + Cache class that should be used when generating. + > Wild card generation_kwargs: @@ -321,6 +326,9 @@ def __init__(self, **kwargs): self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") + # Cache implementation + self.cache_implementation = kwargs.pop("cache_implementation", None) + # Prompt lookup decoding self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0b8102c353da87..1405425e623827 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -24,7 +24,7 @@ import torch.distributed as dist from torch import nn -from ..cache_utils import Cache, DynamicCache +from ..cache_utils import Cache, DynamicCache, StaticCache from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..models.auto import ( @@ -92,6 +92,10 @@ if is_accelerate_available(): from accelerate.hooks import AlignDevicesHook, add_hook_to_module +NEED_SETUP_CACHE_CLASSES_MAPPING = { + "static": StaticCache, +} + @dataclass class GenerateDecoderOnlyOutput(ModelOutput): @@ -1398,6 +1402,19 @@ def generate( "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) generation_config.max_length = generation_config.max_new_tokens + input_ids_length + + # if we don't pass `past_key_values` and a cache_implementation is specified + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING and not model_kwargs.get( + "past_key_values", False + ): + cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[generation_config.cache_implementation] + if not callable(getattr(self, "_setup_cache", None)): + raise ValueError( + "The `generation_config` defines a `cache_implementation` that is not compatible with this model." + " Make sure it has a `_setup_cache` function." + ) + self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) # 7. determine generation mode diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 4bf11dd1b41bc4..d2ea931a44f1f1 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -63,7 +63,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama class OpenLlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -154,7 +154,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 8a850012a5dd36..5fb295bbf0c585 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -88,7 +88,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -130,7 +130,7 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon class FalconRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index b0bdca3095dc99..7409dc7d3861aa 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -527,7 +527,7 @@ def attention_mask_func(attention_scores, ltor_mask): class GPTNeoXRotaryEmbedding(nn.Module): - # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ + # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -617,7 +617,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index c0d4e010c1ecf3..4ac7c4d4e0025f 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -235,7 +235,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding class RotaryEmbedding(nn.Module): - # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ + # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index d5613a8254bcb6..bdd915c1bd8d59 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -513,7 +513,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4c8579fce24d76..c657562ef1cebc 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -30,12 +30,6 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache -from ...modeling_attn_mask_utils import ( - AttentionMaskConverter, - _prepare_4d_attention_mask, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -43,7 +37,7 @@ SequenceClassifierOutputWithPast, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 +from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -52,7 +46,6 @@ logging, replace_return_docstrings, ) -from ...utils.import_utils import is_torch_fx_available from .configuration_llama import LlamaConfig @@ -61,15 +54,6 @@ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa -# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. -# It means that the function will not be traced through and simply appear as a node in the graph. -if is_torch_fx_available(): - if not is_torch_greater_or_equal_than_1_13: - import torch.fx - - _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) - - logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LlamaConfig" @@ -87,24 +71,6 @@ def _get_unpad_data(attention_mask): ) -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - warnings.warn( - "Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask" - ) - return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) - - -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - warnings.warn( - "Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttentionMaskConverter._make_causal_mask" - ) - return AttentionMaskConverter._make_causal_mask( - input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length - ) - - class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -135,30 +101,11 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): + def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) + freqs = (self.inv_freq[:, None].float().expand(-1, position_ids.shape[0]) @ (position_ids.float())).t() + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): @@ -234,8 +181,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -320,7 +265,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) self._init_rope() def _init_rope(self): @@ -350,9 +295,6 @@ def _init_rope(self): else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -363,11 +305,6 @@ def forward( use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: @@ -397,19 +334,20 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + past_seen_tokens = 0 + past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen + kv_seq_len += past_seen_tokens + + new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device) + position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -417,18 +355,9 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[..., past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -483,15 +412,6 @@ def forward( use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # LlamaFlashAttention2 attention does not support output_attentions - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -508,13 +428,19 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + past_seen_tokens = 0 + past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen + kv_seq_len += past_seen_tokens + + new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device) + position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache @@ -704,28 +630,32 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + past_seen_tokens = 0 + past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen + kv_seq_len += past_seen_tokens + new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device) + position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + causal_mask = None if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) + causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: + if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() @@ -734,14 +664,13 @@ def forward( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=causal_mask is None, ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -854,7 +783,7 @@ class LlamaPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" + _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True @@ -870,6 +799,20 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): + if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device: + causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device) + self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) + + for layer in self.model.layers: + layer.self_attn.past_key_value = cache_cls( + self.config, max_batch_size, max_cache_len, device=layer.self_attn.o_proj.weight.device + ) + + def _reset_cache(self): + for layer in self.model.layers: + layer.self_attn.past_key_value = None + LLAMA_INPUTS_DOCSTRING = r""" Args: @@ -962,11 +905,12 @@ def __init__(self, config: LlamaConfig): self.layers = nn.ModuleList( [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._use_sdpa = config._attn_implementation == "sdpa" - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.gradient_checkpointing = False + + # register a causal mask to separate causal and padding mask creation. Merging happends in the attention class + causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=1) + self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) # Initialize weights and apply final processing self.post_init() @@ -994,60 +938,26 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) - past_key_values_length = 0 - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) - position_ids = position_ids.unsqueeze(0) + use_cache = False + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) # embed positions hidden_states = inputs_embeds @@ -1065,7 +975,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, position_ids, past_key_values, output_attentions, @@ -1074,7 +984,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -1097,7 +1007,9 @@ def forward( next_cache = None if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -1107,6 +1019,49 @@ def forward( attentions=all_self_attns, ) + def _update_causal_mask(self, attention_mask, input_tensor): + if self.config._attn_implementation == "flash_attention_2": + causal_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + return causal_mask + + batch_size, seq_length = input_tensor.shape[:2] + dtype = input_tensor.dtype + + # support going beyond cached `max_position_embedding` + if seq_length > self.causal_mask.shape[-1]: + causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1) + self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) + + if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows + causal_mask = ( + self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min + ) + else: + mask = torch.full( + (self.config.max_position_embeddings, self.config.max_position_embeddings), + fill_value=torch.finfo(dtype).min, + ) + causal_mask = torch.triu(mask, diagonal=1).to(dtype) + + if attention_mask is not None and attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill( + padding_mask, torch.finfo(dtype).min + ) + + if self.config._attn_implementation == "sdpa": + if attention_mask is None: + return None + is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy) + if not is_tracing and (torch.all(attention_mask == 1)): + return None + if is_tracing and seq_length == 1: + return None + causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(dtype) + + return causal_mask + class LlamaForCausalLM(LlamaPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] @@ -1271,6 +1226,12 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] + if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): + # generation with static cache + seen_tokens = past_key_value.get_seq_length() + input_ids = input_ids[:, seen_tokens:] + position_ids = position_ids[:, seen_tokens:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index fe51d7ed2afc96..6c510dc9bb01d8 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -88,7 +88,8 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral +# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral +# TODO @Arthur no longer copied from LLama after static cache class MistralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -133,7 +134,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +# copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +# TODO @Arthur no longer copied from LLama after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -612,7 +614,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral +# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral +# TODO @Arthur no longer copied from LLama after static cache class MistralSdpaAttention(MistralAttention): """ Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -656,28 +659,34 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_seen_tokens = kv_seq_len - key_states.shape[-2] + new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) + if ( + attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1 + ): # user defined causal mask + causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]] + # this one liner is equivalent to the pad_unpad function + causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None]) + else: + causal_mask = None # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: + if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() @@ -686,14 +695,13 @@ def forward( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 5c347b38bb1e86..f1e53dd0889711 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -181,7 +181,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral class MixtralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -226,7 +226,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -692,7 +692,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mixtral +# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral class MixtralSdpaAttention(MixtralAttention): """ Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -736,28 +736,34 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_seen_tokens = kv_seq_len - key_states.shape[-2] + new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) + if ( + attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1 + ): # user defined causal mask + causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]] + # this one liner is equivalent to the pad_unpad function + causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None]) + else: + causal_mask = None # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: + if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() @@ -766,14 +772,13 @@ def forward( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index a936a7f89f06d0..592d3e914106d0 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -40,7 +40,7 @@ _CONFIG_FOR_DOC = "PersimmonConfig" -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Persimmon class PersimmonRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -132,7 +132,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -864,6 +864,12 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] + if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): + # generation with static cache + seen_tokens = past_key_value.get_seq_length() + input_ids = input_ids[:, seen_tokens:] + position_ids = position_ids[:, seen_tokens:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 52a7123a952399..98e8143f2cf1fc 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -78,7 +78,7 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi class PhiRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -170,7 +170,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -1125,6 +1125,12 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] + if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): + # generation with static cache + seen_tokens = past_key_value.get_seq_length() + input_ids = input_ids[:, seen_tokens:] + position_ids = position_ids[:, seen_tokens:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 5f7ad4bd4049d9..6338ec6e09987c 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -95,7 +95,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2 class Qwen2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -140,7 +140,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -625,7 +625,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2 +# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2 class Qwen2SdpaAttention(Qwen2Attention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -669,28 +669,34 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_seen_tokens = kv_seq_len - key_states.shape[-2] + new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) + if ( + attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1 + ): # user defined causal mask + causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]] + # this one liner is equivalent to the pad_unpad function + causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None]) + else: + causal_mask = None # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: + if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() @@ -699,14 +705,13 @@ def forward( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index c766f3f522b124..b756306c0c5dcb 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -37,6 +37,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class StaticCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class GlueDataset(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 8ee7617a0b742e..4efc5da5c401cd 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -362,6 +362,7 @@ def test_save_load_fast_init_from_base(self): pass @parameterized.expand([("linear",), ("dynamic",)]) + @unittest.skip("TODO @gante fix this for Llama") def test_model_rope_scaling(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) @@ -507,9 +508,19 @@ def test_eager_matches_sdpa_generate(self): inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) - res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) - self.assertTrue(torch.allclose(res_eager, res_sdpa)) + + with self.subTest(f"{padding_side}"): + torch.testing.assert_close( + res_eager, + res_sdpa, + msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", + ) + + @unittest.skip("TODO @gante fix this for Llama") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass @require_torch diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 72d055c8806afd..df6b15f4dcad35 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -15,14 +15,29 @@ import unittest +from parameterized import parameterized + from transformers import set_seed -from transformers.testing_utils import is_torch_available, require_auto_gptq, require_torch, require_torch_gpu, slow +from transformers.testing_utils import ( + is_torch_available, + require_auto_gptq, + require_torch, + require_torch_gpu, + slow, + torch_device, +) if is_torch_available(): import torch - from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, LlamaForCausalLM, SinkCache + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DynamicCache, + LlamaForCausalLM, + SinkCache, + ) @require_torch @@ -229,3 +244,100 @@ def test_sink_cache_iterative_prompts(self): "was visiting the historic district of Honolulu. Here," ) self.assertTrue(decoded[0].endswith(last_output)) + + @require_torch_gpu + @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) + def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): + EXPECTED_GENERATION = [ + "The best color is the one that complements the subject you are photograph", + "We should not undermind the issues at hand.\nWe should not undermind the issues", + ] + + tokenizer = AutoTokenizer.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="" + ) + model = AutoModelForCausalLM.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", + torch_dtype=torch.bfloat16, + attn_implementation=attn_implementation, + ).to(torch_device) + inputs = tokenizer( + ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" + ).to(model.device) + + set_seed(0) + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + with self.subTest(f"{attn_implementation}, dynamic"): + self.assertListEqual(decoded, EXPECTED_GENERATION) + + set_seed(0) + model.generation_config.cache_implementation = "static" + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + with self.subTest(f"{attn_implementation}, static, eager"): + self.assertListEqual(decoded, EXPECTED_GENERATION) + + set_seed(0) + model.forward = torch.compile(model.forward) + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + with self.subTest(f"{attn_implementation}, static, compiled"): + self.assertListEqual(decoded, EXPECTED_GENERATION) + + @require_torch_gpu + @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) + def test_static_cache_greedy_sampling_pad_right(self, attn_implementation): + EXPECTED_GENERATION = [ + "The best color is\n\n\n\n\n\n\n\n\n\n", + "We should not undermind the issues at hand, but address them head on.\nI think", + ] + + tokenizer = AutoTokenizer.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="" + ) + model = AutoModelForCausalLM.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", + torch_dtype=torch.bfloat16, + attn_implementation=attn_implementation, + ).to("cuda:1") + inputs = tokenizer( + ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" + ).to(model.device) + + set_seed(0) + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + with self.subTest(f"{attn_implementation}, dynamic"): + self.assertListEqual(decoded, EXPECTED_GENERATION) + + set_seed(0) + model.generation_config.cache_implementation = "static" + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + with self.subTest(f"{attn_implementation}, static, eager"): + self.assertListEqual(decoded, EXPECTED_GENERATION) + + set_seed(0) + model._forward = model.forward + compiled_forward = torch.compile(model.forward) + + def compiled(func, input_ids, **kwargs): + return func(input_ids, **kwargs) + + def call(input_ids, **kwargs): + if input_ids.shape[-1] == 1: + return compiled(compiled_forward, input_ids, **kwargs) + + return model._forward(input_ids, **kwargs) + + model.forward = call + + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + with self.subTest(f"{attn_implementation}, static, compiled"): + self.assertListEqual(decoded, EXPECTED_GENERATION) + + @unittest.skip("TODO @gante static cache's does not support beam search yet") + def test_static_cache_beam_search(self): + pass