diff --git a/examples/diff-conversion/README.md b/examples/diff-conversion/README.md new file mode 100644 index 00000000000000..a575a83b015c63 --- /dev/null +++ b/examples/diff-conversion/README.md @@ -0,0 +1,20 @@ +# Using the `diff_converter` linter + +`pip install libcst` is a must! + +# `sh examples/diff-conversion/convert_examples.sh` to get the converted outputs + +The diff converter is a new `linter` specific to `transformers`. It allows us to unpack inheritance in python to convert a modular `diff` file like `diff_gemma.py` into a `single model single file`. + +Examples of possible usage are available in the `examples/diff-conversion`, or `diff_gemma` for a full model usage. + +`python utils/diff_model_converter.py --files_to_parse "/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model2.py"` + +## How it works +We use the `libcst` parser to produce an AST representation of the `diff_xxx.py` file. For any imports that are made from `transformers.models.modeling_xxxx` we parse the source code of that module, and build a class dependency mapping, which allows us to unpack the difference dependencies. + +The code from the `diff` file and the class dependency mapping are "merged" to produce the single model single file. +We use ruff to automatically remove the potential duplicate imports. + +## Why we use libcst instead of the native AST? +AST is super powerful, but it does not keep the `docstring`, `comment` or code formatting. Thus we decided to go with `libcst` \ No newline at end of file diff --git a/examples/diff-conversion/convert_examples.sh b/examples/diff-conversion/convert_examples.sh new file mode 100644 index 00000000000000..1cfdc3e33cdf82 --- /dev/null +++ b/examples/diff-conversion/convert_examples.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# Iterate over each file in the current directory +for file in examples/diff-conversion/diff_*; do + # Check if it's a regular file + if [ -f "$file" ]; then + # Call the Python script with the file name as an argument + python utils/diff_model_converter.py --files_to_parse "$file" + fi +done \ No newline at end of file diff --git a/examples/diff-conversion/diff_dummy.py b/examples/diff-conversion/diff_dummy.py new file mode 100644 index 00000000000000..c5fd57f9f66eb5 --- /dev/null +++ b/examples/diff-conversion/diff_dummy.py @@ -0,0 +1,44 @@ +from math import log +from typing import List, Optional, Tuple, Union + +import torch + +from transformers import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import LlamaModel + + +def _pre_process_input(input_ids): + print(log(input_ids)) + return input_ids + + +# example where we need some deps and some functions +class DummyModel(LlamaModel): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + input_ids = _pre_process_input(input_ids) + + return super().forward( + None, + attention_mask, + position_ids, + past_key_values, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + cache_position, + ) diff --git a/examples/diff-conversion/diff_my_new_model.py b/examples/diff-conversion/diff_my_new_model.py new file mode 100644 index 00000000000000..dddcc1d61c11d6 --- /dev/null +++ b/examples/diff-conversion/diff_my_new_model.py @@ -0,0 +1,14 @@ +from transformers.models.llama.configuration_llama import LlamaConfig + + +# Example where we only want to only add a new config argument and new arg doc +# here there is no `ARG` so we are gonna take parent doc +class MyNewModelConfig(LlamaConfig): + r""" + mlp_bias (`bool`, *optional*, defaults to `False`) + """ + + def __init__(self, mlp_bias=True, new_param=0, **super_kwargs): + self.mlp_bias = mlp_bias + self.new_param = new_param + super().__init__(self, **super_kwargs) diff --git a/examples/diff-conversion/diff_my_new_model2.py b/examples/diff-conversion/diff_my_new_model2.py new file mode 100644 index 00000000000000..2e449e06b16225 --- /dev/null +++ b/examples/diff-conversion/diff_my_new_model2.py @@ -0,0 +1,31 @@ +from transformers.models.gemma.modeling_gemma import GemmaForSequenceClassification +from transformers.models.llama.configuration_llama import LlamaConfig + + +# Example where we only want to only modify the docstring +class MyNewModel2Config(LlamaConfig): + r""" + This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma-7B. + e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GemmaModel`] + ```python + >>> from transformers import GemmaModel, GemmaConfig + >>> # Initializing a Gemma gemma-7b style configuration + >>> configuration = GemmaConfig() + >>> # Initializing a model from the gemma-7b style configuration + >>> model = GemmaModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + +# Example where alllllll the dependencies are fetched to just copy the entire class +class MyNewModel2ForSequenceClassification(GemmaForSequenceClassification): + pass diff --git a/examples/diff-conversion/diff_new_model.py b/examples/diff-conversion/diff_new_model.py new file mode 100644 index 00000000000000..1486d40c6cdbd5 --- /dev/null +++ b/examples/diff-conversion/diff_new_model.py @@ -0,0 +1,30 @@ +# Example where we only want to overwrite the defaults of an init + +from transformers.models.gemma.configuration_gemma import GemmaConfig + + +class NewModelConfig(GemmaConfig): + def __init__( + self, + vocab_size=256030, + hidden_size=64, + intermediate_size=90, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + hidden_activation=None, + max_position_embeddings=1500, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + ): + super().__init__(self) diff --git a/examples/diff-conversion/diff_super.py b/examples/diff-conversion/diff_super.py new file mode 100644 index 00000000000000..160f067ee01b85 --- /dev/null +++ b/examples/diff-conversion/diff_super.py @@ -0,0 +1,38 @@ +from typing import List, Optional, Tuple, Union + +import torch + +from transformers import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import LlamaModel + + +# example where we need some deps and some functions +class SuperModel(LlamaModel): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + out = super().forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + cache_position, + ) + out.logits *= 2**4 + return out diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index 3bf296a63b22fc..6d2418ee1c31cc 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -1,5 +1,12 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the diff. If any change should be done, please apply the change to the +# diff.py file directly. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +19,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Gemma model configuration""" - -from ...configuration_utils import PretrainedConfig -from ...utils import logging -logger = logging.get_logger(__name__) +from transformers import PretrainedConfig class GemmaConfig(PretrainedConfig): @@ -26,13 +29,9 @@ class GemmaConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma-7B. - e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - - Args: vocab_size (`int`, *optional*, defaults to 256000): Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the @@ -83,16 +82,12 @@ class GemmaConfig(PretrainedConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - ```python >>> from transformers import GemmaModel, GemmaConfig - >>> # Initializing a Gemma gemma-7b style configuration >>> configuration = GemmaConfig() - >>> # Initializing a model from the gemma-7b style configuration >>> model = GemmaModel(configuration) - >>> # Accessing the model configuration >>> configuration = model.config ```""" diff --git a/src/transformers/models/gemma/diff_gemma.py b/src/transformers/models/gemma/diff_gemma.py new file mode 100644 index 00000000000000..1f9645ade6021b --- /dev/null +++ b/src/transformers/models/gemma/diff_gemma.py @@ -0,0 +1,507 @@ +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers import PretrainedConfig +from transformers.models.llama.modeling_llama import ( + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaModel, + apply_rotary_pos_emb, + repeat_kv, +) + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...modeling_outputs import CausalLMOutputWithPast +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class GemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma-7B. + e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GemmaModel`] + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 24576): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The legacy activation function. It is overwritten by the `hidden_activation`. + hidden_activation (`str` or `function`, *optional*): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ```python + >>> from transformers import GemmaModel, GemmaConfig + >>> # Initializing a Gemma gemma-7b style configuration + >>> configuration = GemmaConfig() + >>> # Initializing a model from the gemma-7b style configuration + >>> model = GemmaModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=256000, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + hidden_activation=None, + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.hidden_activation = hidden_activation + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class GemmaRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + +ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm) + + +class GemmaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class GemmaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + if config.hidden_activation is None: + logger.warning_once( + "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n" + "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n" + "`config.hidden_activation` if you want to override this behaviour.\n" + "See https://github.com/huggingface/transformers/pull/29402 for more details." + ) + config.hidden_activation = "gelu_pytorch_tanh" + hidden_activation = config.hidden_activation + self.act_fn = ACT2FN[hidden_activation] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class GemmaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.rotary_emb = GemmaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class GemmaModel(LlamaModel): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False # noqa: F841 + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True # noqa: F841 + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # normalized + # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + + return super().forward( + causal_mask, + position_ids, + past_key_values, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + cache_position, + input_ids=None, + inputs_embeds=hidden_states, + ) + + +# Example where we ony modify the docstring and call super +class GemmaForCausalLM(LlamaForCausalLM): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class GemmaForSequenceClassification(LlamaForSequenceClassification): + pass + + +class GemmaForTokenClassification(LlamaForTokenClassification): + pass diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 474dccf3081d49..ff0f7082e95293 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the diff. If any change should be done, please apply the change to the +# diff.py file directly. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # @@ -13,8 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch Gemma model.""" - import math from typing import List, Optional, Tuple, Union @@ -26,10 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache -from ...modeling_attn_mask_utils import ( - AttentionMaskConverter, - _prepare_4d_causal_attention_mask, -) +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -37,7 +38,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 +from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -46,7 +47,6 @@ logging, replace_return_docstrings, ) -from ...utils.import_utils import is_torch_fx_available from .configuration_gemma import GemmaConfig @@ -55,25 +55,14 @@ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa -# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. -# It means that the function will not be traced through and simply appear as a node in the graph. -if is_torch_fx_available(): - if not is_torch_greater_or_equal_than_1_13: - import torch.fx - - _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) - - logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "GemmaConfig" - def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, @@ -108,7 +97,6 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @@ -130,7 +118,35 @@ def forward(self, x, position_ids, seq_len=None): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# Copied from transformers.models.llama.modeling_llama.rotate_half +class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding): + """GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding): + """GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -138,7 +154,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -190,7 +205,6 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -# Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -206,7 +220,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class GemmaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - # Ignore copy def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config @@ -303,7 +316,6 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Gemma class GemmaFlashAttention2(GemmaAttention): """ Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays @@ -319,7 +331,6 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - # Ignore copy def forward( self, hidden_states: torch.Tensor, @@ -329,13 +340,13 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" ) + output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -351,8 +362,8 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -397,7 +408,7 @@ def forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: @@ -503,7 +514,6 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemma class GemmaSdpaAttention(GemmaAttention): """ Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -511,7 +521,7 @@ class GemmaSdpaAttention(GemmaAttention): SDPA API. """ - # Ignore copy + # Adapted from GemmaAttention.forward def forward( self, hidden_states: torch.Tensor, @@ -548,8 +558,8 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -584,7 +594,7 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -598,7 +608,6 @@ def forward( } -# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma class GemmaDecoderLayer(nn.Module): def __init__(self, config: GemmaConfig, layer_idx: int): super().__init__() @@ -692,9 +701,8 @@ class GemmaPreTrainedModel(PreTrainedModel): config_class = GemmaConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"] _no_split_modules = ["GemmaDecoderLayer"] - _skip_keys_device_placement = ["past_key_values", "causal_mask"] + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True @@ -713,6 +721,9 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() +_CONFIG_FOR_DOC = "GemmaConfig" + + GEMMA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -821,7 +832,6 @@ def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) - # Ignore copy def forward( self, input_ids: torch.LongTensor = None, @@ -989,6 +999,8 @@ def _update_causal_mask( if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: causal_mask = torch.full( @@ -1020,7 +1032,6 @@ def _update_causal_mask( return causal_mask -# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->GEMMA,Llama->Gemma,llama->gemma class GemmaForCausalLM(GemmaPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] @@ -1051,7 +1062,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - # Ignore copy @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1244,7 +1254,6 @@ def _reorder_cache(past_key_values, beam_idx): """, GEMMA_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->GEMMA,Llama->Gemma class GemmaForSequenceClassification(GemmaPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1360,7 +1369,6 @@ def forward( """, GEMMA_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Gemma, LLAMA->GEMMA class GemmaForTokenClassification(GemmaPreTrainedModel): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 226d14c18b991c..836528ee2104e7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -17,8 +17,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch LLaMA model.""" - import math from typing import List, Optional, Tuple, Union diff --git a/utils/check_copies.py b/utils/check_copies.py index c4fa2fbaa0ca3d..b50f5845886b0b 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -559,8 +559,11 @@ def get_indent(code: str) -> str: return "" -def run_ruff(code): - command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"] +def run_ruff(code, check=False): + if check: + command = ["ruff", "check", "-", "--fix", "--exit-zero"] + else: + command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"] process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE) stdout, _ = process.communicate(input=code.encode()) return stdout.decode() diff --git a/utils/diff_model_converter.py b/utils/diff_model_converter.py new file mode 100644 index 00000000000000..d9786a9b3c49fb --- /dev/null +++ b/utils/diff_model_converter.py @@ -0,0 +1,555 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import glob +import importlib +import re +from typing import Dict + +import libcst as cst +from check_copies import run_ruff +from libcst import ClassDef, CSTTransformer, CSTVisitor +from libcst import matchers as m +from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider + +from transformers import logging + + +logger = logging.get_logger(__name__) + + +AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the diff. If any change should be done, please apply the change to the +# diff.py file directly. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +""" + + +def get_module_source_from_name(module_name: str) -> str: + # Extract the source code from the module name + spec = importlib.util.find_spec(module_name) + if spec is None or spec.origin is None: + return f"Module {module_name} not found" + + with open(spec.origin, "r") as file: + source_code = file.read() + return source_code + + +class ClassFinder(CSTVisitor): + """A visitor class which analyses a module, creating a mapping of dependencies between classes and functions. + For example if the visited code has + ```python3 + def init_value(): return 1 + + class LlamaModel(PreTrainedModel): + def __init__(self): + super().__init__(self) + self.value = init_value() + ``` + then the `class_dependency_mapping` should be: `{"LlamaModel":["PreTrainedModel","init_value"], "init_value":[]} + + The dependency mapping is updated via the `visit_Name`, `visit_Arg` and `visit_Decorator`. This is very broad, and by + checking the parent node, or the scope of a `cst.Name` or `cst.Arg` or `cst.Decorator` we are able to map the + dependence parent -> child. + + When visiting such nodes, we update the dependency of the parent node, to take into account the visited node. + + All `visit_XXX` correspond to the code executed when vising the cst.Node of type XXX. + """ + + METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) + + def __init__(self, python_module: cst.Module): + # fmt: off + self.python_module: cst.Module = python_module # original cst.Module being visited + self.classes: Dict[str, cst.ClassDef] = {} # stores a mapping from classname to the cst.Node + self.imports = {} # stores all import statements + self.function_def = {} # stores global scope function definition + self.assignments = {} # LLAMA_DOCSTRING + self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"] + # fmt: on + + def _update_class_dependency(self, name, value): + """Update the dependency mapping for `name` with `value` by appending the previous + dependencies to the new `value`. + """ + dep = set(self.class_dependency_mapping.get(value, set())) + dep |= set(self.class_dependency_mapping.get(name, {})) | set({value}) + self.class_dependency_mapping[name] = dep + + def visit_ClassDef(self, node: ClassDef) -> None: + """We don't have non global scope class defs in transformers. Here we add the inheritance dependencies""" + self.classes[node.name.value] = node + for k in node.bases: # deal with inheritance + base_name = self.python_module.code_for_node(k) + self._update_class_dependency(node.name.value, base_name) + + def visit_SimpleStatementLine(self, node): + """ + Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT' and all import statements + are extracted and saved in their corresponding dict. They are then used when updating dependency mappings. + """ + if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches( + self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module() + ): + self.assignments[node.body[0].targets[0].target.value] = node + if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): + self.imports[node.body[0].names] = node + + def visit_FunctionDef(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + self.function_def[node.name.value] = node + + def leave_If(self, node): + for stmt in node.body.body: + if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): + self.imports[stmt.body[0].names] = node + + def leave_Name(self, node): + if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys(): + parent = self.get_metadata(cst.metadata.ScopeProvider, node) + if not isinstance(parent, cst.metadata.scope_provider.GlobalScope): + self._update_class_dependency(parent._name_prefix.split(".")[0], node.value) + + def leave_Arg(self, node): + if m.matches(node.value, m.Name()): + parent = self.get_metadata(ParentNodeProvider, node) + if m.matches(parent, m.ClassDef()) and parent.bases: + self._update_class_dependency(parent.name.value, node.value.value) + + def leave_Dict(self, node): + parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent, m.Assign(targets=[m.AssignTarget()])): + name = parent.targets[0].target.value + if name in self.assignments: + for k in node.elements: + dep_name = k.value.value + if dep_name in self.classes: + self._update_class_dependency(name, dep_name) + + def leave_Decorator(self, node): + if hasattr(node.decorator, "args"): + for k in node.decorator.args: + if k.value.value in self.assignments: + parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) + scope = self.get_metadata(cst.metadata.ScopeProvider, node) + name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value + self._update_class_dependency(name, k.value.value) + + def leave_Module(self, node): + """When leaving the module, we store the position of each global scoped node (Assigns, function def and class def) + to allow sorting the dependencies based on their position in the code. We use the PositionProvider metadata wrapper for this. + """ + self.global_nodes = {**self.assignments, **self.classes, **self.function_def} + # now sort the class dependency_mapping based on the position of the nodes + self.class_start_line = {} + for id, node in self.global_nodes.items(): + self.class_start_line[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line + + +class ReplaceNameTransformer(m.MatcherDecoratableTransformer): + """A transformer that replaces `old_name` with `new_name` in comments, string and any references. + It should take into account name like `MyNewModel`, or `my_new_model`. Without using the AUTO_MAPPING. + Supported renaming patterns: + - llama -> my_new_model and my_new_model -> llama + - Llama -> MyNewModel and MyNewModel -> Llama + - LLAMA -> MY_NEW_MODEL and MY_NEW_MODEL -> LLAMA + - LLaMa -> MyNewModel abd MyNewModel -> Llama + """ + + def __init__(self, old_name, new_name): + super().__init__() + self.old_name = old_name + self.new_name = new_name + self.default_name = "".join(x.title() for x in new_name.split("_")) + self.patterns = { + old_name: new_name, + old_name.upper(): new_name.upper(), + "".join(x.title() for x in old_name.split("_")): self.default_name, + } + + def preserve_case_replace(self, text): + # Create a regex pattern to match all variations + regex_pattern = "|".join(re.escape(key) for key in self.patterns.keys()) + compiled_regex = re.compile(regex_pattern, re.IGNORECASE) + + def replace(match): + word = match.group(0) + return self.patterns.get(word, self.default_name) + + return compiled_regex.sub(replace, text) + + @m.leave(m.Name() | m.SimpleString() | m.Comment()) + def replace_name(self, original_node, updated_node): + update = self.preserve_case_replace(updated_node.value) + return updated_node.with_changes(value=update) + + +def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma"): + """Helper function to rename and then parse a source file using the ClassFinder""" + transformer = ReplaceNameTransformer(old_id, new_id) + new_module = module.visit(transformer) + + wrapper = MetadataWrapper(new_module) + + class_finder = ClassFinder(new_module) + wrapper.visit(class_finder) + return class_finder + + +DOCSTRING_NODE = m.SimpleStatementLine( + body=[ + m.Expr( + value=m.SimpleString( + # match anything between """ """ + value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None) + ) + ) + ] +) + + +class SuperTransformer(cst.CSTTransformer): + METADATA_DEPENDENCIES = (ParentNodeProvider,) + + def __init__(self, python_module: cst.Module, original_methods, updated_methods): + self.python_module = python_module + self.original_methods = original_methods + self.updated_methods = updated_methods + + def update_body(self, existing_body, new_statements): + """ + Helper method to update the body by removing duplicates before adding new statements. + """ + deduplicated_new_body = [] + existing_nodes = { + self.python_module.code_for_node(node).strip() for node in new_statements if isinstance(node, cst.CSTNode) + } + for stmt in existing_body: + if self.python_module.code_for_node(stmt).strip() not in existing_nodes: + if m.matches(stmt, DOCSTRING_NODE) and self.has_docstring: + continue + deduplicated_new_body.append(stmt) + existing_nodes.add(stmt) + else: + logger.info(f"\nFound duplicate {self.python_module.code_for_node(stmt)}") + return deduplicated_new_body + + def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode: + """Updates the body of the input `node`'s `func_name` function by replacing calls + to super().func_name() with the source code of the parent class' `func_name`. + It keeps everything that is defined before `super().func_name()`. + """ + new_body = [] + self.has_docstring = False + for expr in node.body: + self.has_docstring = m.matches(node.body[0], DOCSTRING_NODE) + if m.matches( + expr, + m.SimpleStatementLine( + body=[ + m.Return( + value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name))) + ) + | m.Expr( + value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name))) + ) + ] + ), + ): + new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body)) + else: + new_body.append(expr) + return node.with_changes(body=new_body) + + def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode: + if updated_node.name.value in self.updated_methods: + name = updated_node.name.value + new_body = self.replace_super_calls(updated_node.body, name) + return updated_node.with_changes(body=new_body, params=updated_node.params) + return updated_node + + def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode: + """ "When a return statement is reached, it is replaced with the unrolled super code""" + if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))): + func_def = self.get_metadata(ParentNodeProvider, original_node) + if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods: + updated_return_value = updated_node.value.with_changes( + args=[ + cst.Arg( + value=cst.Call(func=cst.Name("super"), args=[cst.Arg(value=cst.Name(func_def.name.value))]) + ) + ] + ) + return updated_node.with_changes(value=updated_return_value) + return updated_node + + +def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str): + """ + Given the `class_name`, the `updated_node`'s call to super are unpacked. + + | ```python | | ```python + | class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module): + | def __init__(self): | | def __init__(self): + Going from: | self.dropout = 0.2 | to: | self.dropout = 0.2 + | super().__init__() | | super().__init__(config) + | ``` | | self.padding_idx = config.pad_token_id + | self.vocab_size = config.vocab_size + | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + | self.layers = nn.ModuleList( + | [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + | ) + | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + | self.gradient_checkpointing = False + | # Initialize weights and apply final processing + | self.post_init() + | ``` + """ + original_node = class_finder.classes[class_name] + original_methods = {f.name.value if hasattr(f, "name") else f: f for f in original_node.body.body} + updated_methods = {f.name.value if hasattr(f, "name") else f: f for f in updated_node.body.body} + end_meth = [] + for name, func in original_methods.items(): + if name in updated_methods and updated_methods[name] is not None: + new_params = updated_methods[name].params + # Replace the method in the replacement class, preserving decorators + kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None) + if kwarg_name and kwarg_name.name.value == "super_kwargs": + parent_params = {k.name.value: k for k in func.params.params} + parent_params.update({k.name.value: k for k in new_params.params[1:]}) + new_params = new_params.with_changes( + params=list(parent_params.values()), star_kwarg=func.params.star_kwarg + ) + func = func.with_changes(body=updated_methods[name].body, params=new_params) + end_meth.append(func) + + result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth)) + temp_module = cst.Module(body=[result_node]) + new_module = MetadataWrapper(temp_module) + new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods)) + new_replacement_body = new_replacement_class.body[0].body # get the indented block + return original_node.with_changes(body=new_replacement_body) + + +class DiffConverterTransformer(CSTTransformer): + METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) + + def __init__(self, python_module, new_name): + super().__init__() + self.model_name = ( + new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` our `phi3` + ) + # fmt: off + self.python_module = python_module # we store the original module to use `code_for_node` + self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module + self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"} + self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama" + self.new_body = {} # store the new body, all global scope nodes should be added here + self.inserted_deps = [] # nodes inserted via super dependency + self.all_imports = [] # just stores all of the imports + self.global_scope_index = 0 + # fmt: on + + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + """When visiting imports from `transformers.models.xxx` we need to: + 1. Get the original source code + 2. Parse it into an AST Tree + 3. Add this import to `self.transformers_imports` as visited to not parse it twice + """ + import_statement = self.python_module.code_for_node(node.module) + if m.matches(node.module, m.Attribute()): + for imported_ in node.names: + _import = re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", import_statement) + if _import: + source = _import.groups()[0] + if source == "modeling" and "Config" in self.python_module.code_for_node(imported_): + raise ValueError( + f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead" + ) + if import_statement not in self.transformers_imports: + source_code = get_module_source_from_name(import_statement) + tree = cst.parse_module(source_code) + self.transformers_imports[import_statement] = tree + imported_class = self.python_module.code_for_node(imported_.name) + self.imported_mapping[imported_class] = import_statement + + def leave_FunctionDef(self, original_node, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) + if m.matches(parent_node, m.Module()): + self.global_scope_index += 100 + self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node} + return node + + def leave_SimpleStatementLine(self, original_node, updated_node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) + if m.matches(parent_node, m.Module()): + if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])): + if parent_node not in self.all_imports: + self.all_imports.append(updated_node) + return updated_node + elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])): + full_statement = self.python_module.code_for_node(updated_node.body[0].module) + if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", full_statement): + return cst.RemoveFromParent() + if parent_node not in self.all_imports: + self.all_imports.append(updated_node) + return updated_node + self.global_scope_index += 100 + if m.matches(updated_node, m.SimpleStatementLine(body=[m.Assign()])): + # TODO This only works for single target assigns! + node_name = updated_node.body[0].targets[0].target.value + else: + node_name = self.python_module.code_for_node(updated_node.body[0]) + self.new_body[node_name] = { + "insert_idx": self.global_scope_index, + "node": updated_node, + } + return updated_node + + def leave_ClassDef(self, original_node, updated_node): + """ + 1. Filter the `base` classes of this class + If they are from `transformers.models.xx` then: + - take the AST tree of the module it comes from and parse it with a `ClassFinder`. + - rename all every instance of `old_name` (llama) to `new_name` (gemma) + 2. We insert the modules which the inherited base depends on. This has to be done in + the order of the dependencies. If on is already in the new_body (because it's defined in the diff file) + then we remove it from the new body to add it again in the correct order. + 3. Replace the calls to `super().xxxx` merging parent code + """ + class_name = original_node.name.value + bases = [k.value.value for k in original_node.bases if k.value.value in self.imported_mapping] + self.global_scope_index += 100 + for super_class in bases: + if super_class not in self.imported_mapping: + raise ImportError( + f"{super_class} was not imported using `from transformers.models.xxxxx.modeling_xxxx import {super_class}" + ) + + super_file_name = self.imported_mapping[super_class] # we need to get the parsed tree + model_name = re.search(r"_(\S*)", super_file_name) + if model_name: + model_name = model_name.groups()[0] + else: + raise ValueError( + f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name" + ) + + if super_file_name not in self.visited_module: # only extract classes once + class_finder = find_classes_in_file( + self.transformers_imports[super_file_name], model_name, self.model_name + ) + self.visited_module[super_file_name] = class_finder + else: # we are re-using the previously parsed data + class_finder = self.visited_module[super_file_name] + + list_dependencies = { + dep: class_finder.class_start_line.get(dep, 1000) + for dep in class_finder.class_dependency_mapping.get(class_name, []) + } + + list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True) + start_insert_idx = self.global_scope_index + for dependency, _ in list_dependencies: + node = class_finder.global_nodes.get(dependency, None) + if node is not None: + if dependency not in self.new_body: + start_insert_idx -= 1 + self.new_body[dependency] = {"insert_idx": start_insert_idx, "node": node} + elif dependency not in self.inserted_deps: + # make sure the node is written after it's dependencies + start_insert_idx = self.new_body[dependency]["insert_idx"] - 1 + self.inserted_deps.append(dependency) + if len(list_dependencies) > 0: + updated_node = replace_call_to_super(class_finder, updated_node, class_name) + if "Config" in class_name: + self.config_body = [updated_node] + else: + self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} + return updated_node + + def leave_If(self, original_node, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) + if m.matches(parent_node, m.Module()): + full_statement = self.python_module.code_for_node(original_node.test) + if re.search(r"[\s\S]*is_.*available", full_statement): + self.all_imports.append(node) + elif full_statement not in self.new_body: + self.new_body[node] = {"insert_idx": self.global_scope_index, "node": node} + return node + + def leave_Module(self, original_node: cst.Assign, node): + imports = {self.python_module.code_for_node(k): k for k in self.all_imports} + dependency_imports = {} + for visiter in self.visited_module.values(): + dependency_imports.update({self.python_module.code_for_node(k): k for k in visiter.imports.values()}) + if hasattr(self, "config_body"): + self.config_body = list(imports.values()) + self.config_body + dependency_imports.update(imports) + new_body = list(dependency_imports.values()) + if len(self.new_body.keys()) > 0: + new_body += [k[1]["node"] for k in sorted(self.new_body.items(), key=lambda x: x[1]["insert_idx"])] + else: + new_body = [] + return node.with_changes(body=[*new_body]) + + +def convert_file(diff_file, cst_transformers=None): + model_name = re.search(r"diff_(.*)(?=\.py$)", diff_file).groups()[0] + # Parse the Python file + with open(diff_file, "r") as file: + code = file.read() + module = cst.parse_module(code) + wrapper = MetadataWrapper(module) + if cst_transformers is None: + cst_transformers = DiffConverterTransformer(module, model_name) + new_mod = wrapper.visit(cst_transformers) + ruffed_code = run_ruff(new_mod.code, True) + formatted_code = run_ruff(ruffed_code, False) + if len(formatted_code.strip()) > 0: + with open(diff_file.replace("diff_", "modeling_"), "w") as f: + f.write(AUTO_GENERATED_MESSAGE + formatted_code) + + if hasattr(cst_transformers, "config_body"): + config_module = cst.Module(body=[*cst_transformers.config_body], header=new_mod.header) + with open(diff_file.replace("diff_", "configuration_"), "w") as f: + ruffed_code = run_ruff(config_module.code, True) + formatted_code = run_ruff(ruffed_code, False) + f.write(AUTO_GENERATED_MESSAGE + formatted_code) + + # TODO optimize by re-using the class_finder + return cst_transformers + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--files_to_parse", + default=["/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model.py"], + nargs="+", + help="A list of `diff_xxxx` files that should be converted to single model file", + ) + args = parser.parse_args() + if args.files_to_parse == ["all"]: + args.files_to_parse = glob.glob("src/transformers/models/**/diff_*.py", recursive=True) + for file_name in args.files_to_parse: + print(f"Converting {file_name} to a single model single file format") + module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "") + converter = convert_file(file_name)