Skip to content

Commit

Permalink
current state
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed May 30, 2024
1 parent 64422e5 commit 8a85473
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 41 deletions.
12 changes: 8 additions & 4 deletions examples/diff-conversion/diff_dummy.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from transformers.models.llama.modeling_llama import LlamaModel
from typing import *
import torch
from math import log
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import List, Optional, Tuple, Union

import torch

from transformers import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel


def _pre_process_input(input_ids):
print(log(input_ids))
return input_ids


# example where we need some deps and some functions
class DummyModel(LlamaModel):
def forward(
Expand Down
12 changes: 6 additions & 6 deletions examples/diff-conversion/diff_my_new_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from transformers.models.llama.modeling_llama import LlamaConfig


# Example where we only want to only add a new config argument and new arg doc
# here there is no `ARG` so we are gonna take parent doc
class MyNewModelConfig(LlamaConfig):
r"""
mlp_bias (`bool`, *optional*, defaults to `False`)
mlp_bias (`bool`, *optional*, defaults to `False`)
"""
def __init__(
self,
mlp_bias=False
):

def __init__(self, mlp_bias=False):
self.mlp_bias = mlp_bias
super().__init__(self)
super().__init__(self)
3 changes: 3 additions & 0 deletions examples/diff-conversion/diff_my_new_model2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from transformers.models.gemma.modeling_gemma import GemmaForSequenceClassification
from transformers.models.llama.configuration_llama import LlamaConfig


# Example where we only want to only modify the docstring
class MyNewModel2Config(LlamaConfig):
r"""
Expand All @@ -23,6 +25,7 @@ class MyNewModel2Config(LlamaConfig):
>>> configuration = model.config
```"""


# Example where alllllll the dependencies are fetched to just copy the entire class
class MyNewModel2ForSequenceClassification(GemmaForSequenceClassification):
pass
4 changes: 3 additions & 1 deletion examples/diff-conversion/diff_new_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Example where we only want to overwrite the defaults of an init

from transformers.models.gemma.configuration_gemma import GemmaConfig


class NewModelConfig(GemmaConfig):
def __init__(
self,
Expand All @@ -25,4 +27,4 @@ def __init__(
attention_bias=False,
attention_dropout=0.0,
):
super().__init__(self)
super().__init__(self)
3 changes: 1 addition & 2 deletions src/transformers/models/gemma/configuration_gemma.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
Expand All @@ -21,7 +21,6 @@
# limitations under the License.



from transformers import PretrainedConfig


Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gemma/diff_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(
**kwargs,
)


class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
Expand Down
17 changes: 6 additions & 11 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
Expand All @@ -24,29 +24,22 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn

from transformers import PretrainedConfig
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache
from ...modeling_outputs import CausalLMOutputWithPast
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import logging
import math
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
Expand Down Expand Up @@ -81,6 +74,7 @@ def _get_unpad_data(attention_mask):
max_seqlen_in_batch,
)


class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
Expand Down Expand Up @@ -734,6 +728,7 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()


_CONFIG_FOR_DOC = "GemmaConfig"


Expand Down
9 changes: 5 additions & 4 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
Expand Down Expand Up @@ -30,6 +30,8 @@
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

Expand All @@ -48,14 +50,13 @@
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_llama import LlamaConfig
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa


"""PyTorch LLaMA model."""


Expand Down
29 changes: 16 additions & 13 deletions utils/diff_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
"""
Expand Down Expand Up @@ -330,8 +330,8 @@ def __init__(self, python_module):
def leave_FunctionDef(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
self.global_scope_index += 100
self.new_body[node.name.value] = {"insert_idx":self.global_scope_index, "node":node}
self.global_scope_index += 100
self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node}
return node

def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
Expand Down Expand Up @@ -364,7 +364,10 @@ def leave_SimpleStatementLine(self, original_node, updated_node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
self.global_scope_index += 100
self.new_body[self.python_module.code_for_node(updated_node.body[0])] = {"insert_idx":self.global_scope_index, "node":updated_node}
self.new_body[self.python_module.code_for_node(updated_node.body[0])] = {
"insert_idx": self.global_scope_index,
"node": updated_node,
}
return updated_node

def leave_ClassDef(self, original_node, updated_node):
Expand Down Expand Up @@ -402,22 +405,22 @@ def leave_ClassDef(self, original_node, updated_node):
}

list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
start_insert_idx = self.global_scope_index
start_insert_idx = self.global_scope_index
for dependency, _ in list_dependencies:
node = class_finder.global_nodes.get(dependency, None)
if node is not None:
if dependency not in self.new_body:
start_insert_idx -= 1
self.new_body[dependency] = {"insert_idx":start_insert_idx, "node":node}
self.new_body[dependency] = {"insert_idx": start_insert_idx, "node": node}
elif dependency not in self.inserted_deps:
# make sure the node is written after it's dependencies
start_insert_idx = self.new_body[dependency]["insert_idx"]-1
start_insert_idx = self.new_body[dependency]["insert_idx"] - 1
self.inserted_deps.append(dependency)
updated_node = replace_call_to_super(class_finder, updated_node, class_name)
if "Config" in class_name:
self.config_body = [updated_node]
else:
self.new_body[class_name] = {"insert_idx":self.global_scope_index, "node":updated_node}
self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
return updated_node

# def leave_If(self, original_node, node):
Expand All @@ -431,14 +434,14 @@ def leave_ClassDef(self, original_node, updated_node):
# return node

def leave_Module(self, original_node: cst.Assign, node):
imports = {self.python_module.code_for_node(k):k for k in self.all_imports }
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
for visiter in self.visited_module.values():
imports.update({self.python_module.code_for_node(k):k for k in visiter.imports.values()})
imports.update({self.python_module.code_for_node(k): k for k in visiter.imports.values()})

new_body = list(imports.values())
if hasattr(self, "config_body"):
self.config_body = self.all_imports + self.config_body
new_body += [k[1]["node"] for k in sorted(self.new_body.items(), key=lambda x:x[1]["insert_idx"])]
new_body += [k[1]["node"] for k in sorted(self.new_body.items(), key=lambda x: x[1]["insert_idx"])]
return node.with_changes(body=[*new_body])


Expand All @@ -451,7 +454,7 @@ def convert_file(diff_file, cst_transformers=None):
if cst_transformers is None:
cst_transformers = DiffConverterTransformer(module)
new_mod = wrapper.visit(cst_transformers)
ruffed_code = new_mod.code #run_ruff(new_mod.code, True)
ruffed_code = new_mod.code # run_ruff(new_mod.code, True)
if len(ruffed_code) > 0:
with open(diff_file.replace("diff_", "modeling_"), "w") as f:
f.write(AUTO_GENERATED_MESSAGE + ruffed_code)
Expand All @@ -471,7 +474,7 @@ def convert_file(diff_file, cst_transformers=None):
parser.add_argument(
"--files_to_parse",
default="all",
nargs='+',
nargs="+",
help="A list of `diff_xxxx` files that should be converted to single model file",
)
args = parser.parse_args()
Expand Down

0 comments on commit 8a85473

Please sign in to comment.