diff --git a/examples/diff-conversion/diff_dummy.py b/examples/diff-conversion/diff_dummy.py index b53ed3b8990204..c5fd57f9f66eb5 100644 --- a/examples/diff-conversion/diff_dummy.py +++ b/examples/diff-conversion/diff_dummy.py @@ -1,14 +1,18 @@ -from transformers.models.llama.modeling_llama import LlamaModel -from typing import * -import torch from math import log -from transformers.modeling_outputs import CausalLMOutputWithPast +from typing import List, Optional, Tuple, Union + +import torch + from transformers import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import LlamaModel + def _pre_process_input(input_ids): print(log(input_ids)) return input_ids + # example where we need some deps and some functions class DummyModel(LlamaModel): def forward( diff --git a/examples/diff-conversion/diff_my_new_model.py b/examples/diff-conversion/diff_my_new_model.py index d86fa8937506ff..3ede01855e5d0d 100644 --- a/examples/diff-conversion/diff_my_new_model.py +++ b/examples/diff-conversion/diff_my_new_model.py @@ -1,13 +1,13 @@ from transformers.models.llama.modeling_llama import LlamaConfig + + # Example where we only want to only add a new config argument and new arg doc # here there is no `ARG` so we are gonna take parent doc class MyNewModelConfig(LlamaConfig): r""" - mlp_bias (`bool`, *optional*, defaults to `False`) + mlp_bias (`bool`, *optional*, defaults to `False`) """ - def __init__( - self, - mlp_bias=False - ): + + def __init__(self, mlp_bias=False): self.mlp_bias = mlp_bias - super().__init__(self) \ No newline at end of file + super().__init__(self) diff --git a/examples/diff-conversion/diff_my_new_model2.py b/examples/diff-conversion/diff_my_new_model2.py index a6f9edf0032658..2e449e06b16225 100644 --- a/examples/diff-conversion/diff_my_new_model2.py +++ b/examples/diff-conversion/diff_my_new_model2.py @@ -1,5 +1,7 @@ from transformers.models.gemma.modeling_gemma import GemmaForSequenceClassification from transformers.models.llama.configuration_llama import LlamaConfig + + # Example where we only want to only modify the docstring class MyNewModel2Config(LlamaConfig): r""" @@ -23,6 +25,7 @@ class MyNewModel2Config(LlamaConfig): >>> configuration = model.config ```""" + # Example where alllllll the dependencies are fetched to just copy the entire class class MyNewModel2ForSequenceClassification(GemmaForSequenceClassification): pass diff --git a/examples/diff-conversion/diff_new_model.py b/examples/diff-conversion/diff_new_model.py index bf04fae1e3d289..1486d40c6cdbd5 100644 --- a/examples/diff-conversion/diff_new_model.py +++ b/examples/diff-conversion/diff_new_model.py @@ -1,6 +1,8 @@ # Example where we only want to overwrite the defaults of an init from transformers.models.gemma.configuration_gemma import GemmaConfig + + class NewModelConfig(GemmaConfig): def __init__( self, @@ -25,4 +27,4 @@ def __init__( attention_bias=False, attention_dropout=0.0, ): - super().__init__(self) \ No newline at end of file + super().__init__(self) diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index 8b766a60c2c3a2..6d2418ee1c31cc 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -1,7 +1,7 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the +# the file from the diff. If any change should be done, please apply the change to the # diff.py file directly. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 @@ -21,7 +21,6 @@ # limitations under the License. - from transformers import PretrainedConfig diff --git a/src/transformers/models/gemma/diff_gemma.py b/src/transformers/models/gemma/diff_gemma.py index d972c42f598782..1cc0f3d46d348e 100644 --- a/src/transformers/models/gemma/diff_gemma.py +++ b/src/transformers/models/gemma/diff_gemma.py @@ -162,6 +162,7 @@ def __init__( **kwargs, ) + class GemmaRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index bf79938bae7285..ba91aebcab41ad 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1,7 +1,7 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the +# the file from the diff. If any change should be done, please apply the change to the # diff.py file directly. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 @@ -24,29 +24,22 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn - -from transformers import PretrainedConfig +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache -from ...modeling_outputs import CausalLMOutputWithPast -from ...pytorch_utils import ALL_LAYERNORM_LAYERS -from ...utils import logging -import math -import torch.nn.functional as F -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, - QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -81,6 +74,7 @@ def _get_unpad_data(attention_mask): max_seqlen_in_batch, ) + class GemmaRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -734,6 +728,7 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + _CONFIG_FOR_DOC = "GemmaConfig" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8c4a208448e924..83f8c650a16366 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1,7 +1,7 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the +# the file from the diff. If any change should be done, please apply the change to the # diff.py file directly. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 @@ -30,6 +30,8 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint +from flash_attn import flash_attn_func, flash_attn_varlen_func +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -48,14 +50,13 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_llama import LlamaConfig -from flash_attn import flash_attn_func, flash_attn_varlen_func -from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + """PyTorch LLaMA model.""" diff --git a/utils/diff_model_converter.py b/utils/diff_model_converter.py index 861697cec63673..7fabe5770361d1 100644 --- a/utils/diff_model_converter.py +++ b/utils/diff_model_converter.py @@ -33,7 +33,7 @@ AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the +# the file from the diff. If any change should be done, please apply the change to the # diff.py file directly. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 """ @@ -330,8 +330,8 @@ def __init__(self, python_module): def leave_FunctionDef(self, original_node, node): parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) if m.matches(parent_node, m.Module()): - self.global_scope_index += 100 - self.new_body[node.name.value] = {"insert_idx":self.global_scope_index, "node":node} + self.global_scope_index += 100 + self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node} return node def visit_ImportFrom(self, node: cst.ImportFrom) -> None: @@ -364,7 +364,10 @@ def leave_SimpleStatementLine(self, original_node, updated_node): parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) if m.matches(parent_node, m.Module()): self.global_scope_index += 100 - self.new_body[self.python_module.code_for_node(updated_node.body[0])] = {"insert_idx":self.global_scope_index, "node":updated_node} + self.new_body[self.python_module.code_for_node(updated_node.body[0])] = { + "insert_idx": self.global_scope_index, + "node": updated_node, + } return updated_node def leave_ClassDef(self, original_node, updated_node): @@ -402,22 +405,22 @@ def leave_ClassDef(self, original_node, updated_node): } list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True) - start_insert_idx = self.global_scope_index + start_insert_idx = self.global_scope_index for dependency, _ in list_dependencies: node = class_finder.global_nodes.get(dependency, None) if node is not None: if dependency not in self.new_body: start_insert_idx -= 1 - self.new_body[dependency] = {"insert_idx":start_insert_idx, "node":node} + self.new_body[dependency] = {"insert_idx": start_insert_idx, "node": node} elif dependency not in self.inserted_deps: # make sure the node is written after it's dependencies - start_insert_idx = self.new_body[dependency]["insert_idx"]-1 + start_insert_idx = self.new_body[dependency]["insert_idx"] - 1 self.inserted_deps.append(dependency) updated_node = replace_call_to_super(class_finder, updated_node, class_name) if "Config" in class_name: self.config_body = [updated_node] else: - self.new_body[class_name] = {"insert_idx":self.global_scope_index, "node":updated_node} + self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} return updated_node # def leave_If(self, original_node, node): @@ -431,14 +434,14 @@ def leave_ClassDef(self, original_node, updated_node): # return node def leave_Module(self, original_node: cst.Assign, node): - imports = {self.python_module.code_for_node(k):k for k in self.all_imports } + imports = {self.python_module.code_for_node(k): k for k in self.all_imports} for visiter in self.visited_module.values(): - imports.update({self.python_module.code_for_node(k):k for k in visiter.imports.values()}) + imports.update({self.python_module.code_for_node(k): k for k in visiter.imports.values()}) new_body = list(imports.values()) if hasattr(self, "config_body"): self.config_body = self.all_imports + self.config_body - new_body += [k[1]["node"] for k in sorted(self.new_body.items(), key=lambda x:x[1]["insert_idx"])] + new_body += [k[1]["node"] for k in sorted(self.new_body.items(), key=lambda x: x[1]["insert_idx"])] return node.with_changes(body=[*new_body]) @@ -451,7 +454,7 @@ def convert_file(diff_file, cst_transformers=None): if cst_transformers is None: cst_transformers = DiffConverterTransformer(module) new_mod = wrapper.visit(cst_transformers) - ruffed_code = new_mod.code #run_ruff(new_mod.code, True) + ruffed_code = new_mod.code # run_ruff(new_mod.code, True) if len(ruffed_code) > 0: with open(diff_file.replace("diff_", "modeling_"), "w") as f: f.write(AUTO_GENERATED_MESSAGE + ruffed_code) @@ -471,7 +474,7 @@ def convert_file(diff_file, cst_transformers=None): parser.add_argument( "--files_to_parse", default="all", - nargs='+', + nargs="+", help="A list of `diff_xxxx` files that should be converted to single model file", ) args = parser.parse_args()