From f14637a7b540ac1a1b21f6ea919c80a51e287dc6 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 28 Nov 2024 08:01:54 +0100 Subject: [PATCH 01/40] refactor LlamaAttention --- .../models/llama/modeling_llama.py | 420 ++++++------------ 1 file changed, 131 insertions(+), 289 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0408bb73c7f2da..803a0003868018 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -168,31 +168,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." - ) - kwargs["rope_type"] = "linear" - super().__init__(*args, **kwargs) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " - "__init__)." - ) - kwargs["rope_type"] = "dynamic" - super().__init__(*args, **kwargs) - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -255,6 +230,126 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def eager_attention_forward(config, query, key, value, mask, **_kwargs): + key_states = repeat_kv(key, config.num_key_value_groups) + value_states = repeat_kv(value, config.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + + if config.attn_logit_softcapping is not None: + attn_weights = attn_weights / config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * config.attn_logit_softcapping + if mask is not None: # no matter the length, we just slice it + causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **kwargs): + if mask is not None: + seq_len = mask.shape[1] + query = query[:, :, :seq_len] + value = value[:, :, :seq_len] + + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + value_states = value.transpose(1, 2) + + dropout_rate = config.attention_dropout if config.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + mask, + seq_len, + dropout=dropout_rate, + softmax_scale=config.scaling, + is_causal=config.is_causal, + sliding_window=config.sliding_window, + use_top_left_mask=config._flash_attn_uses_top_left_mask, + **kwargs + ) + + return attn_output, None + + +def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): + def tanh_softcap(score, b, h, q_idx, kv_idx): + soft_cap = config.attn_logit_softcapping + score = soft_cap * torch.tanh(score / soft_cap) + if mask is not None: + return score + mask[b][0][q_idx][kv_idx] + return score + + attn_output = flex_attention( + query, + key, + value, + score_mod=tanh_softcap, + enable_gqa=True, + scale=config.scaling, + return_lse=output_attentions, + ) + if not output_attentions: + return attn_output, None + else: + return attn_output[0], attn_output[1] + + +def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): + key = repeat_kv(key, config.num_key_value_groups) + value = repeat_kv(value, config.num_key_value_groups) + + causal_mask = mask + if mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query.device.type == "cuda" and causal_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and query.shape[1] > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=config.attention_dropout if config.training else 0.0, + is_causal=is_causal, + scale=config.scaling, + ) + return attn_output, None + + +LLAMA_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "flex_attention": flex_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -284,9 +379,6 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) - self.rotary_emb = LlamaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -297,7 +389,7 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -310,16 +402,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -327,151 +410,17 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaFlashAttention2(LlamaAttention): - """ - Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) + if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: + attention_type = "eager" else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + attention_type = self.config._attn_implementation - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - **kwargs, + attn_output, attn_weights = LLAMA_ATTENTION_FUNCTION[attention_type]( + self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions, **kwargs ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) if not output_attentions: @@ -480,119 +429,12 @@ def forward( return attn_output, attn_weights, past_key_value -class LlamaSdpaAttention(LlamaAttention): - """ - Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from LlamaAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -LLAMA_ATTENTION_CLASSES = { - "eager": LlamaAttention, - "flash_attention_2": LlamaFlashAttention2, - "sdpa": LlamaSdpaAttention, -} - - class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -608,7 +450,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: From f446bd4c004d10e8a10d5f686455204e2befa09b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 12:20:51 +0100 Subject: [PATCH 02/40] only change lLlama --- .../modeling_flash_attention_utils.py | 46 ++- src/transformers/modeling_utils.py | 206 +++++++----- .../models/llama/modeling_llama.py | 303 ++++++------------ src/transformers/utils/generic.py | 4 + 4 files changed, 267 insertions(+), 292 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 1b9274e21f5205..52051e3f52d140 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -231,7 +231,7 @@ def _flash_attention_forward( if not use_top_left_mask: causal = is_causal else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.mistral.modeling_mistral.MistralFlashAttention2.__init__. causal = is_causal and query_length != 1 # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). @@ -336,3 +336,47 @@ class FlashAttentionKwargs(TypedDict, total=False): cu_seq_lens_k: Optional[torch.LongTensor] max_length_q: Optional[int] max_length_k: Optional[int] + + +class TransformersKwargs(TypedDict, total=False): + output_attentions: Optional[bool] + output_hidden_states: Optional[bool] + use_cache: Optional[bool] + return_dict: Optional[bool] + + + +from functools import wraps +from typing import Callable, TypedDict, Optional + + + +def validate_config_kwargs(config): + """ + A decorator to validate and initialize kwargs based on a config object. + """ + def decorator(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + # Default values from the config + default_kwargs = { + "output_attentions": config.output_attentions, + "output_hidden_states": config.output_hidden_states, + "use_cache": config.use_cache, + "return_dict": config.use_return_dict, + } + + # Merge provided kwargs with defaults + validated_kwargs = {**default_kwargs, **kwargs} + + # Validate kwargs against TypedDict + for key in validated_kwargs: + if key not in TransformersKwargs.__annotations__: + raise ValueError(f"Invalid keyword argument: {key}") + + # Pass the validated kwargs to the function + return func(*args, **validated_kwargs) + + return wrapper + + return decorator \ No newline at end of file diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4703c415e42fbb..3bf6465b666606 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2534,92 +2534,6 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]): self.base_model._prune_heads(heads_to_prune) - def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): - """ - Activates gradient checkpointing for the current model. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". - - We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of - the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 - - Args: - gradient_checkpointing_kwargs (dict, *optional*): - Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. - """ - if not self.supports_gradient_checkpointing: - raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") - - if gradient_checkpointing_kwargs is None: - gradient_checkpointing_kwargs = {"use_reentrant": True} - - gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) - - # For old GC format (transformers < 4.35.0) for models that live on the Hub - # we will fall back to the overwritten `_set_gradient_checkpointing` method - _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters - - if not _is_using_old_format: - self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) - else: - self.apply(partial(self._set_gradient_checkpointing, value=True)) - logger.warning( - "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." - "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." - ) - - if getattr(self, "_hf_peft_config_loaded", False): - # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True - # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 - # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate - # the gradients to make sure the gradient flows. - self.enable_input_require_grads() - - def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint): - is_gradient_checkpointing_set = False - - # Apply it on the top-level module in case the top-level modules supports it - # for example, LongT5Stack inherits from `PreTrainedModel`. - if hasattr(self, "gradient_checkpointing"): - self._gradient_checkpointing_func = gradient_checkpointing_func - self.gradient_checkpointing = enable - is_gradient_checkpointing_set = True - - for module in self.modules(): - if hasattr(module, "gradient_checkpointing"): - module._gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = enable - is_gradient_checkpointing_set = True - - if not is_gradient_checkpointing_set: - raise ValueError( - f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute" - " `gradient_checkpointing` to modules of the model that uses checkpointing." - ) - - def gradient_checkpointing_disable(self): - """ - Deactivates gradient checkpointing for the current model. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". - """ - if self.supports_gradient_checkpointing: - # For old GC format (transformers < 4.35.0) for models that live on the Hub - # we will fall back to the overwritten `_set_gradient_checkpointing` methid - _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters - if not _is_using_old_format: - self._set_gradient_checkpointing(enable=False) - else: - logger.warning( - "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." - "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." - ) - self.apply(partial(self._set_gradient_checkpointing, value=False)) - - if getattr(self, "_hf_peft_config_loaded", False): - self.disable_input_require_grads() @property def is_gradient_checkpointing(self) -> bool: @@ -5568,3 +5482,123 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): files_content[filename].append(device_map[weight_name]) return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] + + +class GradientCheckpointLayer(torch.nn.Module): + def __call__(self, *args, **kwargs): + """ + Adjust the behavior of the inherited class by overriding `__call__`. + + Automatically handles gradient checkpointing based on flags in the provided arguments. + """ + # Extract necessary flags and arguments + gradient_checkpointing = kwargs.pop("gradient_checkpointing", False) + training = self.training + + if gradient_checkpointing and training: + # Use gradient checkpointing + return self._apply_gradient_checkpointing(*args, **kwargs) + else: + # Default behavior: call the original `forward` method + return super().__call__(*args, **kwargs) + + def _apply_gradient_checkpointing(self, *args, **kwargs): + """ + Apply gradient checkpointing using the appropriate function. + + By default, uses `torch.utils.checkpoint.checkpoint`. + """ + # Assume `self.forward` is compatible with checkpointing + return checkpoint(self.__call__, *args, **kwargs) + + + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of + the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + + Args: + gradient_checkpointing_kwargs (dict, *optional*): + Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. + """ + if not self.supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": True} + + gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) + + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` method + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) + else: + self.apply(partial(self._set_gradient_checkpointing, value=True)) + logger.warning( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + + if getattr(self, "_hf_peft_config_loaded", False): + # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True + # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 + # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate + # the gradients to make sure the gradient flows. + self.enable_input_require_grads() + + def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint): + is_gradient_checkpointing_set = False + + # Apply it on the top-level module in case the top-level modules supports it + # for example, LongT5Stack inherits from `PreTrainedModel`. + if hasattr(self, "gradient_checkpointing"): + self._gradient_checkpointing_func = gradient_checkpointing_func + self.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + for module in self.modules(): + if hasattr(module, "gradient_checkpointing"): + module._gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + if not is_gradient_checkpointing_set: + raise ValueError( + f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute" + " `gradient_checkpointing` to modules of the model that uses checkpointing." + ) + + def gradient_checkpointing_disable(self): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self.supports_gradient_checkpointing: + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` methid + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=False) + else: + logger.warning( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + self.apply(partial(self._set_gradient_checkpointing, value=False)) + + if getattr(self, "_hf_peft_config_loaded", False): + self.disable_input_require_grads() + + +ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, function]] = {} diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 803a0003868018..aa261f557c8676 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -17,7 +17,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import List, Optional, Tuple, Union import torch @@ -36,16 +35,16 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) +from ...utils.generic import validate_config_kwargs from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( - LossKwargs, + KwargsForCausalLM, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -230,59 +229,37 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - -def eager_attention_forward(config, query, key, value, mask, **_kwargs): - key_states = repeat_kv(key, config.num_key_value_groups) - value_states = repeat_kv(value, config.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling - - if config.attn_logit_softcapping is not None: - attn_weights = attn_weights / config.attn_logit_softcapping - attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * config.attn_logit_softcapping - if mask is not None: # no matter the length, we just slice it - causal_mask = mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights - - -def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **kwargs): +def flash_attention_forward( + config, query, key, value, mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs +): if mask is not None: seq_len = mask.shape[1] query = query[:, :, :seq_len] value = value[:, :, :seq_len] + else: + seq_len = query.shape[1] - query_states = query.transpose(1, 2) - key_states = key.transpose(1, 2) - value_states = value.transpose(1, 2) - - dropout_rate = config.attention_dropout if config.training else 0.0 + dropout_rate = config.attention_dropout if training else 0.0 - input_dtype = query_states.dtype + input_dtype = query.dtype if input_dtype == torch.float32: - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, + query, + key, + value, mask, seq_len, dropout=dropout_rate, - softmax_scale=config.scaling, - is_causal=config.is_causal, - sliding_window=config.sliding_window, - use_top_left_mask=config._flash_attn_uses_top_left_mask, - **kwargs + softmax_scale=getattr(config, "scaling", 1.0), + is_causal=getattr(config, "is_causal", False), + sliding_window=getattr(config, "sliding_window", None), + use_top_left_mask=getattr(config, "_flash_attn_uses_top_left_mask", False), + layer_idx=layer_idx, + **kwargs, ) return attn_output, None @@ -342,37 +319,40 @@ def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): return attn_output, None -LLAMA_ATTENTION_FUNCTION = { - "flash_attention_2": flash_attention_forward, - "flex_attention": flex_attention_forward, - "eager": eager_attention_forward, - "sdpa": sdpa_attention_forward, -} +ALL_ATTENTION_FUNCTIONS.update( + { + "llama.flash_attention_2": flash_attention_forward, + "llama.flex_attention": flex_attention_forward, + "llama.sdpa": sdpa_attention_forward, + } +) +def eager_attention_forward(attention_class:nn.Module, query, key, value, mask, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, config.num_key_value_groups) + value_states = repeat_kv(value, config.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) - self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) @@ -382,54 +362,49 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], # will become mandatory in v4.46 attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape, _ = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=2) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: - attention_type = "eager" - else: - attention_type = self.config._attn_implementation + attention_interface: function = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, attn_weights = LLAMA_ATTENTION_FUNCTION[attention_type]( - self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions, **kwargs + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) - + attn_output = attn_output.reshape(*hidden_shape, -1) attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class LlamaDecoderLayer(nn.Module): +class LlamaDecoderLayer(GradientCheckpointLayer): def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -443,49 +418,22 @@ def __init__(self, config: LlamaConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + output_attentions: Optional[bool] = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - position_ids=position_ids, past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, @@ -503,9 +451,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -654,9 +599,6 @@ def __init__(self, config: LlamaConfig): self.rotary_emb = LlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") - # Initialize weights and apply final processing self.post_init() @@ -666,7 +608,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @validate_config_kwargs def forward( self, input_ids: torch.LongTensor = None, @@ -681,12 +623,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -700,31 +637,20 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + if position_ids is None: position_ids = cache_position.unsqueeze(0) + causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -733,42 +659,25 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -778,19 +687,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) + if not return_dict: + return output.to_tuple() + return output + def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -913,9 +820,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -1258,12 +1162,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) + @validate_config_kwargs def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1272,10 +1171,7 @@ def forward( past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + **kwargs, ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1291,10 +1187,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + **kwargs ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 26ec82b20fd40e..5173c962da5b32 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -29,6 +29,7 @@ import numpy as np from packaging import version +from ..modeling_flash_attention_utils import FlashAttentionKwargs from .import_utils import ( get_torch_version, is_flax_available, @@ -867,3 +868,6 @@ class LossKwargs(TypedDict, total=False): """ num_items_in_batch: Optional[int] + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... From 0384db9c0c54ca6dfd45cf1ba9d315a1efa71189 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 12:39:07 +0100 Subject: [PATCH 03/40] more refactoring --- .../integrations/flash_attention.py | 37 ++ .../integrations/flex_attention.py | 23 + .../integrations/sdpa_attention.py | 32 ++ src/transformers/modeling_utils.py | 19 +- src/transformers/models/auto/modeling_task.py | 296 ++++++++++ .../models/llama/modeling_llama.py | 505 +----------------- 6 files changed, 407 insertions(+), 505 deletions(-) create mode 100644 src/transformers/integrations/flash_attention.py create mode 100644 src/transformers/integrations/flex_attention.py create mode 100644 src/transformers/integrations/sdpa_attention.py create mode 100644 src/transformers/models/auto/modeling_task.py diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py new file mode 100644 index 00000000000000..892b9bb8fe572a --- /dev/null +++ b/src/transformers/integrations/flash_attention.py @@ -0,0 +1,37 @@ +import torch + + +def flash_attention_forward( + config, query, key, value, mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs +): + if mask is not None: + seq_len = mask.shape[1] + query = query[:, :, :seq_len] + value = value[:, :, :seq_len] + else: + seq_len = query.shape[1] + + dropout_rate = config.attention_dropout if training else 0.0 + + input_dtype = query.dtype + if input_dtype == torch.float32: + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = _flash_attention_forward( + query, + key, + value, + mask, + seq_len, + dropout=dropout_rate, + softmax_scale=getattr(config, "scaling", 1.0), + is_causal=getattr(config, "is_causal", False), + sliding_window=getattr(config, "sliding_window", None), + use_top_left_mask=getattr(config, "_flash_attn_uses_top_left_mask", False), + layer_idx=layer_idx, + **kwargs, + ) + + return attn_output, None \ No newline at end of file diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py new file mode 100644 index 00000000000000..97bf10f9c41fbd --- /dev/null +++ b/src/transformers/integrations/flex_attention.py @@ -0,0 +1,23 @@ +import torch + +def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): + def tanh_softcap(score, b, h, q_idx, kv_idx): + soft_cap = config.attn_logit_softcapping + score = soft_cap * torch.tanh(score / soft_cap) + if mask is not None: + return score + mask[b][0][q_idx][kv_idx] + return score + + attn_output = flex_attention( + query, + key, + value, + score_mod=tanh_softcap, + enable_gqa=True, + scale=config.scaling, + return_lse=output_attentions, + ) + if not output_attentions: + return attn_output, None + else: + return attn_output[0], attn_output[1] diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py new file mode 100644 index 00000000000000..96b7556e784d1a --- /dev/null +++ b/src/transformers/integrations/sdpa_attention.py @@ -0,0 +1,32 @@ +import torch + + +def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): + key = repeat_kv(key, config.num_key_value_groups) + value = repeat_kv(value, config.num_key_value_groups) + + causal_mask = mask + if mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query.device.type == "cuda" and causal_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and query.shape[1] > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=config.attention_dropout if config.training else 0.0, + is_causal=is_causal, + scale=config.scaling, + ) + return attn_output, None diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3bf6465b666606..44d367d3940f77 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -45,6 +45,9 @@ from .dynamic_module_utils import custom_object_save from .generation import GenerationConfig, GenerationMixin from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled +from .integrations.flash_attention import * +from .integrations.flex_attention import * +from .integrations.sdpa_attention import * from .loss.loss_utils import LOSS_MAPPING from .pytorch_utils import ( # noqa: F401 Conv1D, @@ -5484,6 +5487,19 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] + + +ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, function]] = {} + +ALL_ATTENTION_FUNCTIONS.update( + { + "flash_attention_2": flash_attention_forward, + "flex_attention": flex_attention_forward, + "sdpa": sdpa_attention_forward, + } +) + + class GradientCheckpointLayer(torch.nn.Module): def __call__(self, *args, **kwargs): """ @@ -5599,6 +5615,3 @@ def gradient_checkpointing_disable(self): if getattr(self, "_hf_peft_config_loaded", False): self.disable_input_require_grads() - - -ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, function]] = {} diff --git a/src/transformers/models/auto/modeling_task.py b/src/transformers/models/auto/modeling_task.py new file mode 100644 index 00000000000000..279321f1ef4bc0 --- /dev/null +++ b/src/transformers/models/auto/modeling_task.py @@ -0,0 +1,296 @@ +import torch + +class AutoForCausalLM(LlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _embeding_layer = "model.embed_tokens" + _output_embedding = "lm_head" + + def __init__(self, config): + super().__init__(config) + self.model = AutoModel.from_config(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class AutoForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + + +class AutoForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class AutoForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @validate_config_kwargs + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index aa261f557c8676..2fa00783bb3fd4 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -37,7 +37,7 @@ ) from ...utils.generic import validate_config_kwargs from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, GradientCheckpointLayer from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( @@ -53,7 +53,6 @@ logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf" _CONFIG_FOR_DOC = "LlamaConfig" @@ -229,105 +228,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def flash_attention_forward( - config, query, key, value, mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs -): - if mask is not None: - seq_len = mask.shape[1] - query = query[:, :, :seq_len] - value = value[:, :, :seq_len] - else: - seq_len = query.shape[1] - - dropout_rate = config.attention_dropout if training else 0.0 - - input_dtype = query.dtype - if input_dtype == torch.float32: - query = query.to(target_dtype) - key = key.to(target_dtype) - value = value.to(target_dtype) - - attn_output = _flash_attention_forward( - query, - key, - value, - mask, - seq_len, - dropout=dropout_rate, - softmax_scale=getattr(config, "scaling", 1.0), - is_causal=getattr(config, "is_causal", False), - sliding_window=getattr(config, "sliding_window", None), - use_top_left_mask=getattr(config, "_flash_attn_uses_top_left_mask", False), - layer_idx=layer_idx, - **kwargs, - ) - - return attn_output, None - - -def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): - def tanh_softcap(score, b, h, q_idx, kv_idx): - soft_cap = config.attn_logit_softcapping - score = soft_cap * torch.tanh(score / soft_cap) - if mask is not None: - return score + mask[b][0][q_idx][kv_idx] - return score - - attn_output = flex_attention( - query, - key, - value, - score_mod=tanh_softcap, - enable_gqa=True, - scale=config.scaling, - return_lse=output_attentions, - ) - if not output_attentions: - return attn_output, None - else: - return attn_output[0], attn_output[1] - - -def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): - key = repeat_kv(key, config.num_key_value_groups) - value = repeat_kv(value, config.num_key_value_groups) - - causal_mask = mask - if mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query.device.type == "cuda" and causal_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and query.shape[1] > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=causal_mask, - dropout_p=config.attention_dropout if config.training else 0.0, - is_causal=is_causal, - scale=config.scaling, - ) - return attn_output, None - - -ALL_ATTENTION_FUNCTIONS.update( - { - "llama.flash_attention_2": flash_attention_forward, - "llama.flex_attention": flex_attention_forward, - "llama.sdpa": sdpa_attention_forward, - } -) - - def eager_attention_forward(attention_class:nn.Module, query, key, value, mask, **_kwargs): config = attention_class.config key_states = repeat_kv(key, config.num_key_value_groups) @@ -585,6 +485,7 @@ class LlamaModel(LlamaPreTrainedModel): Args: config: LlamaConfig """ + _input_embedding = "embed_tokens" # no need for set and get, take then from PreTrainedModel def __init__(self, config: LlamaConfig): super().__init__(config) @@ -598,16 +499,9 @@ def __init__(self, config: LlamaConfig): self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - @validate_config_kwargs def forward( self, @@ -623,8 +517,6 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: - - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -693,10 +585,7 @@ def forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) - - if not return_dict: - return output.to_tuple() - return output + return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -819,391 +708,3 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask - -class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - The LLaMa Model transformer with a sequence classification head on top (linear layer). - - [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - LLAMA_START_DOCSTRING, -) -class LlamaForSequenceClassification(LlamaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = LlamaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - -@add_start_docstrings( - """ -The Llama Model transformer with a span classification head on top for extractive question-answering tasks like -SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - LLAMA_START_DOCSTRING, -) -class LlamaForQuestionAnswering(LlamaPreTrainedModel): - base_model_prefix = "transformer" - - # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama - def __init__(self, config): - super().__init__(config) - self.transformer = LlamaModel(config) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.transformer.embed_tokens - - def set_input_embeddings(self, value): - self.transformer.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.transformer( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - loss = None - if start_positions is not None and end_positions is not None: - loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return QuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states - output) e.g. for Named-Entity-Recognition (NER) tasks. - """, - LLAMA_START_DOCSTRING, -) -class LlamaForTokenClassification(LlamaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = LlamaModel(config) - if getattr(config, "classifier_dropout", None) is not None: - classifier_dropout = config.classifier_dropout - elif getattr(config, "hidden_dropout", None) is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @validate_config_kwargs - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Union[Tuple, TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - **kwargs - ) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output) - logits = self.score(sequence_output) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.config) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) From 4e681b9c72288b1fc7e7c924760a0a77554801b1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 13:38:19 +0100 Subject: [PATCH 04/40] nits --- src/transformers/integrations/flash_attention.py | 2 +- src/transformers/models/llama/modeling_llama.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 892b9bb8fe572a..58606f7e021463 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -1,5 +1,5 @@ import torch - +from ..modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward def flash_attention_forward( config, query, key, value, mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 2fa00783bb3fd4..c375c71e99ebcb 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -301,9 +301,6 @@ def forward( return attn_output, attn_weights - - - class LlamaDecoderLayer(GradientCheckpointLayer): def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() From 893ef382c42742ebb47d91beb056558d119523f2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 13:47:59 +0100 Subject: [PATCH 05/40] nits --- src/transformers/integrations/flash_attention.py | 5 +---- src/transformers/modeling_utils.py | 6 +++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 58606f7e021463..c1ba6a7720bc22 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -25,11 +25,8 @@ def flash_attention_forward( value, mask, seq_len, + config=config, dropout=dropout_rate, - softmax_scale=getattr(config, "scaling", 1.0), - is_causal=getattr(config, "is_causal", False), - sliding_window=getattr(config, "sliding_window", None), - use_top_left_mask=getattr(config, "_flash_attn_uses_top_left_mask", False), layer_idx=layer_idx, **kwargs, ) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 44d367d3940f77..d0c5e92623ebd8 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -45,9 +45,9 @@ from .dynamic_module_utils import custom_object_save from .generation import GenerationConfig, GenerationMixin from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled -from .integrations.flash_attention import * -from .integrations.flex_attention import * -from .integrations.sdpa_attention import * +from .integrations.flash_attention import flash_attention_forward +from .integrations.flex_attention import flex_attention_forward +from .integrations.sdpa_attention import sdpa_attention_forward from .loss.loss_utils import LOSS_MAPPING from .pytorch_utils import ( # noqa: F401 Conv1D, From 13a195a7bba4f365eaa06038354ad84b48455947 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 13:53:35 +0100 Subject: [PATCH 06/40] _output_embedding and _input_embeding --- src/transformers/modeling_utils.py | 6 +++--- src/transformers/models/auto/modeling_task.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d0c5e92623ebd8..1495b9cd87f282 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1830,7 +1830,7 @@ def get_input_embeddings(self) -> nn.Module: if base_model is not self: return base_model.get_input_embeddings() else: - raise NotImplementedError + return getattr(self, self._embeding_layer) def set_input_embeddings(self, value: nn.Module): """ @@ -1843,7 +1843,7 @@ def set_input_embeddings(self, value: nn.Module): if base_model is not self: base_model.set_input_embeddings(value) else: - raise NotImplementedError + raise setattr(self, self._embeding_layer, value) def get_output_embeddings(self) -> nn.Module: """ @@ -1852,7 +1852,7 @@ def get_output_embeddings(self) -> nn.Module: Returns: `nn.Module`: A torch module mapping hidden states to vocabulary. """ - return None # Overwrite for models with output embeddings + return getattr(self, self._output_embedding, None) def _init_weights(self, module): """ diff --git a/src/transformers/models/auto/modeling_task.py b/src/transformers/models/auto/modeling_task.py index 279321f1ef4bc0..751751a44d289e 100644 --- a/src/transformers/models/auto/modeling_task.py +++ b/src/transformers/models/auto/modeling_task.py @@ -1,6 +1,7 @@ import torch +from transformers import PreTrainedModel -class AutoForCausalLM(LlamaPreTrainedModel, GenerationMixin): +class AutoForCausalLM(PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _embeding_layer = "model.embed_tokens" From 39ab8b757b391f39a2c947e5c0b9e403d2f0f324 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 13:58:09 +0100 Subject: [PATCH 07/40] oupts --- .../modeling_flash_attention_utils.py | 17 ++++++++++++----- src/transformers/models/llama/modeling_llama.py | 12 ++++-------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 52051e3f52d140..b9e2214d504ab6 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -348,7 +348,7 @@ class TransformersKwargs(TypedDict, total=False): from functools import wraps from typing import Callable, TypedDict, Optional - +from logging import logger def validate_config_kwargs(config): @@ -358,12 +358,13 @@ def validate_config_kwargs(config): def decorator(func: Callable): @wraps(func) def wrapper(*args, **kwargs): + self = args[0] # Default values from the config default_kwargs = { - "output_attentions": config.output_attentions, - "output_hidden_states": config.output_hidden_states, - "use_cache": config.use_cache, - "return_dict": config.use_return_dict, + "output_attentions": self.config.output_attentions, + "output_hidden_states": self.config.output_hidden_states, + "use_cache": self.config.use_cache, + "return_dict": self.config.use_return_dict, } # Merge provided kwargs with defaults @@ -374,6 +375,12 @@ def wrapper(*args, **kwargs): if key not in TransformersKwargs.__annotations__: raise ValueError(f"Invalid keyword argument: {key}") + if self.gradient_checkpointing and self.training and default_kwargs["use_cache"]: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + validated_kwargs["use_cache"] = False + # Pass the validated kwargs to the function return func(*args, **validated_kwargs) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c375c71e99ebcb..5669facb90d4d4 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Callable import torch import torch.utils.checkpoint @@ -283,7 +283,7 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: function = eager_attention_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -482,7 +482,7 @@ class LlamaModel(LlamaPreTrainedModel): Args: config: LlamaConfig """ - _input_embedding = "embed_tokens" # no need for set and get, take then from PreTrainedModel + _input_embedding = "embed_tokens" def __init__(self, config: LlamaConfig): super().__init__(config) @@ -517,11 +517,7 @@ def forward( if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) From 0418f97553fca6f1425070a0049b5ecad7a9ed3b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 14:18:41 +0100 Subject: [PATCH 08/40] make auto for causal lm work --- src/transformers/__init__.py | 3 ++ .../modeling_flash_attention_utils.py | 49 ++----------------- src/transformers/modeling_utils.py | 2 +- src/transformers/models/auto/__init__.py | 4 ++ src/transformers/models/auto/modeling_auto.py | 2 +- src/transformers/models/auto/modeling_task.py | 33 +++++++++---- .../models/llama/modeling_llama.py | 15 +++--- src/transformers/utils/generic.py | 40 ++++++++++++++- 8 files changed, 83 insertions(+), 65 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index fa54ced6a13486..b67362924d6cdb 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -185,6 +185,7 @@ "AutoImageProcessor", "AutoProcessor", "AutoTokenizer", + "AutoForCausalLM", ], "models.autoformer": ["AutoformerConfig"], "models.bark": [ @@ -5039,6 +5040,7 @@ AutoImageProcessor, AutoProcessor, AutoTokenizer, + AutoForCausalLM, ) from .models.autoformer import ( AutoformerConfig, @@ -6200,6 +6202,7 @@ from .utils.dummy_pt_objects import * else: # Benchmarks + from .models.auto.modeling_task import AutoForCausalLM from .benchmark.benchmark import PyTorchBenchmark from .benchmark.benchmark_args import PyTorchBenchmarkArguments from .cache_utils import ( diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index b9e2214d504ab6..0e8d83778d813b 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -20,7 +20,11 @@ import torch import torch.nn.functional as F -from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal +from functools import wraps +from typing import Callable, TypedDict, Optional +import logging + +from .utils.import_utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal if is_flash_attn_2_available(): @@ -344,46 +348,3 @@ class TransformersKwargs(TypedDict, total=False): use_cache: Optional[bool] return_dict: Optional[bool] - - -from functools import wraps -from typing import Callable, TypedDict, Optional -from logging import logger - - -def validate_config_kwargs(config): - """ - A decorator to validate and initialize kwargs based on a config object. - """ - def decorator(func: Callable): - @wraps(func) - def wrapper(*args, **kwargs): - self = args[0] - # Default values from the config - default_kwargs = { - "output_attentions": self.config.output_attentions, - "output_hidden_states": self.config.output_hidden_states, - "use_cache": self.config.use_cache, - "return_dict": self.config.use_return_dict, - } - - # Merge provided kwargs with defaults - validated_kwargs = {**default_kwargs, **kwargs} - - # Validate kwargs against TypedDict - for key in validated_kwargs: - if key not in TransformersKwargs.__annotations__: - raise ValueError(f"Invalid keyword argument: {key}") - - if self.gradient_checkpointing and self.training and default_kwargs["use_cache"]: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - validated_kwargs["use_cache"] = False - - # Pass the validated kwargs to the function - return func(*args, **validated_kwargs) - - return wrapper - - return decorator \ No newline at end of file diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1495b9cd87f282..729e45d7c8dd11 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5489,7 +5489,7 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): -ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, function]] = {} +ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, Callable]] = {} ALL_ATTENTION_FUNCTIONS.update( { diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 2ee0541a1a71b8..c7e39f4da94608 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -122,6 +122,9 @@ "AutoModelForZeroShotObjectDetection", "AutoModelForImageTextToText", ] + _import_structure["modeling_task"] = [ + "AutoForCausalLM", + ] try: if not is_tf_available(): @@ -311,6 +314,7 @@ AutoModelForZeroShotObjectDetection, AutoModelWithLMHead, ) + from .modeling_task import AutoForCausalLM try: if not is_tf_available(): diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 2c519a7dc42ca5..406651da2a3132 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -500,7 +500,7 @@ ("granitemoe", "GraniteMoeForCausalLM"), ("jamba", "JambaForCausalLM"), ("jetmoe", "JetMoeForCausalLM"), - ("llama", "LlamaForCausalLM"), + ("llama", "AutoForCausalLM"), ("mamba", "MambaForCausalLM"), ("mamba2", "Mamba2ForCausalLM"), ("marian", "MarianForCausalLM"), diff --git a/src/transformers/models/auto/modeling_task.py b/src/transformers/models/auto/modeling_task.py index 751751a44d289e..6dbdd1d15f119c 100644 --- a/src/transformers/models/auto/modeling_task.py +++ b/src/transformers/models/auto/modeling_task.py @@ -1,12 +1,26 @@ import torch -from transformers import PreTrainedModel +import torch.nn as nn +from typing import Optional, List, Union, Unpack, Tuple, Dict + +from ...modeling_utils import PreTrainedModel +from ...generation import GenerationMixin +from ..auto import AutoModel +from ...cache_utils import Cache, DynamicCache, StaticCache + +from ...utils.generic import KwargsForCausalLM, validate_config_kwargs +from ...modeling_outputs import ( + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + CausalLMOutputWithPast, + TokenClassifierOutput +) class AutoForCausalLM(PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _embeding_layer = "model.embed_tokens" _output_embedding = "lm_head" - + _no_split_modules = [] def __init__(self, config): super().__init__(config) self.model = AutoModel.from_config(config) @@ -75,11 +89,11 @@ def forward( ) -class AutoForSequenceClassification(LlamaPreTrainedModel): +class AutoForSequenceClassification(PreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.model = LlamaModel(config) + self.model = AutoModel.from_config(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing @@ -158,13 +172,13 @@ def forward( -class AutoForQuestionAnswering(LlamaPreTrainedModel): +class AutoForQuestionAnswering(PreTrainedModel): base_model_prefix = "transformer" # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama def __init__(self, config): super().__init__(config) - self.transformer = LlamaModel(config) + self.transformer = AutoModel.from_config(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) # Initialize weights and apply final processing @@ -227,11 +241,11 @@ def forward( ) -class AutoForTokenClassification(LlamaPreTrainedModel): +class AutoForTokenClassification(PreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.model = LlamaModel(config) + self.model = AutoModel.from_config(config) if getattr(config, "classifier_dropout", None) is not None: classifier_dropout = config.classifier_dropout elif getattr(config, "hidden_dropout", None) is not None: @@ -259,6 +273,7 @@ def forward( past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, + return_dict = False, **kwargs, ) -> Union[Tuple, TokenClassifierOutput]: r""" @@ -267,8 +282,6 @@ def forward( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( input_ids, attention_mask=attention_mask, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5669facb90d4d4..aac5e9470a231c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -35,13 +35,12 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...utils.generic import validate_config_kwargs +from ...utils.generic import validate_config_kwargs, KwargsForCausalLM from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, GradientCheckpointLayer from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( - KwargsForCausalLM, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, @@ -251,13 +250,13 @@ def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) - self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) def forward( self, diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 5173c962da5b32..f32497a6e1a55c 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -24,7 +24,7 @@ from dataclasses import fields, is_dataclass from enum import Enum from functools import partial, wraps -from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict +from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict, Callable import numpy as np from packaging import version @@ -871,3 +871,41 @@ class LossKwargs(TypedDict, total=False): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + + + + +def validate_config_kwargs(config): + """ + A decorator to validate and initialize kwargs based on a config object. + """ + def decorator(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + self = args[0] + # Default values from the config + default_kwargs = { + "output_attentions": self.config.output_attentions, + "output_hidden_states": self.config.output_hidden_states, + "use_cache": self.config.use_cache, + "return_dict": self.config.use_return_dict, + } + + # Merge provided kwargs with defaults + validated_kwargs = {**default_kwargs, **kwargs} + + # Validate kwargs against TypedDict + for key in validated_kwargs: + if key not in KwargsForCausalLM.__annotations__: + raise ValueError(f"Invalid keyword argument: {key}") + + if self.gradient_checkpointing and self.training and default_kwargs["use_cache"]: + validated_kwargs["use_cache"] = False + + # Pass the validated kwargs to the function + return func(*args, **validated_kwargs) + + return wrapper + + return decorator \ No newline at end of file From 341b8ce9fa8538b08e0cfaeb7c5a8adfc8286ffa Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 14:27:20 +0100 Subject: [PATCH 09/40] nits --- .../models/llama/modeling_llama.py | 3 +- src/transformers/utils/generic.py | 65 +++++++++++-------- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index aac5e9470a231c..2df1852aba0ec2 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -267,7 +267,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - input_shape, _ = hidden_states.shape[:-1] + input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape) @@ -382,6 +382,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + gradient_checkpointing = False def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index f32497a6e1a55c..159b4c9d6f518c 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -24,7 +24,7 @@ from dataclasses import fields, is_dataclass from enum import Enum from functools import partial, wraps -from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict, Callable +from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict, Callable, Union import numpy as np from packaging import version @@ -39,6 +39,7 @@ is_torch_fx_proxy, ) +import torch class cached_property(property): """ @@ -870,42 +871,52 @@ class LossKwargs(TypedDict, total=False): num_items_in_batch: Optional[int] -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): + input_ids: torch.LongTensor = None + attention_mask: Optional[torch.Tensor] = None + position_ids: Optional[torch.LongTensor] = None + past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None + inputs_embeds: Optional[torch.FloatTensor] = None + labels: Optional[torch.LongTensor] = None + use_cache: Optional[bool] = None + output_attentions: Optional[bool] = None + output_hidden_states: Optional[bool] = None + return_dict: Optional[bool] = None + cache_position: Optional[torch.LongTensor] = None + num_logits_to_keep: int = 0 - -def validate_config_kwargs(config): +def validate_config_kwargs(func): """ A decorator to validate and initialize kwargs based on a config object. """ - def decorator(func: Callable): - @wraps(func) - def wrapper(*args, **kwargs): - self = args[0] - # Default values from the config - default_kwargs = { - "output_attentions": self.config.output_attentions, - "output_hidden_states": self.config.output_hidden_states, - "use_cache": self.config.use_cache, - "return_dict": self.config.use_return_dict, - } - # Merge provided kwargs with defaults - validated_kwargs = {**default_kwargs, **kwargs} + @wraps(func) + def wrapper(*args, **kwargs): + self = args[0] + # Default values from the config + default_kwargs = { + "output_attentions": self.config.output_attentions, + "output_hidden_states": self.config.output_hidden_states, + "use_cache": self.config.use_cache, + "return_dict": self.config.use_return_dict, + } - # Validate kwargs against TypedDict - for key in validated_kwargs: - if key not in KwargsForCausalLM.__annotations__: - raise ValueError(f"Invalid keyword argument: {key}") + # Merge provided kwargs with defaults + validated_kwargs = {**default_kwargs, **kwargs} - if self.gradient_checkpointing and self.training and default_kwargs["use_cache"]: - validated_kwargs["use_cache"] = False + # Validate kwargs against TypedDict + for key in validated_kwargs: + if key not in KwargsForCausalLM.__annotations__: + raise ValueError(f"Invalid keyword argument: {key}") - # Pass the validated kwargs to the function - return func(*args, **validated_kwargs) + if self.gradient_checkpointing and self.training and default_kwargs["use_cache"]: + validated_kwargs["use_cache"] = False - return wrapper + # Pass the validated kwargs to the function + return func(*args, **validated_kwargs) + + return wrapper - return decorator \ No newline at end of file From 556aa4ec2de158c4a973118fe0ecaabb23dcab57 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 14:35:53 +0100 Subject: [PATCH 10/40] updates --- src/transformers/models/auto/modeling_task.py | 19 +------------------ .../models/llama/modeling_llama.py | 12 +++++++----- 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/auto/modeling_task.py b/src/transformers/models/auto/modeling_task.py index 6dbdd1d15f119c..96ee693312f3c0 100644 --- a/src/transformers/models/auto/modeling_task.py +++ b/src/transformers/models/auto/modeling_task.py @@ -37,39 +37,22 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + num_logits_to_keep: int = 1, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, return_dict=return_dict, - cache_position=cache_position, **kwargs, ) hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 2df1852aba0ec2..58f02ef00794cf 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -215,7 +215,7 @@ def forward(self, x): return down_proj -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: +def repeat_kv(hidden_states: torch.Tensor, n_rep: int, dim=2) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) @@ -223,16 +223,17 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states.unsqueeze(dim).expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward(attention_class:nn.Module, query, key, value, mask, **_kwargs): config = attention_class.config - key_states = repeat_kv(key, config.num_key_value_groups) - value_states = repeat_kv(value, config.num_key_value_groups) + query, key, value = [x.transpose(1,2) for x in (query, key, value)] + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling if mask is not None: causal_mask = mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask @@ -252,6 +253,7 @@ def __init__(self, config: LlamaConfig, layer_idx: int): self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim ** -0.5 self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) From f61a5fec413fb79509ea07d6d6e04d4333b81448 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 14:37:12 +0100 Subject: [PATCH 11/40] pass attention --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 58f02ef00794cf..e5622b01e5308b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -297,7 +297,7 @@ def forward( **kwargs, ) - attn_output = attn_output.reshape(*hidden_shape, -1) + attn_output = attn_output.reshape(*input_shape, -1) attn_output = self.o_proj(attn_output) return attn_output, attn_weights From dcf7a37ce17e4553ddb8870d535daf4d03a97f49 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 14:38:26 +0100 Subject: [PATCH 12/40] cache concatenates on the wrong axis --- src/transformers/cache_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index f3f0bd6fe5458f..eafd36f28337e1 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -447,8 +447,8 @@ def update( self.key_cache[layer_idx] = key_states self.value_cache[layer_idx] = value_states else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=1) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=1) return self.key_cache[layer_idx], self.value_cache[layer_idx] From 1baabd3207f9d3398fc067a21f26ae1eb89d10da Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 15:13:39 +0100 Subject: [PATCH 13/40] update --- src/transformers/modeling_utils.py | 13 +- .../models/llama/modeling_llama.py | 158 ++---------------- 2 files changed, 22 insertions(+), 149 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 729e45d7c8dd11..52770dfa2738b5 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2537,6 +2537,10 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]): self.base_model._prune_heads(heads_to_prune) + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + for layer in list(self.modules()): + if isinstance(layer, GradientCheckpointLayer): + layer.gradient_checkpointing_enable(gradient_checkpointing_kwargs) @property def is_gradient_checkpointing(self) -> bool: @@ -5501,6 +5505,7 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): class GradientCheckpointLayer(torch.nn.Module): + def __call__(self, *args, **kwargs): """ Adjust the behavior of the inherited class by overriding `__call__`. @@ -5508,7 +5513,7 @@ def __call__(self, *args, **kwargs): Automatically handles gradient checkpointing based on flags in the provided arguments. """ # Extract necessary flags and arguments - gradient_checkpointing = kwargs.pop("gradient_checkpointing", False) + gradient_checkpointing = kwargs.pop("gradient_checkpointing", False) | getattr(self, "gradient_checkpointing", False) training = self.training if gradient_checkpointing and training: @@ -5525,8 +5530,7 @@ def _apply_gradient_checkpointing(self, *args, **kwargs): By default, uses `torch.utils.checkpoint.checkpoint`. """ # Assume `self.forward` is compatible with checkpointing - return checkpoint(self.__call__, *args, **kwargs) - + return checkpoint(self.__call__, *args, **kwargs) def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): @@ -5543,8 +5547,7 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): gradient_checkpointing_kwargs (dict, *optional*): Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. """ - if not self.supports_gradient_checkpointing: - raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + self.gradient_checkpointing = True if gradient_checkpointing_kwargs is None: gradient_checkpointing_kwargs = {"use_reentrant": True} diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e5622b01e5308b..89c05546d3a904 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -25,27 +25,23 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache -from ...generation import GenerationMixin + from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, - CausalLMOutputWithPast, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, ) -from ...utils.generic import validate_config_kwargs, KwargsForCausalLM +from ...utils.generic import validate_config_kwargs from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, GradientCheckpointLayer from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, + + + logging, - replace_return_docstrings, + ) from .configuration_llama import LlamaConfig @@ -227,15 +223,15 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int, dim=2) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class:nn.Module, query, key, value, mask, **_kwargs): +def eager_attention_forward(attention_class:nn.Module, query, key, value, attention_mask=None, **_kwargs): config = attention_class.config query, key, value = [x.transpose(1,2) for x in (query, key, value)] key_states = repeat_kv(key, attention_class.num_key_value_groups) value_states = repeat_kv(value, attention_class.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling - if mask is not None: - causal_mask = mask[:, :, :, : key_states.shape[-2]] + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) @@ -263,8 +259,7 @@ def __init__(self, config: LlamaConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], # will become mandatory in v4.46 - attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -289,12 +284,7 @@ def forward( attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - **kwargs, + self, query_states, key_states, value_states, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1) @@ -316,10 +306,6 @@ def __init__(self, config: LlamaConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - position_embeddings, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -328,14 +314,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + hidden_states, self_attn_weights = self.self_attn(hidden_states=hidden_states, **kwargs) hidden_states = residual + hidden_states # Fully Connected @@ -352,27 +331,6 @@ def forward( return outputs -LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) class LlamaPreTrainedModel(PreTrainedModel): config_class = LlamaConfig base_model_prefix = "model" @@ -397,93 +355,7 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ _input_embedding = "embed_tokens" def __init__(self, config: LlamaConfig): @@ -519,8 +391,6 @@ def forward( if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -529,7 +399,7 @@ def forward( if cache_position is None: cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) - + if position_ids is None: position_ids = cache_position.unsqueeze(0) From 38dd294dd72183713faa7515923a4c3ebb73ad7e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 16:10:40 +0100 Subject: [PATCH 14/40] fix --- src/transformers/modeling_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 52770dfa2738b5..44d0f18f730c55 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5521,7 +5521,7 @@ def __call__(self, *args, **kwargs): return self._apply_gradient_checkpointing(*args, **kwargs) else: # Default behavior: call the original `forward` method - return super().__call__(*args, **kwargs) + return self.forward(*args, **kwargs) def _apply_gradient_checkpointing(self, *args, **kwargs): """ @@ -5530,7 +5530,9 @@ def _apply_gradient_checkpointing(self, *args, **kwargs): By default, uses `torch.utils.checkpoint.checkpoint`. """ # Assume `self.forward` is compatible with checkpointing - return checkpoint(self.__call__, *args, **kwargs) + def wrapped_forward(): + return self.forward(*args, **kwargs) + return self._gradient_checkpointing_func(wrapped_forward) def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): From 40154815cb393f43628e171f3d10d0d8c4c97bb0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 16:23:18 +0100 Subject: [PATCH 15/40] revert some stuff --- src/transformers/cache_utils.py | 4 +- .../models/llama/modeling_llama.py | 44 ++++++------------- 2 files changed, 15 insertions(+), 33 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index eafd36f28337e1..f3f0bd6fe5458f 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -447,8 +447,8 @@ def update( self.key_cache[layer_idx] = key_states self.value_cache[layer_idx] = value_states else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=1) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=1) + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) return self.key_cache[layer_idx], self.value_cache[layer_idx] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 89c05546d3a904..83cbc16a5e89d4 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -86,31 +86,14 @@ def __init__( config: Optional[LlamaConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -211,7 +194,7 @@ def forward(self, x): return down_proj -def repeat_kv(hidden_states: torch.Tensor, n_rep: int, dim=2) -> torch.Tensor: +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) @@ -219,13 +202,12 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int, dim=2) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states.unsqueeze(dim).expand(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward(attention_class:nn.Module, query, key, value, attention_mask=None, **_kwargs): config = attention_class.config - query, key, value = [x.transpose(1,2) for x in (query, key, value)] key_states = repeat_kv(key, attention_class.num_key_value_groups) value_states = repeat_kv(value, attention_class.num_key_value_groups) @@ -267,12 +249,12 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states).view(hidden_shape) - key_states = self.k_proj(hidden_states).view(hidden_shape) - value_states = self.v_proj(hidden_states).view(hidden_shape) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1,2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1,2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1,2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=2) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -287,7 +269,7 @@ def forward( self, query_states, key_states, value_states, **kwargs, ) - attn_output = attn_output.reshape(*input_shape, -1) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights From 28829d2dd68a0c473ed630113803b6bdff2ae785 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 16:44:04 +0100 Subject: [PATCH 16/40] there was an issue with tie weight keys --- src/transformers/models/auto/modeling_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/auto/modeling_task.py b/src/transformers/models/auto/modeling_task.py index 96ee693312f3c0..29ca5fbf59dbb4 100644 --- a/src/transformers/models/auto/modeling_task.py +++ b/src/transformers/models/auto/modeling_task.py @@ -51,7 +51,7 @@ def forward( return_dict=return_dict, **kwargs, ) - + self.lm_head.weight.data = self.model.embed_tokens.weight.data hidden_states = outputs[0] logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) From 1ef18f49a93e989b4dc97530d836aad1df7509ea Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 16:53:02 +0100 Subject: [PATCH 17/40] style --- src/transformers/__init__.py | 4 ++-- .../integrations/flash_attention.py | 6 ++++-- .../integrations/flex_attention.py | 1 + .../integrations/sdpa_attention.py | 2 +- .../modeling_flash_attention_utils.py | 4 ---- src/transformers/models/auto/modeling_task.py | 19 ++++++++++--------- .../models/llama/modeling_llama.py | 11 +++-------- src/transformers/utils/generic.py | 4 ++-- 8 files changed, 23 insertions(+), 28 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b67362924d6cdb..9d50333ab72e7e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -5037,10 +5037,10 @@ TOKENIZER_MAPPING, AutoConfig, AutoFeatureExtractor, + AutoForCausalLM, AutoImageProcessor, AutoProcessor, AutoTokenizer, - AutoForCausalLM, ) from .models.autoformer import ( AutoformerConfig, @@ -6202,7 +6202,6 @@ from .utils.dummy_pt_objects import * else: # Benchmarks - from .models.auto.modeling_task import AutoForCausalLM from .benchmark.benchmark import PyTorchBenchmark from .benchmark.benchmark_args import PyTorchBenchmarkArguments from .cache_utils import ( @@ -6405,6 +6404,7 @@ AutoModelForZeroShotObjectDetection, AutoModelWithLMHead, ) + from .models.auto.modeling_task import AutoForCausalLM from .models.autoformer import ( AutoformerForPrediction, AutoformerModel, diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index c1ba6a7720bc22..95a31c7692c142 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -1,5 +1,7 @@ import torch -from ..modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward + +from ..modeling_flash_attention_utils import _flash_attention_forward + def flash_attention_forward( config, query, key, value, mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs @@ -31,4 +33,4 @@ def flash_attention_forward( **kwargs, ) - return attn_output, None \ No newline at end of file + return attn_output, None diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 97bf10f9c41fbd..b4c9ed7f261f16 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -1,5 +1,6 @@ import torch + def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 96b7556e784d1a..78f7ada8081eb6 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -1,4 +1,4 @@ -import torch +import torch def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 0e8d83778d813b..eba05e05f3191d 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -20,10 +20,6 @@ import torch import torch.nn.functional as F -from functools import wraps -from typing import Callable, TypedDict, Optional -import logging - from .utils.import_utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal diff --git a/src/transformers/models/auto/modeling_task.py b/src/transformers/models/auto/modeling_task.py index 29ca5fbf59dbb4..b38188a0f84571 100644 --- a/src/transformers/models/auto/modeling_task.py +++ b/src/transformers/models/auto/modeling_task.py @@ -1,19 +1,20 @@ +from typing import List, Optional, Tuple, Union, Unpack + import torch import torch.nn as nn -from typing import Optional, List, Union, Unpack, Tuple, Dict -from ...modeling_utils import PreTrainedModel +from ...cache_utils import Cache from ...generation import GenerationMixin -from ..auto import AutoModel -from ...cache_utils import Cache, DynamicCache, StaticCache - -from ...utils.generic import KwargsForCausalLM, validate_config_kwargs from ...modeling_outputs import ( + CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, - CausalLMOutputWithPast, - TokenClassifierOutput + TokenClassifierOutput, ) +from ...modeling_utils import PreTrainedModel +from ...utils.generic import KwargsForCausalLM, validate_config_kwargs +from ..auto import AutoModel + class AutoForCausalLM(PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] @@ -51,7 +52,7 @@ def forward( return_dict=return_dict, **kwargs, ) - self.lm_head.weight.data = self.model.embed_tokens.weight.data + self.lm_head.weight.data = self.model.embed_tokens.weight.data # TODO fix me! hidden_states = outputs[0] logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 83cbc16a5e89d4..3a0f942beb798b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union, Callable +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -25,24 +25,19 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache - from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, ) -from ...utils.generic import validate_config_kwargs from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, GradientCheckpointLayer +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( - - - logging, - ) +from ...utils.generic import validate_config_kwargs from .configuration_llama import LlamaConfig diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 159b4c9d6f518c..0ec52b5f440c40 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -24,9 +24,10 @@ from dataclasses import fields, is_dataclass from enum import Enum from functools import partial, wraps -from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict, Callable, Union +from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict, Union import numpy as np +import torch from packaging import version from ..modeling_flash_attention_utils import FlashAttentionKwargs @@ -39,7 +40,6 @@ is_torch_fx_proxy, ) -import torch class cached_property(property): """ From 4b9a429a1cca77d078b8204d330c2aef7856514c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 16:54:35 +0100 Subject: [PATCH 18/40] style --- src/transformers/integrations/flex_attention.py | 5 +++++ src/transformers/integrations/sdpa_attention.py | 11 +++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index b4c9ed7f261f16..d8bd8df765297c 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -1,5 +1,10 @@ import torch +from ...utils import is_torch_greater_or_equal + + +if is_torch_greater_or_equal("2.5"): + from torch.nn.attention.flex_attention import flex_attention def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): def tanh_softcap(score, b, h, q_idx, kv_idx): diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 78f7ada8081eb6..328d63e0217c38 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -1,6 +1,17 @@ import torch +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): key = repeat_kv(key, config.num_key_value_groups) value = repeat_kv(value, config.num_key_value_groups) From e5d60b4f23ea734f5610d1215dab619e09f46884 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 17:07:20 +0100 Subject: [PATCH 19/40] fix --- src/transformers/integrations/flash_attention.py | 8 ++++---- src/transformers/integrations/flex_attention.py | 2 +- src/transformers/integrations/sdpa_attention.py | 14 +++++++------- src/transformers/modeling_utils.py | 8 ++++---- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 95a31c7692c142..2189baef8158b6 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -4,10 +4,10 @@ def flash_attention_forward( - config, query, key, value, mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs + config, query, key, value, attentions_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs ): - if mask is not None: - seq_len = mask.shape[1] + if attentions_mask is not None: + seq_len = attentions_mask.shape[1] query = query[:, :, :seq_len] value = value[:, :, :seq_len] else: @@ -25,7 +25,7 @@ def flash_attention_forward( query, key, value, - mask, + attentions_mask, seq_len, config=config, dropout=dropout_rate, diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index d8bd8df765297c..ca6977f7ac2334 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -1,6 +1,6 @@ import torch -from ...utils import is_torch_greater_or_equal +from ..utils import is_torch_greater_or_equal if is_torch_greater_or_equal("2.5"): diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 328d63e0217c38..65c74ca3aa5eb5 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -12,12 +12,12 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): - key = repeat_kv(key, config.num_key_value_groups) - value = repeat_kv(value, config.num_key_value_groups) +def sdpa_attention_forward(module, query, key, value, attentions_mask=None, **_kwargs): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) - causal_mask = mask - if mask is not None: + causal_mask = attentions_mask + if attentions_mask is not None: causal_mask = causal_mask[:, :, :, : key.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, @@ -36,8 +36,8 @@ def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): key, value, attn_mask=causal_mask, - dropout_p=config.attention_dropout if config.training else 0.0, + dropout_p=module.config.attention_dropout if module.training else 0.0, is_causal=is_causal, - scale=config.scaling, + scale=module.scaling, ) return attn_output, None diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 44d0f18f730c55..d43c63d630874d 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1583,10 +1583,10 @@ def _autoset_attn_implementation( ) elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. - config = cls._check_and_enable_sdpa( - config, - hard_check_only=False if requested_attn_implementation is None else True, - ) + # config = cls._check_and_enable_sdpa( + # config, + # hard_check_only=False if requested_attn_implementation is None else True, + # ) if ( torch.version.hip is not None From 3bbae395394259523ae319c9aa3bd17490c351d2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 17:24:56 +0100 Subject: [PATCH 20/40] remove tanh --- .../integrations/flex_attention.py | 22 +++++-------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index ca6977f7ac2334..b1ea9cbf1057f6 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -1,29 +1,17 @@ -import torch - from ..utils import is_torch_greater_or_equal if is_torch_greater_or_equal("2.5"): from torch.nn.attention.flex_attention import flex_attention -def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): - def tanh_softcap(score, b, h, q_idx, kv_idx): - soft_cap = config.attn_logit_softcapping - score = soft_cap * torch.tanh(score / soft_cap) - if mask is not None: - return score + mask[b][0][q_idx][kv_idx] - return score +def flex_attention_forward(module, query, key, value, attention_mask, output_attentions=False, **_kwargs): - attn_output = flex_attention( + attn_output, attention_weights = flex_attention( query, key, value, - score_mod=tanh_softcap, enable_gqa=True, - scale=config.scaling, - return_lse=output_attentions, + scale=module.scaling, + return_lse=True, ) - if not output_attentions: - return attn_output, None - else: - return attn_output[0], attn_output[1] + return attn_output, attention_weights From 89d32d6825053d9f700b533407f8a9b2848025f0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 17:30:26 +0100 Subject: [PATCH 21/40] fix auto set --- src/transformers/modeling_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d43c63d630874d..efa15b073b13a7 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1597,6 +1597,8 @@ def _autoset_attn_implementation( "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." ) torch.backends.cuda.enable_flash_sdp(False) + elif config._attn_implementation in ALL_ATTENTION_FUNCTIONS: + pass elif isinstance(requested_attn_implementation, dict): config._attn_implementation = None else: From 7a911efddf06c8b18569a2d4c56398f4bb9598f3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 17:39:40 +0100 Subject: [PATCH 22/40] update --- src/transformers/integrations/flex_attention.py | 11 +++++++++++ src/transformers/integrations/sdpa_attention.py | 1 + 2 files changed, 12 insertions(+) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index b1ea9cbf1057f6..1694a4f5b517f4 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -6,12 +6,23 @@ def flex_attention_forward(module, query, key, value, attention_mask, output_attentions=False, **_kwargs): + causal_mask = attention_mask + if causal_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + def causal_mod(score, b, h, q_idx, kv_idx): + if causal_mask is not None: + score += causal_mask[b][0][q_idx][kv_idx] + return score + attn_output, attention_weights = flex_attention( query, key, value, + score_mod=causal_mod, enable_gqa=True, scale=module.scaling, return_lse=True, ) + attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attention_weights diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 65c74ca3aa5eb5..33d2e85e473688 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -40,4 +40,5 @@ def sdpa_attention_forward(module, query, key, value, attentions_mask=None, **_k is_causal=is_causal, scale=module.scaling, ) + attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None From 20c512bc800173ae5d7ea88f97a9e6a552cd62fd Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 18:07:53 +0100 Subject: [PATCH 23/40] clean --- src/transformers/integrations/sdpa_attention.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 33d2e85e473688..ad28411667c055 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -20,17 +20,11 @@ def sdpa_attention_forward(module, query, key, value, attentions_mask=None, **_k if attentions_mask is not None: causal_mask = causal_mask[:, :, :, : key.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query.device.type == "cuda" and causal_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and query.shape[1] > 1 else False - attn_output = torch.nn.functional.scaled_dot_product_attention( query, key, From d9156363bf57d12eb2946c34bebcc4ffcb2753a7 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 18:49:37 +0100 Subject: [PATCH 24/40] mm --- src/transformers/modeling_flash_attention_utils.py | 8 -------- src/transformers/models/auto/modeling_task.py | 3 ++- src/transformers/models/llama/modeling_llama.py | 7 +++---- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index eba05e05f3191d..9fd8703f5e7ffb 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -336,11 +336,3 @@ class FlashAttentionKwargs(TypedDict, total=False): cu_seq_lens_k: Optional[torch.LongTensor] max_length_q: Optional[int] max_length_k: Optional[int] - - -class TransformersKwargs(TypedDict, total=False): - output_attentions: Optional[bool] - output_hidden_states: Optional[bool] - use_cache: Optional[bool] - return_dict: Optional[bool] - diff --git a/src/transformers/models/auto/modeling_task.py b/src/transformers/models/auto/modeling_task.py index b38188a0f84571..c4bc19e3f6df4e 100644 --- a/src/transformers/models/auto/modeling_task.py +++ b/src/transformers/models/auto/modeling_task.py @@ -22,6 +22,8 @@ class AutoForCausalLM(PreTrainedModel, GenerationMixin): _embeding_layer = "model.embed_tokens" _output_embedding = "lm_head" _no_split_modules = [] + _supports_cache_class = True + def __init__(self, config): super().__init__(config) self.model = AutoModel.from_config(config) @@ -31,7 +33,6 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3a0f942beb798b..d9de7225e569ab 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -249,7 +249,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1,2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -333,7 +333,7 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() class LlamaModel(LlamaPreTrainedModel): - _input_embedding = "embed_tokens" + _embedding_layer = "embed_tokens" def __init__(self, config: LlamaConfig): super().__init__(config) @@ -380,7 +380,6 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) @@ -423,7 +422,7 @@ def forward( output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) From 60189825d7fca0471e63f8b80354788d19782790 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 19:09:09 +0100 Subject: [PATCH 25/40] fix! --- src/transformers/models/auto/modeling_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/auto/modeling_task.py b/src/transformers/models/auto/modeling_task.py index c4bc19e3f6df4e..0a18a365a8d4ae 100644 --- a/src/transformers/models/auto/modeling_task.py +++ b/src/transformers/models/auto/modeling_task.py @@ -53,7 +53,7 @@ def forward( return_dict=return_dict, **kwargs, ) - self.lm_head.weight.data = self.model.embed_tokens.weight.data # TODO fix me! + # self.lm_head.weight.data = self.model.embed_tokens.weight.data # TODO fix me! hidden_states = outputs[0] logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) From e9d751abaafda02978695baba3b0f32b97ed8b8a Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 19:26:39 +0100 Subject: [PATCH 26/40] fix attention_mask --- src/transformers/integrations/sdpa_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index ad28411667c055..8a0de7769fcbad 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -12,12 +12,12 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def sdpa_attention_forward(module, query, key, value, attentions_mask=None, **_kwargs): +def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kwargs): key = repeat_kv(key, module.num_key_value_groups) value = repeat_kv(value, module.num_key_value_groups) - causal_mask = attentions_mask - if attentions_mask is not None: + causal_mask = attention_mask + if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key.shape[-2]] query = query.contiguous() From 7a608da9f84051eeb59a2cd6fa8126e189d86ae7 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 19:44:29 +0100 Subject: [PATCH 27/40] update --- src/transformers/models/auto/__init__.py | 10 ++- src/transformers/models/auto/auto_factory.py | 4 +- src/transformers/models/auto/modeling_auto.py | 8 +-- src/transformers/models/llama/__init__.py | 8 --- tests/models/llama/test_modeling_llama.py | 72 ++++++++++--------- 5 files changed, 53 insertions(+), 49 deletions(-) diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index c7e39f4da94608..7e4d34ec3e1a77 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -124,6 +124,9 @@ ] _import_structure["modeling_task"] = [ "AutoForCausalLM", + "AutoForSequenceClassification", + "AutoForQuestionAnswering", + "AutoForTokenClassification", ] try: @@ -314,7 +317,12 @@ AutoModelForZeroShotObjectDetection, AutoModelWithLMHead, ) - from .modeling_task import AutoForCausalLM + from .modeling_task import ( + AutoForCausalLM, + AutoForQuestionAnswering, + AutoForSequenceClassification, + AutoForTokenClassification, + ) try: if not is_tf_available(): diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 7809b2a6cc2cfc..f7a80bd7d37c52 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -438,7 +438,9 @@ def from_config(cls, config, **kwargs): elif type(config) in cls._model_mapping.keys(): model_class = _get_model_class(config, cls._model_mapping) return model_class._from_config(config, **kwargs) - + else: + model_class = cls._model_mapping["auto"] + return model_class._from_config(config, **kwargs) raise ValueError( f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 406651da2a3132..e9ac5c6d44d483 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -500,7 +500,7 @@ ("granitemoe", "GraniteMoeForCausalLM"), ("jamba", "JambaForCausalLM"), ("jetmoe", "JetMoeForCausalLM"), - ("llama", "AutoForCausalLM"), + ("auto", "AutoForCausalLM"), ("mamba", "MambaForCausalLM"), ("mamba2", "Mamba2ForCausalLM"), ("marian", "MarianForCausalLM"), @@ -960,7 +960,7 @@ ("layoutlmv3", "LayoutLMv3ForSequenceClassification"), ("led", "LEDForSequenceClassification"), ("lilt", "LiltForSequenceClassification"), - ("llama", "LlamaForSequenceClassification"), + ("auto", "AutoForSequenceClassification"), ("longformer", "LongformerForSequenceClassification"), ("luke", "LukeForSequenceClassification"), ("markuplm", "MarkupLMForSequenceClassification"), @@ -1045,7 +1045,7 @@ ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), ("led", "LEDForQuestionAnswering"), ("lilt", "LiltForQuestionAnswering"), - ("llama", "LlamaForQuestionAnswering"), + ("auto", "AutoForQuestionAnswering"), ("longformer", "LongformerForQuestionAnswering"), ("luke", "LukeForQuestionAnswering"), ("lxmert", "LxmertForQuestionAnswering"), @@ -1147,7 +1147,7 @@ ("layoutlmv2", "LayoutLMv2ForTokenClassification"), ("layoutlmv3", "LayoutLMv3ForTokenClassification"), ("lilt", "LiltForTokenClassification"), - ("llama", "LlamaForTokenClassification"), + ("auto", "AutoForTokenClassification"), ("longformer", "LongformerForTokenClassification"), ("luke", "LukeForTokenClassification"), ("markuplm", "MarkupLMForTokenClassification"), diff --git a/src/transformers/models/llama/__init__.py b/src/transformers/models/llama/__init__.py index 3f6461c4c093f2..34dd29ab22f67d 100644 --- a/src/transformers/models/llama/__init__.py +++ b/src/transformers/models/llama/__init__.py @@ -50,12 +50,8 @@ pass else: _import_structure["modeling_llama"] = [ - "LlamaForCausalLM", "LlamaModel", "LlamaPreTrainedModel", - "LlamaForSequenceClassification", - "LlamaForQuestionAnswering", - "LlamaForTokenClassification", ] try: @@ -93,10 +89,6 @@ pass else: from .modeling_llama import ( - LlamaForCausalLM, - LlamaForQuestionAnswering, - LlamaForSequenceClassification, - LlamaForTokenClassification, LlamaModel, LlamaPreTrainedModel, ) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 9e67f4f7381e24..6d354b021cca4c 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -23,6 +23,12 @@ from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed from transformers.generation.configuration_utils import GenerationConfig +from transformers.models.auto import ( + AutoForCausalLM, + AutoForQuestionAnswering, + AutoForSequenceClassification, + AutoForTokenClassification, +) from transformers.testing_utils import ( cleanup, require_flash_attn, @@ -44,10 +50,6 @@ import torch from transformers import ( - LlamaForCausalLM, - LlamaForQuestionAnswering, - LlamaForSequenceClassification, - LlamaForTokenClassification, LlamaModel, LlamaTokenizer, ) @@ -197,7 +199,7 @@ def create_and_check_for_causal_lm( encoder_hidden_states, encoder_attention_mask, ): - model = LlamaForCausalLM(config=config) + model = AutoForCausalLM(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, labels=token_labels) @@ -217,7 +219,7 @@ def create_and_check_decoder_model_past_large_inputs( ): config.is_decoder = True config.add_cross_attention = True - model = LlamaForCausalLM(config=config) + model = AutoForCausalLM(config=config) model.to(torch_device) model.eval() @@ -285,23 +287,23 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi all_model_classes = ( ( LlamaModel, - LlamaForCausalLM, - LlamaForSequenceClassification, - LlamaForQuestionAnswering, - LlamaForTokenClassification, + AutoForCausalLM, + AutoForSequenceClassification, + AutoForQuestionAnswering, + AutoForTokenClassification, ) if is_torch_available() else () ) - all_generative_model_classes = (LlamaForCausalLM,) if is_torch_available() else () + all_generative_model_classes = (AutoForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": LlamaModel, - "text-classification": LlamaForSequenceClassification, - "text-generation": LlamaForCausalLM, - "zero-shot": LlamaForSequenceClassification, - "question-answering": LlamaForQuestionAnswering, - "token-classification": LlamaForTokenClassification, + "text-classification": AutoForSequenceClassification, + "text-generation": AutoForCausalLM, + "zero-shot": AutoForSequenceClassification, + "question-answering": AutoForQuestionAnswering, + "token-classification": AutoForTokenClassification, } if is_torch_available() else {} @@ -315,7 +317,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi model_split_percents = [0.5, 0.7, 0.8] # used in `test_torch_compile_for_training` - _torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None + _torch_compile_train_cls = AutoForCausalLM if is_torch_available() else None def setUp(self): self.model_tester = LlamaModelTester(self) @@ -340,7 +342,7 @@ def test_llama_sequence_classification_model(self): input_ids = input_dict["input_ids"] attention_mask = input_ids.ne(1).to(torch_device) sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = LlamaForSequenceClassification(config) + model = AutoForSequenceClassification(config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) @@ -353,7 +355,7 @@ def test_llama_sequence_classification_model_for_single_label(self): input_ids = input_dict["input_ids"] attention_mask = input_ids.ne(1).to(torch_device) sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = LlamaForSequenceClassification(config) + model = AutoForSequenceClassification(config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) @@ -368,7 +370,7 @@ def test_llama_sequence_classification_model_for_multi_label(self): sequence_labels = ids_tensor( [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size ).to(torch.float) - model = LlamaForSequenceClassification(config) + model = AutoForSequenceClassification(config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) @@ -380,7 +382,7 @@ def test_llama_token_classification_model(self): input_ids = input_dict["input_ids"] attention_mask = input_ids.ne(1).to(torch_device) token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = LlamaForTokenClassification(config=config) + model = AutoForTokenClassification(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=attention_mask, labels=token_labels) @@ -536,17 +538,17 @@ def _reinitialize_config(base_config, new_kwargs): # from untouched config -> ✅ base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common() - original_model = LlamaForCausalLM(base_config).to(torch_device) + original_model = AutoForCausalLM(base_config).to(torch_device) original_model(**model_inputs) # from a config with the expected rope configuration -> ✅ config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}}) - original_model = LlamaForCausalLM(config).to(torch_device) + original_model = AutoForCausalLM(config).to(torch_device) original_model(**model_inputs) # from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}}) - original_model = LlamaForCausalLM(config).to(torch_device) + original_model = AutoForCausalLM(config).to(torch_device) original_model(**model_inputs) # from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config) @@ -555,13 +557,13 @@ def _reinitialize_config(base_config, new_kwargs): ) self.assertTrue(config.rope_scaling["type"] == "linear") self.assertTrue(config.rope_scaling["rope_type"] == "linear") - original_model = LlamaForCausalLM(config).to(torch_device) + original_model = AutoForCausalLM(config).to(torch_device) original_model(**model_inputs) # from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}}) - original_model = LlamaForCausalLM(config).to(torch_device) + original_model = AutoForCausalLM(config).to(torch_device) original_model(**model_inputs) self.assertEqual(len(logs.output), 1) self.assertIn("factor field", logs.output[0]) @@ -571,7 +573,7 @@ def _reinitialize_config(base_config, new_kwargs): config = _reinitialize_config( base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}} ) - original_model = LlamaForCausalLM(config).to(torch_device) + original_model = AutoForCausalLM(config).to(torch_device) original_model(**model_inputs) self.assertEqual(len(logs.output), 1) self.assertIn("Unrecognized keys", logs.output[0]) @@ -594,7 +596,7 @@ def test_use_flash_attention_2_true(self): model = model_class(config) model.save_pretrained(tmp_dir) - new_model = LlamaForCausalLM.from_pretrained( + new_model = AutoForCausalLM.from_pretrained( tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16 ).to("cuda") @@ -645,7 +647,7 @@ def test_llama_3_1_hard(self): ) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") - model = LlamaForCausalLM.from_pretrained( + model = AutoForCausalLM.from_pretrained( "meta-llama/Meta-Llama-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16 ) input_text = ["Tell me about the french revolution."] @@ -660,7 +662,7 @@ def test_llama_3_1_hard(self): def test_model_7b_logits_bf16(self): input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] - model = LlamaForCausalLM.from_pretrained( + model = AutoForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" ) @@ -704,7 +706,7 @@ def test_model_7b_logits_bf16(self): def test_model_7b_logits(self): input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] - model = LlamaForCausalLM.from_pretrained( + model = AutoForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16 ) @@ -754,7 +756,7 @@ def test_model_7b_dola_generation(self): ) prompt = "Simply put, the theory of relativity states that " tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - model = LlamaForCausalLM.from_pretrained( + model = AutoForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16 ) model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device) @@ -791,7 +793,7 @@ def test_compile_static_cache(self): "My favorite all time favorite condiment is ketchup.", ] tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="right") - model = LlamaForCausalLM.from_pretrained( + model = AutoForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16 ) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) @@ -856,7 +858,7 @@ def test_export_static_cache(self): cache_implementation = "static" attn_implementation = "sdpa" batch_size = 1 - model = LlamaForCausalLM.from_pretrained( + model = AutoForCausalLM.from_pretrained( llama_model_ckp, device_map=device, torch_dtype=dtype, @@ -896,7 +898,7 @@ def setUp(self): model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" self.model_dtype = torch.float32 self.tokenizer = LlamaTokenizer.from_pretrained(model_name) - self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) + self.model = AutoForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) def get_test_data(self): template = "my favorite {}" From 6028e85990a6bd8e9c436ab4ed5cb84d21c8f46b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 19:50:04 +0100 Subject: [PATCH 28/40] fixup --- src/transformers/__init__.py | 8 -- .../integrations/flex_attention.py | 2 +- .../integrations/sdpa_attention.py | 1 + src/transformers/modeling_utils.py | 12 +- src/transformers/models/auto/modeling_auto.py | 10 +- src/transformers/models/auto/modeling_task.py | 6 +- .../models/llama/modeling_llama.py | 35 ++++-- src/transformers/utils/generic.py | 3 - tests/models/llama/test_modeling_llama.py | 103 ------------------ 9 files changed, 38 insertions(+), 142 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9d50333ab72e7e..6332034d54f327 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2572,10 +2572,6 @@ ) _import_structure["models.llama"].extend( [ - "LlamaForCausalLM", - "LlamaForQuestionAnswering", - "LlamaForSequenceClassification", - "LlamaForTokenClassification", "LlamaModel", "LlamaPreTrainedModel", ] @@ -7265,10 +7261,6 @@ LiltPreTrainedModel, ) from .models.llama import ( - LlamaForCausalLM, - LlamaForQuestionAnswering, - LlamaForSequenceClassification, - LlamaForTokenClassification, LlamaModel, LlamaPreTrainedModel, ) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 1694a4f5b517f4..9c309a9ad50575 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -4,8 +4,8 @@ if is_torch_greater_or_equal("2.5"): from torch.nn.attention.flex_attention import flex_attention -def flex_attention_forward(module, query, key, value, attention_mask, output_attentions=False, **_kwargs): +def flex_attention_forward(module, query, key, value, attention_mask, output_attentions=False, **_kwargs): causal_mask = attention_mask if causal_mask is not None: causal_mask = causal_mask[:, :, :, : key.shape[-2]] diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 8a0de7769fcbad..0cf58f035ea9c1 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -12,6 +12,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kwargs): key = repeat_kv(key, module.num_key_value_groups) value = repeat_kv(value, module.num_key_value_groups) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index efa15b073b13a7..a3c99f1124114e 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5493,8 +5493,6 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] - - ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, Callable]] = {} ALL_ATTENTION_FUNCTIONS.update( @@ -5507,7 +5505,6 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): class GradientCheckpointLayer(torch.nn.Module): - def __call__(self, *args, **kwargs): """ Adjust the behavior of the inherited class by overriding `__call__`. @@ -5515,7 +5512,9 @@ def __call__(self, *args, **kwargs): Automatically handles gradient checkpointing based on flags in the provided arguments. """ # Extract necessary flags and arguments - gradient_checkpointing = kwargs.pop("gradient_checkpointing", False) | getattr(self, "gradient_checkpointing", False) + gradient_checkpointing = kwargs.pop("gradient_checkpointing", False) | getattr( + self, "gradient_checkpointing", False + ) training = self.training if gradient_checkpointing and training: @@ -5531,11 +5530,12 @@ def _apply_gradient_checkpointing(self, *args, **kwargs): By default, uses `torch.utils.checkpoint.checkpoint`. """ + # Assume `self.forward` is compatible with checkpointing def wrapped_forward(): - return self.forward(*args, **kwargs) - return self._gradient_checkpointing_func(wrapped_forward) + return self.forward(*args, **kwargs) + return self._gradient_checkpointing_func(wrapped_forward) def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): """ diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e9ac5c6d44d483..c80921728ede0e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -463,6 +463,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping + ("auto", "AutoForCausalLM"), ("bart", "BartForCausalLM"), ("bert", "BertLMHeadModel"), ("bert-generation", "BertGenerationDecoder"), @@ -473,7 +474,6 @@ ("blenderbot-small", "BlenderbotSmallForCausalLM"), ("bloom", "BloomForCausalLM"), ("camembert", "CamembertForCausalLM"), - ("code_llama", "LlamaForCausalLM"), ("codegen", "CodeGenForCausalLM"), ("cohere", "CohereForCausalLM"), ("cpmant", "CpmAntForCausalLM"), @@ -500,7 +500,6 @@ ("granitemoe", "GraniteMoeForCausalLM"), ("jamba", "JambaForCausalLM"), ("jetmoe", "JetMoeForCausalLM"), - ("auto", "AutoForCausalLM"), ("mamba", "MambaForCausalLM"), ("mamba2", "Mamba2ForCausalLM"), ("marian", "MarianForCausalLM"), @@ -920,6 +919,7 @@ [ # Model for Sequence Classification mapping ("albert", "AlbertForSequenceClassification"), + ("auto", "AutoForSequenceClassification"), ("bart", "BartForSequenceClassification"), ("bert", "BertForSequenceClassification"), ("big_bird", "BigBirdForSequenceClassification"), @@ -928,7 +928,6 @@ ("bloom", "BloomForSequenceClassification"), ("camembert", "CamembertForSequenceClassification"), ("canine", "CanineForSequenceClassification"), - ("code_llama", "LlamaForSequenceClassification"), ("convbert", "ConvBertForSequenceClassification"), ("ctrl", "CTRLForSequenceClassification"), ("data2vec-text", "Data2VecTextForSequenceClassification"), @@ -960,7 +959,6 @@ ("layoutlmv3", "LayoutLMv3ForSequenceClassification"), ("led", "LEDForSequenceClassification"), ("lilt", "LiltForSequenceClassification"), - ("auto", "AutoForSequenceClassification"), ("longformer", "LongformerForSequenceClassification"), ("luke", "LukeForSequenceClassification"), ("markuplm", "MarkupLMForSequenceClassification"), @@ -1017,6 +1015,7 @@ [ # Model for Question Answering mapping ("albert", "AlbertForQuestionAnswering"), + ("auto", "AutoForQuestionAnswering"), ("bart", "BartForQuestionAnswering"), ("bert", "BertForQuestionAnswering"), ("big_bird", "BigBirdForQuestionAnswering"), @@ -1045,7 +1044,6 @@ ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), ("led", "LEDForQuestionAnswering"), ("lilt", "LiltForQuestionAnswering"), - ("auto", "AutoForQuestionAnswering"), ("longformer", "LongformerForQuestionAnswering"), ("luke", "LukeForQuestionAnswering"), ("lxmert", "LxmertForQuestionAnswering"), @@ -1114,6 +1112,7 @@ [ # Model for Token Classification mapping ("albert", "AlbertForTokenClassification"), + ("auto", "AutoForTokenClassification"), ("bert", "BertForTokenClassification"), ("big_bird", "BigBirdForTokenClassification"), ("biogpt", "BioGptForTokenClassification"), @@ -1147,7 +1146,6 @@ ("layoutlmv2", "LayoutLMv2ForTokenClassification"), ("layoutlmv3", "LayoutLMv3ForTokenClassification"), ("lilt", "LiltForTokenClassification"), - ("auto", "AutoForTokenClassification"), ("longformer", "LongformerForTokenClassification"), ("luke", "LukeForTokenClassification"), ("markuplm", "MarkupLMForTokenClassification"), diff --git a/src/transformers/models/auto/modeling_task.py b/src/transformers/models/auto/modeling_task.py index 0a18a365a8d4ae..3c889674ce0d4c 100644 --- a/src/transformers/models/auto/modeling_task.py +++ b/src/transformers/models/auto/modeling_task.py @@ -44,7 +44,6 @@ def forward( num_logits_to_keep: int = 1, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: - outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -156,7 +155,6 @@ def forward( ) - class AutoForQuestionAnswering(PreTrainedModel): base_model_prefix = "transformer" @@ -258,7 +256,7 @@ def forward( past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, - return_dict = False, + return_dict=False, **kwargs, ) -> Union[Tuple, TokenClassifierOutput]: r""" @@ -273,7 +271,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - **kwargs + **kwargs, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d9de7225e569ab..b6f750a32df887 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -201,7 +201,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class:nn.Module, query, key, value, attention_mask=None, **_kwargs): +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): config = attention_class.config key_states = repeat_kv(key, attention_class.num_key_value_groups) value_states = repeat_kv(value, attention_class.num_key_value_groups) @@ -217,6 +217,7 @@ def eager_attention_forward(attention_class:nn.Module, query, key, value, attent attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights + class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -226,12 +227,20 @@ def __init__(self, config: LlamaConfig, layer_idx: int): self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 - self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) def forward( self, @@ -244,9 +253,9 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1,2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1,2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1,2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -261,7 +270,11 @@ def forward( attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( - self, query_states, key_states, value_states, **kwargs, + self, + query_states, + key_states, + value_states, + **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -332,6 +345,7 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + class LlamaModel(LlamaPreTrainedModel): _embedding_layer = "embed_tokens" @@ -548,4 +562,3 @@ def _prepare_4d_causal_attention_mask_with_cache_position( ) return causal_mask - diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 0ec52b5f440c40..a3c6503cd04f16 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -886,8 +886,6 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): num_logits_to_keep: int = 0 - - def validate_config_kwargs(func): """ A decorator to validate and initialize kwargs based on a config object. @@ -919,4 +917,3 @@ def wrapper(*args, **kwargs): return func(*args, **validated_kwargs) return wrapper - diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 6d354b021cca4c..1cd7ef2d1d11e3 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -53,8 +53,6 @@ LlamaModel, LlamaTokenizer, ) - from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding, LlamaRotaryEmbedding - class LlamaModelTester: def __init__( @@ -426,107 +424,6 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - def test_model_rope_scaling(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - scaling_factor = 10 - short_input_length = 10 - long_input_length = int(config.max_position_embeddings * 1.5) - - # Inputs - x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device - position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) - position_ids_short = position_ids_short.unsqueeze(0) - position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) - position_ids_long = position_ids_long.unsqueeze(0) - - # Sanity check original RoPE - original_rope = LlamaRotaryEmbedding(config=config).to(torch_device) - original_cos_short, original_sin_short = original_rope(x, position_ids_short) - original_cos_long, original_sin_long = original_rope(x, position_ids_long) - torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) - - # Sanity check linear RoPE scaling - # New position "x" should match original position with index "x/scaling_factor" - config.rope_scaling = {"type": "linear", "factor": scaling_factor} - linear_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device) - linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) - linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) - torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) - for new_position in range(0, long_input_length, scaling_factor): - original_position = int(new_position // scaling_factor) - torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) - torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) - - # Sanity check Dynamic NTK RoPE scaling - # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase - # with scaling_factor (or that `inv_freq` decreases) - config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} - ntk_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device) - ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) - ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) - torch.testing.assert_close(ntk_cos_short, original_cos_short) - torch.testing.assert_close(ntk_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_sin_long, original_sin_long) - self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) - - # Sanity check Yarn RoPE scaling - # Scaling should be over the entire input - config.rope_scaling = {"type": "yarn", "factor": scaling_factor} - yarn_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device) - yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) - yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) - torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_cos_short, original_cos_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_sin_long, original_sin_long) - - def test_rope_class_retrocompatibility(self): - # Delete me when we remove compatibility for the old API :) - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - scaling_factor = 10 - short_input_length = 10 - long_input_length = int(config.max_position_embeddings * 1.5) - config.rope_scaling = {"type": "linear", "factor": 10} - - # Inputs - x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device - position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) - position_ids_short = position_ids_short.unsqueeze(0) - position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) - position_ids_long = position_ids_long.unsqueeze(0) - - # Old API -- under the hood, "type": "linear" is set and `LlamaRotaryEmbedding` is called - old_api_rope = LlamaLinearScalingRotaryEmbedding( - config.hidden_size // config.num_attention_heads, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - ).to(torch_device) - old_cos_short, old_sin_short = old_api_rope(x, position_ids_short) - old_cos_long, old_sin_long = old_api_rope(x, position_ids_long) - - # New API - config.rope_scaling = {"type": "linear", "factor": scaling_factor} - new_api_rope = LlamaRotaryEmbedding(config=config).to(torch_device) - new_cos_short, new_sin_short = new_api_rope(x, position_ids_short) - new_cos_long, new_sin_long = new_api_rope(x, position_ids_long) - - # The results should match - torch.testing.assert_close(old_cos_short, new_cos_short) - torch.testing.assert_close(old_sin_short, new_sin_short) - torch.testing.assert_close(old_cos_long, new_cos_long) - torch.testing.assert_close(old_sin_long, new_sin_long) def test_model_loading_old_rope_configs(self): def _reinitialize_config(base_config, new_kwargs): From 725d00caf4170197a4e912b5d73aaaac6c455377 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 09:22:04 +0100 Subject: [PATCH 29/40] fix some stuff --- src/transformers/modeling_utils.py | 7 +++++-- src/transformers/models/auto/modeling_task.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a3c99f1124114e..53dcc7c194a8e6 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1354,6 +1354,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # `config.base_model_tp_plan` during `post_init`. _tp_plan = None + _output_embedding = None + _input_embedding = None + @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: """ @@ -1832,7 +1835,7 @@ def get_input_embeddings(self) -> nn.Module: if base_model is not self: return base_model.get_input_embeddings() else: - return getattr(self, self._embeding_layer) + return getattr(self, self._input_embedding) def set_input_embeddings(self, value: nn.Module): """ @@ -1845,7 +1848,7 @@ def set_input_embeddings(self, value: nn.Module): if base_model is not self: base_model.set_input_embeddings(value) else: - raise setattr(self, self._embeding_layer, value) + raise setattr(self, self._input_embedding, value) def get_output_embeddings(self) -> nn.Module: """ diff --git a/src/transformers/models/auto/modeling_task.py b/src/transformers/models/auto/modeling_task.py index 3c889674ce0d4c..7e01f8730ff2aa 100644 --- a/src/transformers/models/auto/modeling_task.py +++ b/src/transformers/models/auto/modeling_task.py @@ -19,7 +19,7 @@ class AutoForCausalLM(PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _embeding_layer = "model.embed_tokens" + _input_embedding = "model.embed_tokens" _output_embedding = "lm_head" _no_split_modules = [] _supports_cache_class = True diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b6f750a32df887..f87ff3650a4a2b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -347,7 +347,7 @@ def _init_weights(self, module): class LlamaModel(LlamaPreTrainedModel): - _embedding_layer = "embed_tokens" + _input_embedding = "embed_tokens" def __init__(self, config: LlamaConfig): super().__init__(config) From c224f36d108bdccb6595eac08fd29497d57ab200 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 09:39:36 +0100 Subject: [PATCH 30/40] fix some tests --- src/transformers/modeling_utils.py | 25 +++++++++++++++++-- .../models/auto/configuration_auto.py | 2 ++ src/transformers/models/auto/modeling_task.py | 16 ++++++++++-- tests/models/llama/test_modeling_llama.py | 2 +- 4 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 53dcc7c194a8e6..b9e29e35dac127 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1356,6 +1356,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix _output_embedding = None _input_embedding = None + gradient_checkpointing = False @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: @@ -1847,8 +1848,10 @@ def set_input_embeddings(self, value: nn.Module): base_model = getattr(self, self.base_model_prefix, self) if base_model is not self: base_model.set_input_embeddings(value) + elif self._input_embedding is not None: + setattr(self, self._input_embedding, value) else: - raise setattr(self, self._input_embedding, value) + raise ValueError("No input embedding") def get_output_embeddings(self) -> nn.Module: """ @@ -1857,7 +1860,25 @@ def get_output_embeddings(self) -> nn.Module: Returns: `nn.Module`: A torch module mapping hidden states to vocabulary. """ - return getattr(self, self._output_embedding, None) + if self._output_embedding is not None: + return getattr(self, self._output_embedding, None) + else: + return None + + def set_output_embeddings(self, value: nn.Module): + """ + Set model's input embeddings. + + Args: + value (`nn.Module`): A module mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + base_model.set_output_embeddings(value) + elif self._output_embedding is not None: + setattr(self, self._output_embedding, value) + else: + raise ValueError() def _init_weights(self, module): """ diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 4ab6d392282657..fe8dcc88e7a034 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -900,6 +900,8 @@ class AutoConfig: This class cannot be instantiated directly using `__init__()` (throws an error). """ + model_type = "auto" + def __init__(self): raise EnvironmentError( "AutoConfig is designed to be instantiated " diff --git a/src/transformers/models/auto/modeling_task.py b/src/transformers/models/auto/modeling_task.py index 7e01f8730ff2aa..592c48991a52ab 100644 --- a/src/transformers/models/auto/modeling_task.py +++ b/src/transformers/models/auto/modeling_task.py @@ -13,16 +13,16 @@ ) from ...modeling_utils import PreTrainedModel from ...utils.generic import KwargsForCausalLM, validate_config_kwargs -from ..auto import AutoModel +from ..auto import AutoConfig, AutoModel class AutoForCausalLM(PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _input_embedding = "model.embed_tokens" _output_embedding = "lm_head" _no_split_modules = [] _supports_cache_class = True + config_class = AutoConfig def __init__(self, config): super().__init__(config) @@ -33,6 +33,13 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value: nn.Module): + self.model.set_input_embeddings(value) + + @validate_config_kwargs def forward( self, input_ids: torch.LongTensor = None, @@ -74,6 +81,8 @@ def forward( class AutoForSequenceClassification(PreTrainedModel): + config_class = AutoConfig + def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels @@ -157,6 +166,7 @@ def forward( class AutoForQuestionAnswering(PreTrainedModel): base_model_prefix = "transformer" + config_class = AutoConfig # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama def __init__(self, config): @@ -225,6 +235,8 @@ def forward( class AutoForTokenClassification(PreTrainedModel): + config_class = AutoConfig + def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 1cd7ef2d1d11e3..de7fac72030b61 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -54,6 +54,7 @@ LlamaTokenizer, ) + class LlamaModelTester: def __init__( self, @@ -424,7 +425,6 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - def test_model_loading_old_rope_configs(self): def _reinitialize_config(base_config, new_kwargs): # Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation From 3f68c7cf72b04b4527696e5924e929a1448b474c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 09:44:14 +0100 Subject: [PATCH 31/40] 9 left! --- src/transformers/models/auto/modeling_task.py | 27 +++++++------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/auto/modeling_task.py b/src/transformers/models/auto/modeling_task.py index 592c48991a52ab..5561626b57ee3e 100644 --- a/src/transformers/models/auto/modeling_task.py +++ b/src/transformers/models/auto/modeling_task.py @@ -39,7 +39,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value: nn.Module): self.model.set_input_embeddings(value) - @validate_config_kwargs + def forward( self, input_ids: torch.LongTensor = None, @@ -51,6 +51,8 @@ def forward( num_logits_to_keep: int = 1, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: + return_dict = return_dict if return_dict else self.config.return_dict + outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -106,10 +108,8 @@ def forward( past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs ) -> Union[Tuple, SequenceClassifierOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -119,10 +119,8 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, return_dict=return_dict, + **kwargs ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) @@ -183,6 +181,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.transformer.embed_tokens = value + @validate_config_kwargs def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -192,8 +191,6 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, QuestionAnsweringModelOutput]: @@ -205,9 +202,8 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, return_dict=return_dict, + **kwargs ) sequence_output = outputs[0] @@ -259,7 +255,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value - @validate_config_kwargs def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -271,18 +266,14 @@ def forward( return_dict=False, **kwargs, ) -> Union[Tuple, TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ + return_dict = return_dict if return_dict else self.config.return_dict outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, + return_dict=return_dict, **kwargs, ) sequence_output = outputs[0] From 1a5a834f536e179d9dd87d97246c5df840334945 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 09:53:53 +0100 Subject: [PATCH 32/40] fix auto? --- src/transformers/models/auto/auto_factory.py | 5 +++++ src/transformers/models/auto/configuration_auto.py | 2 ++ 2 files changed, 7 insertions(+) diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index f7a80bd7d37c52..7cd9599372435a 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -566,6 +566,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): return model_class.from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs ) + else: + model_class = cls._model_mapping[PretrainedConfig] + return model_class.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs + ) raise ValueError( f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index fe8dcc88e7a034..3b28c086dd819f 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -37,6 +37,7 @@ ("altclip", "AltCLIPConfig"), ("audio-spectrogram-transformer", "ASTConfig"), ("autoformer", "AutoformerConfig"), + ("auto", "PretrainedConfig"), ("bark", "BarkConfig"), ("bart", "BartConfig"), ("beit", "BeitConfig"), @@ -328,6 +329,7 @@ ("altclip", "AltCLIP"), ("audio-spectrogram-transformer", "Audio Spectrogram Transformer"), ("autoformer", "Autoformer"), + ("auto", "Auto"), ("bark", "Bark"), ("bart", "BART"), ("barthez", "BARThez"), From 53450ac36515fe347b9c4a0e616272eb6e888d2d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 10:14:20 +0100 Subject: [PATCH 33/40] fix --- src/transformers/models/auto/modeling_task.py | 26 ++- src/transformers/utils/generic.py | 8 +- tests/test_modeling_common.py | 184 +++++++++--------- 3 files changed, 109 insertions(+), 109 deletions(-) diff --git a/src/transformers/models/auto/modeling_task.py b/src/transformers/models/auto/modeling_task.py index 5561626b57ee3e..0cae15b604a75f 100644 --- a/src/transformers/models/auto/modeling_task.py +++ b/src/transformers/models/auto/modeling_task.py @@ -48,20 +48,20 @@ def forward( past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, - num_logits_to_keep: int = 1, + num_logits_to_keep: int = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: - return_dict = return_dict if return_dict else self.config.return_dict + return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - return_dict=return_dict, + return_dict=True, **kwargs, ) - # self.lm_head.weight.data = self.model.embed_tokens.weight.data # TODO fix me! + hidden_states = outputs[0] logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) @@ -69,11 +69,7 @@ def forward( if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( + output = CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, @@ -81,6 +77,8 @@ def forward( attentions=outputs.attentions, ) + return output if return_dict else output.to_tuple() + class AutoForSequenceClassification(PreTrainedModel): config_class = AutoConfig @@ -119,7 +117,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - return_dict=return_dict, + return_dict=True, **kwargs ) hidden_states = transformer_outputs[0] @@ -202,7 +200,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - return_dict=return_dict, + return_dict=True, **kwargs ) @@ -263,17 +261,17 @@ def forward( past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, - return_dict=False, + return_dict=None, **kwargs, ) -> Union[Tuple, TokenClassifierOutput]: - return_dict = return_dict if return_dict else self.config.return_dict + return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - return_dict=return_dict, + return_dict=True, **kwargs, ) sequence_output = outputs[0] diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index a3c6503cd04f16..2d7279569c0f2d 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -905,10 +905,10 @@ def wrapper(*args, **kwargs): # Merge provided kwargs with defaults validated_kwargs = {**default_kwargs, **kwargs} - # Validate kwargs against TypedDict - for key in validated_kwargs: - if key not in KwargsForCausalLM.__annotations__: - raise ValueError(f"Invalid keyword argument: {key}") + # # Validate kwargs against TypedDict + # for key in validated_kwargs: + # if key not in KwargsForCausalLM.__annotations__: + # raise ValueError(f"Invalid keyword argument: {key}") if self.gradient_checkpointing and self.training and default_kwargs["use_cache"]: validated_kwargs["use_cache"] = False diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f3f326a4ce8112..1cf6c63697eac8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -927,78 +927,79 @@ def test_attention_outputs(self): encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes for model_class in self.all_model_classes: - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = False - config.return_dict = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - - # check that output_attentions also work using config - del inputs_dict["output_attentions"] - config.output_attentions = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + with self.subTest(model_class.__name__): + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - if chunk_length is not None: - self.assertListEqual( - list(attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], - ) - else: - self.assertListEqual( - list(attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) - out_len = len(outputs) - - if self.is_encoder_decoder: - correct_outlen = 5 - - # loss is at first position - if "labels" in inputs_dict: - correct_outlen += 1 # loss is added to beginning - # Question Answering model returns start_logits and end_logits - if model_class.__name__ in [ - *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES), - *get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES), - ]: - correct_outlen += 1 # start_logits and end_logits instead of only 1 output - if "past_key_values" in outputs: - correct_outlen += 1 # past_key_values have been returned - - self.assertEqual(out_len, correct_outlen) - - # decoder attentions - decoder_attentions = outputs.decoder_attentions - self.assertIsInstance(decoder_attentions, (list, tuple)) - self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) - self.assertListEqual( - list(decoder_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], - ) + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + if chunk_length is not None: + self.assertListEqual( + list(attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + + if self.is_encoder_decoder: + correct_outlen = 5 + + # loss is at first position + if "labels" in inputs_dict: + correct_outlen += 1 # loss is added to beginning + # Question Answering model returns start_logits and end_logits + if model_class.__name__ in [ + *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES), + *get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES), + ]: + correct_outlen += 1 # start_logits and end_logits instead of only 1 output + if "past_key_values" in outputs: + correct_outlen += 1 # past_key_values have been returned + + self.assertEqual(out_len, correct_outlen) + + # decoder attentions + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], + ) - # cross attentions - cross_attentions = outputs.cross_attentions - self.assertIsInstance(cross_attentions, (list, tuple)) - self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) - self.assertListEqual( - list(cross_attentions[0].shape[-3:]), - [ - self.model_tester.num_attention_heads, - decoder_seq_length, - encoder_key_length, - ], - ) + # cross attentions + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + decoder_seq_length, + encoder_key_length, + ], + ) # Check attention is always last and order is fine inputs_dict["output_attentions"] = True @@ -1006,30 +1007,31 @@ def test_attention_outputs(self): model = model_class(config) model.to(torch_device) model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + with self.subTest(model_class.__name__): + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - if hasattr(self.model_tester, "num_hidden_states_types"): - added_hidden_states = self.model_tester.num_hidden_states_types - elif self.is_encoder_decoder: - added_hidden_states = 2 - else: - added_hidden_states = 1 - self.assertEqual(out_len + added_hidden_states, len(outputs)) + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) - self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) - if chunk_length is not None: - self.assertListEqual( - list(self_attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], - ) - else: - self.assertListEqual( - list(self_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + if chunk_length is not None: + self.assertListEqual( + list(self_attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) @slow def test_torchscript_simple(self): From 2016bc47d0a6f104d768d1c4cc1d7ab8dd2ca8e5 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 10:18:38 +0100 Subject: [PATCH 34/40] default init weights --- src/transformers/modeling_utils.py | 10 +++++++++- src/transformers/models/llama/modeling_llama.py | 11 ----------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b9e29e35dac127..a4f1bb25f0a8e3 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1887,7 +1887,15 @@ def _init_weights(self, module): using `from_pretrained`. Any attempt to initialize outside of this function will be useless as the torch.nn.init function are all replaced with skip. """ - pass + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() def _initialize_weights(self, module): """ diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f87ff3650a4a2b..8de23ba7146625 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -334,17 +334,6 @@ class LlamaPreTrainedModel(PreTrainedModel): _supports_static_cache = True gradient_checkpointing = False - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - class LlamaModel(LlamaPreTrainedModel): _input_embedding = "embed_tokens" From 4f36712da1e805e02f03da17e47ac93ca2eecd60 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 10:24:11 +0100 Subject: [PATCH 35/40] nit? --- src/transformers/modeling_utils.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a4f1bb25f0a8e3..3dcd54303a8294 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1356,7 +1356,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix _output_embedding = None _input_embedding = None - gradient_checkpointing = False @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: @@ -2576,6 +2575,13 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): if isinstance(layer, GradientCheckpointLayer): layer.gradient_checkpointing_enable(gradient_checkpointing_kwargs) + @property + def gradient_checkpointing(self, gradient_checkpointing_kwargs=None): + for layer in list(self.modules()): + if isinstance(layer, GradientCheckpointLayer): + return layer.gradient_checkpointing + return False + @property def is_gradient_checkpointing(self) -> bool: """ @@ -5537,6 +5543,11 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): class GradientCheckpointLayer(torch.nn.Module): + + def __init__(self, *args, **kwargs): + self.gradient_checkpointing = False + super().__init__( *args, **kwargs) + def __call__(self, *args, **kwargs): """ Adjust the behavior of the inherited class by overriding `__call__`. From 9461039d87f9060f42997602334ebb97cd4bd539 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 10:36:49 +0100 Subject: [PATCH 36/40] nits --- src/transformers/models/llama/modeling_llama.py | 1 - src/transformers/utils/generic.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8de23ba7146625..b00c48766e3255 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -332,7 +332,6 @@ class LlamaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True - gradient_checkpointing = False class LlamaModel(LlamaPreTrainedModel): diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 2d7279569c0f2d..b462c5bb16b119 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -30,7 +30,6 @@ import torch from packaging import version -from ..modeling_flash_attention_utils import FlashAttentionKwargs from .import_utils import ( get_torch_version, is_flax_available, @@ -871,7 +870,7 @@ class LossKwargs(TypedDict, total=False): num_items_in_batch: Optional[int] -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): +class KwargsForCausalLM(LossKwargs): input_ids: torch.LongTensor = None attention_mask: Optional[torch.Tensor] = None position_ids: Optional[torch.LongTensor] = None From 584b4430964e24951397e4ccc5c3812a92e79218 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 10:43:13 +0100 Subject: [PATCH 37/40] fix unpack imoprt --- src/transformers/models/auto/modeling_task.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/auto/modeling_task.py b/src/transformers/models/auto/modeling_task.py index 0cae15b604a75f..61d3e78511e494 100644 --- a/src/transformers/models/auto/modeling_task.py +++ b/src/transformers/models/auto/modeling_task.py @@ -1,4 +1,6 @@ -from typing import List, Optional, Tuple, Union, Unpack +from typing import List, Optional, Tuple, Union + +from ...processing_utils import Unpack import torch import torch.nn as nn From 95cb944ee686f78cbd3fc4c8a70dcdd99bd2627f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 11:33:37 +0100 Subject: [PATCH 38/40] be permissive --- src/transformers/modeling_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 272b4af5af0128..b31d3b8d4a0646 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1493,7 +1493,10 @@ def _autoset_attn_implementation( message += ( ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' ) - raise ValueError(message + ".") + if config._attn_implementation in ALL_ATTENTION_FUNCTIONS: + pass + else: + raise ValueError(message + ".") # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. requested_attn_implementation = config._attn_implementation_internal From caaa5e550845b569840fdf9d8cf96c209fbe5d89 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 12 Dec 2024 18:29:26 +0000 Subject: [PATCH 39/40] tgi update --- src/transformers/integrations/flash_attention.py | 13 +++++++++---- src/transformers/models/auto/modeling_task.py | 2 ++ src/transformers/models/llama/modeling_llama.py | 1 + 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 2189baef8158b6..06d7072408800d 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -4,15 +4,20 @@ def flash_attention_forward( - config, query, key, value, attentions_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs + config, query, key, value, attention_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs ): - if attentions_mask is not None: - seq_len = attentions_mask.shape[1] + if attention_mask is not None: + seq_len = attention_mask.shape[1] query = query[:, :, :seq_len] value = value[:, :, :seq_len] else: seq_len = query.shape[1] + # Re-transpose them + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + dropout_rate = config.attention_dropout if training else 0.0 input_dtype = query.dtype @@ -25,7 +30,7 @@ def flash_attention_forward( query, key, value, - attentions_mask, + attention_mask, seq_len, config=config, dropout=dropout_rate, diff --git a/src/transformers/models/auto/modeling_task.py b/src/transformers/models/auto/modeling_task.py index 61d3e78511e494..322a65e76f8371 100644 --- a/src/transformers/models/auto/modeling_task.py +++ b/src/transformers/models/auto/modeling_task.py @@ -25,6 +25,8 @@ class AutoForCausalLM(PreTrainedModel, GenerationMixin): _no_split_modules = [] _supports_cache_class = True config_class = AutoConfig + _supports_flash_attn_2 = True + _supports_sdpa = True def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b00c48766e3255..0b9897c6116fb6 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -269,6 +269,7 @@ def forward( if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + kwargs["layer_idx"] = self.layer_idx attn_output, attn_weights = attention_interface( self, query_states, From 5060a334de16361a9f0da117c9e39162fbcf94c2 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 13 Dec 2024 14:07:01 +0000 Subject: [PATCH 40/40] remove layer_idx --- src/transformers/models/llama/modeling_llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0b9897c6116fb6..b00c48766e3255 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -269,7 +269,6 @@ def forward( if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - kwargs["layer_idx"] = self.layer_idx attn_output, attn_weights = attention_interface( self, query_states,