From 4bff54f9214e2156e30981eddf7b32cb4e655cba Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 19 Nov 2024 13:52:38 +0100 Subject: [PATCH] Gemma capping (#34282) * softcapping * soft cap before the mask * style * ... * super nit * update * fixes * update * small issue with modular * fix modular imports * update * fixup * simplify a hell lot * simplify cleaning imports * finish fixing * update our design * nits * use a deprecation cycle * updates * Fix modular (recursive deps need to always be computed after merges!) * push * fix * update * fix modular order * make fix-copies * updates * update * ? * don't compile for now * ? * fix some stuff * donc! * fix copies * update * fixup * ? * fix two tests * fix? * for now, don't use head info * eager when output attentoin and sdpa or flash as it's the simplest behaviour (for our tests as well :)) * fix-copies * revert sdpa check * Apply suggestions from code review Co-authored-by: Cyril Vallez * rebase, fix-copies and push * add a slow integration test * update the test * fix left padding issue * fix test * remove duplicate scaling * quality * add a small test and make sure it works * 2b --------- Co-authored-by: Cyril Vallez Co-authored-by: Cyril Vallez --- src/transformers/modeling_utils.py | 1 + .../models/gemma2/modeling_gemma2.py | 414 +++++++---------- .../models/gemma2/modular_gemma2.py | 418 ++++++++---------- src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 8 + tests/generation/test_utils.py | 2 +- tests/models/gemma2/test_modeling_gemma2.py | 60 ++- utils/modular_model_converter.py | 58 +-- 8 files changed, 426 insertions(+), 536 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7672df0b9a0e46..d68166d5268404 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1519,6 +1519,7 @@ def _autoset_attn_implementation( "eager", "sdpa", "flash_attention_2", + "flex_attention", ]: message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' if cls._supports_flash_attn_2: diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index c439ec069f7a93..6111261830b8f0 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -41,7 +41,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal, - is_flash_attn_greater_or_equal_2_10, + is_torch_greater_or_equal, logging, replace_return_docstrings, ) @@ -51,6 +51,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_greater_or_equal("2.5"): + from torch.nn.attention.flex_attention import flex_attention logger = logging.get_logger(__name__) @@ -168,6 +170,127 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(config, query, key, value, mask, **_kwargs): + key_states = repeat_kv(key, config.num_key_value_groups) + value_states = repeat_kv(value, config.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + + if config.attn_logit_softcapping is not None: + attn_weights = attn_weights / config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * config.attn_logit_softcapping + if mask is not None: # no matter the length, we just slice it + causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs): + if mask is not None: + seq_len = mask.shape[1] + query = query[:, :, :seq_len] + value = value[:, :, :seq_len] + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + value_states = value.transpose(1, 2) + + dropout_rate = config.attention_dropout if config.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + mask, + seq_len, + dropout=dropout_rate, + softmax_scale=config.scaling, + is_causal=config.is_causal, + sliding_window=config.sliding_window, + use_top_left_mask=config._flash_attn_uses_top_left_mask, + softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, + ) + + return attn_output, None + + +def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): + def tanh_softcap(score, b, h, q_idx, kv_idx): + soft_cap = config.attn_logit_softcapping + score = soft_cap * torch.tanh(score / soft_cap) + if mask is not None: + return score + mask[b][0][q_idx][kv_idx] + return score + + attn_output = flex_attention( + query, + key, + value, + score_mod=tanh_softcap, + enable_gqa=True, + scale=config.scaling, + return_lse=output_attentions, + ) + if not output_attentions: + return attn_output, None + else: + return attn_output[0], attn_output[1] + + +def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): + key = repeat_kv(key, config.num_key_value_groups) + value = repeat_kv(value, config.num_key_value_groups) + + causal_mask = mask + if mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query.device.type == "cuda" and causal_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and query.shape[1] > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=config.attention_dropout if config.training else 0.0, + is_causal=is_causal, + scale=config.scaling, + ) + return attn_output, None + + +GEMMA2_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "flex_attention": flex_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -175,12 +298,6 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size @@ -192,7 +309,8 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.rope_theta = config.rope_theta self.is_causal = True self.scaling = config.query_pre_attn_scalar**-0.5 - + self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None + self.attn_logit_softcapping = config.attn_logit_softcapping if self.hidden_size % self.num_heads != 0: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" @@ -208,7 +326,6 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) - self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None def forward( self, @@ -243,33 +360,17 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if self.config.attn_logit_softcapping is not None: - attn_weights = attn_weights / self.config.attn_logit_softcapping - attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * self.config.attn_logit_softcapping - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: + logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") + attention_type = "eager" + else: + attention_type = self.config._attn_implementation - attn_output = attn_output.transpose(1, 2).contiguous() + attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( + self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions + ) - attn_output = attn_output.view(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: @@ -279,230 +380,36 @@ def forward( class Gemma2FlashAttention2(Gemma2Attention): - """ - Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - if attention_mask is not None: - seq_len = attention_mask.shape[1] - key_states = key_states[:, :, :seq_len] - value_states = value_states[:, :, :seq_len] - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (Gemma2RMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - softmax_scale=self.scaling, - is_causal=self.is_causal, - sliding_window=self.sliding_window, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.config._attn_implementation = "flash_attention_2" + logger.warning_once( + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - class Gemma2SdpaAttention(Gemma2Attention): - """ - Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Gemma2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Gemma2Model is using Gemma2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - scale=self.scaling, + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.config._attn_implementation = "sdpa" + logger.warning_once( + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -GEMMA2_ATTENTION_CLASSES = { - "eager": Gemma2Attention, - "flash_attention_2": Gemma2FlashAttention2, - "sdpa": Gemma2SdpaAttention, -} - class Gemma2DecoderLayer(nn.Module): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.config = config + self.is_sliding = not bool(layer_idx % 2) + self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.config = config - self.is_sliding = not bool(layer_idx % 2) + self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window @@ -517,25 +424,6 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding # Flash-attn is a 2D tensor if self.config._attn_implementation == "flash_attention_2": diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index ff2d42d671c3c4..8d86238632365f 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -29,18 +29,17 @@ from ...utils import ( is_flash_attn_2_available, is_flash_attn_greater_or_equal, - is_flash_attn_greater_or_equal_2_10, + is_torch_greater_or_equal, logging, ) from ..gemma.modeling_gemma import ( - GemmaAttention, - GemmaDecoderLayer, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification, GemmaModel, GemmaPreTrainedModel, GemmaRMSNorm, + GemmaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv, ) @@ -49,6 +48,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_greater_or_equal("2.5"): + from torch.nn.attention.flex_attention import flex_attention + _CHECKPOINT_FOR_DOC = "google/gemma2-7b" @@ -207,118 +209,183 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -class Gemma2Attention(GemmaAttention): +class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): + pass + + +def eager_attention_forward(config, query, key, value, mask, **_kwargs): + key_states = repeat_kv(key, config.num_key_value_groups) + value_states = repeat_kv(value, config.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + + if config.attn_logit_softcapping is not None: + attn_weights = attn_weights / config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * config.attn_logit_softcapping + if mask is not None: # no matter the length, we just slice it + causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs): + if mask is not None: + seq_len = mask.shape[1] + query = query[:, :, :seq_len] + value = value[:, :, :seq_len] + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + value_states = value.transpose(1, 2) + + dropout_rate = config.attention_dropout if config.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + mask, + seq_len, + dropout=dropout_rate, + softmax_scale=config.scaling, + is_causal=config.is_causal, + sliding_window=config.sliding_window, + use_top_left_mask=config._flash_attn_uses_top_left_mask, + softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, + ) + + return attn_output, None + + +def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): + def tanh_softcap(score, b, h, q_idx, kv_idx): + soft_cap = config.attn_logit_softcapping + score = soft_cap * torch.tanh(score / soft_cap) + if mask is not None: + return score + mask[b][0][q_idx][kv_idx] + return score + + attn_output = flex_attention( + query, + key, + value, + score_mod=tanh_softcap, + enable_gqa=True, + scale=config.scaling, + return_lse=output_attentions, + ) + if not output_attentions: + return attn_output, None + else: + return attn_output[0], attn_output[1] + + +def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): + key = repeat_kv(key, config.num_key_value_groups) + value = repeat_kv(value, config.num_key_value_groups) + + causal_mask = mask + if mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query.device.type == "cuda" and causal_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and query.shape[1] > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=config.attention_dropout if config.training else 0.0, + is_causal=is_causal, + scale=config.scaling, + ) + return attn_output, None + + +GEMMA2_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "flex_attention": flex_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + +class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if self.config.attn_logit_softcapping is not None: - attn_weights = attn_weights / self.config.attn_logit_softcapping - attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * self.config.attn_logit_softcapping - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + self.attn_logit_softcapping = config.attn_logit_softcapping + if self.hidden_size % self.num_heads != 0: raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." ) - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Gemma2FlashAttention2(Gemma2Attention): - """ - Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.rotary_emb = Gemma2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -336,57 +403,14 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if attention_mask is not None: - seq_len = attention_mask.shape[1] - key_states = key_states[:, :, :seq_len] - value_states = value_states[:, :, :seq_len] - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (Gemma2RMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: + logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") + attention_type = "eager" + else: + attention_type = self.config._attn_implementation - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - softmax_scale=self.scaling, - is_causal=self.is_causal, - sliding_window=self.sliding_window, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, + attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( + self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -398,105 +422,37 @@ def forward( return attn_output, attn_weights, past_key_value -class Gemma2SdpaAttention(Gemma2Attention): - """ - Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Gemma2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Gemma2Model is using Gemma2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - scale=self.scaling, +class Gemma2FlashAttention2(Gemma2Attention): + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.config._attn_implementation = "flash_attention_2" + logger.warning_once( + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value +class Gemma2SdpaAttention(Gemma2Attention): + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.config._attn_implementation = "sdpa" + logger.warning_once( + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" + ) -class Gemma2DecoderLayer(GemmaDecoderLayer): +class Gemma2DecoderLayer(nn.Module): def __init__(self, config: Gemma2Config, layer_idx: int): - super().__init__(config, layer_idx) + super().__init__() + self.hidden_size = config.hidden_size self.config = config self.is_sliding = not bool(layer_idx % 2) + self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) + self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 2a10bcaa3c9412..492642d61babb5 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -209,6 +209,7 @@ is_torch_fp16_available_on_device, is_torch_fx_available, is_torch_fx_proxy, + is_torch_greater_or_equal, is_torch_mlu_available, is_torch_mps_available, is_torch_musa_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 173aee9b1ac739..6306efa2face19 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -929,6 +929,14 @@ def is_flash_attn_greater_or_equal(library_version: str): return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version) +@lru_cache() +def is_torch_greater_or_equal(library_version: str): + if not _is_package_available("torch"): + return False + + return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version) + + def is_torchdistx_available(): return _torchdistx_available diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6630fc2ba9d152..b1d0042c65167f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1496,7 +1496,7 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature): next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] # They should result in very similar logits - self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5)) + torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, atol=1e-5, rtol=1e-5) @pytest.mark.generate def test_past_key_values_format(self): diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 7bca83f96d73ab..06116c4dbafbd4 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -199,19 +199,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l def test_sdpa_equivalence(self): pass - def test_eager_attention_loaded_by_default(self): - """Gemma 2 + SDPA = inferior results, because of the logit softcapping. Eager is the default.""" - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - - # Usually we enable SDPA by default, but not for Gemma2 - model = Gemma2Model(config) - self.assertTrue(model.config._attn_implementation == "eager") - - # We can still force SDPA - config._attn_implementation = "sdpa" - model = Gemma2Model(config) - self.assertTrue(model.config._attn_implementation == "sdpa") - @slow @require_torch_gpu @@ -277,9 +264,30 @@ def test_model_9b_pipeline_bf16(self): "Hi today I'm going to be talking about the history of the United States. The United States of America", ] - model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to( - torch_device - ) + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention" + ).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(model_id) + pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) + + output = pipe(self.input_text, max_new_tokens=20, do_sample=False, padding=True) + + self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0]) + self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1]) + + @require_read_token + def test_model_2b_pipeline_bf16_flex_attention(self): + # See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR + model_id = "google/gemma-2-2b" + # EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1960s and I am trying to find out what the average", + "Hi today I'm going to be talking about the 10 best anime of all time.\n\n1", + ] + + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention" + ).to(torch_device) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) @@ -365,3 +373,23 @@ def test_export_static_cache(self): ) ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) + + @require_read_token + def test_model_9b_bf16_flex_attention(self): + model_id = "google/gemma-2-9b" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", + "Hi today I'm going to be talking about the history of the United States. The United States of America", + ] + + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention" + ).to(torch_device) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=False) + + self.assertEqual(output_text, EXPECTED_TEXTS) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index b1dfa18a7a9c84..e5f6e34ece0e08 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -153,9 +153,9 @@ def __init__(self, all_bases: Set[str]): def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode: # Handle ClassB.call_to_method if ( - isinstance(original_node.value, cst.Name) + m.matches(original_node.value, m.Name()) and original_node.value.value in self.all_bases - and isinstance(original_node.attr, cst.Name) + and m.matches(original_node.attr, m.Name()) ): # Replace with super().call_to_method return updated_node.with_changes( @@ -163,10 +163,10 @@ def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attrib ) # Handle ClassB().call_to_method elif ( - isinstance(original_node.value, cst.Call) - and isinstance(original_node.value.func, cst.Name) + m.matches(original_node.value, m.Call()) + and m.matches(original_node.value.func, m.Name()) and original_node.value.func.value in self.all_bases - and isinstance(original_node.attr, cst.Name) + and m.matches(original_node.attr, m.Name()) ): # Replace with super().call_to_method return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super")))) @@ -174,16 +174,16 @@ def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attrib def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode: # Check if the function being called is of the form ClassB().func_a or ClassB.func_a - if isinstance(original_node.func, cst.Attribute) and ( + if m.matches(original_node.func, m.Attribute()) and ( # Match ClassB().func_a(...) ( - isinstance(original_node.func.value, cst.Call) - and isinstance(original_node.func.value.func, cst.Name) + m.matches(original_node.func.value, m.Call()) + and m.matches(original_node.func.value.func, m.Name()) and original_node.func.value.func.value in self.all_bases ) or # Match ClassB.func_a(...) - (isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases) + (m.matches(original_node.func.value, m.Name()) and original_node.func.value.value in self.all_bases) ): # Check if the first argument is 'self', and remove it if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")): @@ -632,8 +632,10 @@ def leave_Module(self, node): for id, node in self.global_nodes.items(): self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line - # Since we added every Name as part of `self.object_dependency_mapping`, we now remove those that - # are not part of the recorded objects (i.e. built-in variables, imports, etc) + def _restrict_dependencies_to_known_entities(self): + """Since we added every Name as part of `self.object_dependency_mapping`, we need to remove those that + are not part of the recorded objects in `self.global_nodes` (i.e. built-in variables, imports, etc). + This should be called only after all merging operations have been finalized!!""" global_objects = set(self.global_nodes.keys()) for object_name, dependencies in self.object_dependency_mapping.items(): self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects} @@ -814,6 +816,8 @@ def merge_modular_dependencies(self, classes, functions, assignments, object_map # Correctly re-set the global nodes at this point self.global_nodes.update(self.functions) self.global_nodes.update(self.assignments) + # Restrict the dependency mappings to the know entities to avoid Python's built-ins + self._restrict_dependencies_to_known_entities() # Create the global mapping of recursive dependencies for functions and assignments self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() @@ -1142,22 +1146,20 @@ def visit_SimpleStatementLine(self, node): if assigned_variable == "__all__": self.all_all_to_add = split_all_assignment(node) else: + self.current_assignment = assigned_variable self.assignments[assigned_variable] = node def leave_Module(self, node): """When we leave the modular file, we do the following in order: - 1. compute the nested (recursive) function and assignment dependencies - 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update + 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update its dependency graph with the new function and assignment definitions found in the modular - 3. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) + 2. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) + 3. compute the nested (recursive) function and assignment dependencies """ # Takes care of finalizing our visit super().leave_Module(node) - # 1. compute the nested (recursive) function and assignment dependencies - self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() - - # 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies + # 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies self.visited_modules = {} self.renamers = {} for file, module in self.model_specific_modules.items(): @@ -1177,10 +1179,13 @@ def leave_Module(self, node): # We record it so that we can rename classes later the exact same way self.renamers[file] = renamer - # 3. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the + # 2. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the # definitions found in the visited files self.merge_model_specific_imports(self.visited_modules) + # 3. compute the nested (recursive) function and assignment dependencies + self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() + # We need to keep track of which objects were imported directly into which modeling file to not add them wrongly later # Note that we may visit several of the same file types, thus we save them per file type, not file self.imported_objects_per_file = defaultdict(set) @@ -1200,9 +1205,9 @@ def merge_model_specific_imports(self, visited_modules): if object_name in visited_module.functions and object_name not in self.functions: self.functions[object_name] = visited_module.functions[object_name] self.added_objects_file_mapping[object_name] = file - dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None) + dependencies = visited_module.object_dependency_mapping.get(object_name, None) if dependencies is not None: - self.object_recursive_dependency_mapping[object_name] = dependencies + self.object_dependency_mapping[object_name] = dependencies for dep in dependencies: if dep not in self.global_nodes: self.added_objects_file_mapping[dep] = file @@ -1212,9 +1217,9 @@ def merge_model_specific_imports(self, visited_modules): elif object_name in visited_module.assignments and object_name not in self.assignments: self.assignments[object_name] = visited_module.assignments[object_name] self.added_objects_file_mapping[object_name] = file - dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None) + dependencies = visited_module.object_dependency_mapping.get(object_name, None) if dependencies is not None: - self.object_recursive_dependency_mapping[object_name] = dependencies + self.object_dependency_mapping[object_name] = dependencies for dep in dependencies: if dep not in self.global_nodes: self.added_objects_file_mapping[dep] = file @@ -1222,6 +1227,8 @@ def merge_model_specific_imports(self, visited_modules): # Do not forget to re-assign all nodes after the merge self.global_nodes = {**self.assignments, **self.classes, **self.functions} + # And restric dependencies to those nodes only + self._restrict_dependencies_to_known_entities() def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: """Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that @@ -1239,10 +1246,11 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: else: original_dependencies.append(dep) # Sort all lists according to the order in their respective file - all_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + all_dependencies = [] for file, dependencies in other_files_dependencies.items(): sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x]) all_dependencies += sorted_dependencies + all_dependencies += sorted(original_dependencies, key=lambda x: self.start_lines[x]) # Add all original node first, then merged ones (one file at a time) for dep in all_dependencies: @@ -1485,7 +1493,7 @@ def save_modeling_file(modular_file, converted_file): parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/gemma/modular_gemma.py"], + default=["src/transformers/models/gemma2/modular_gemma2.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", )