Skip to content

Commit

Permalink
update diff
Browse files Browse the repository at this point in the history
  • Loading branch information
zucchini-nlp committed Jun 11, 2024
1 parent c9164a4 commit 1dc56ae
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 13 deletions.
6 changes: 3 additions & 3 deletions src/transformers/models/instructblipvideo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

_import_structure = {
"configuration_instructblipvideo": [
"InstructBlipVideoVisionConfig",
"InstructBlipVideoQFormerConfig",
"InstructBlipVideoConfig",
"InstructBlipVideoQFormerConfig",
"InstructBlipVideoVisionConfig",
],
"processing_instructblipvideo": ["InstructBlipVideoProcessor"],
}
Expand Down Expand Up @@ -63,7 +63,7 @@
pass
else:
from .image_processing_instructblipvideo import InstructBlipVideoImageProcessor

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
Expand Down
1 change: 0 additions & 1 deletion tests/models/instructblip/test_modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import numpy as np
import requests
from huggingface_hub import hf_hub_download

from transformers import (
CONFIG_MAPPING,
Expand Down
2 changes: 1 addition & 1 deletion tests/models/instructblip/test_processor_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from transformers import (
AutoProcessor,
BertTokenizerFast,
GPT2Tokenizer,
BlipImageProcessor,
GPT2Tokenizer,
InstructBlipProcessor,
PreTrainedTokenizerFast,
)
Expand Down
34 changes: 26 additions & 8 deletions utils/diff_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
- LLaMa -> MyNewModel abd MyNewModel -> Llama
"""

def __init__(self, old_name, new_name):
def __init__(self, old_name, new_name, given_old_name=None, given_new_name=None):
super().__init__()
self.old_name = old_name
self.new_name = new_name
Expand All @@ -183,6 +183,8 @@ def __init__(self, old_name, new_name):
old_name.upper(): new_name.upper(),
"".join(x.title() for x in old_name.split("_")): self.default_name,
}
if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns:
self.patterns[given_old_name] = given_new_name

def preserve_case_replace(self, text):
# Create a regex pattern to match all variations
Expand All @@ -201,9 +203,9 @@ def replace_name(self, original_node, updated_node):
return updated_node.with_changes(value=update)


def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma"):
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma", given_old_name=None, given_new_name=None):
"""Helper function to rename and then parse a source file using the ClassFinder"""
transformer = ReplaceNameTransformer(old_id, new_id)
transformer = ReplaceNameTransformer(old_id, new_id, given_old_name, given_new_name)
new_module = module.visit(transformer)

wrapper = MetadataWrapper(new_module)
Expand Down Expand Up @@ -356,11 +358,13 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
class DiffConverterTransformer(CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)

def __init__(self, python_module, new_name):
def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None):
super().__init__()
self.model_name = (
new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` our `phi3`
)
self.given_old_name = given_old_name
self.given_new_name = given_new_name
# fmt: off
self.python_module = python_module # we store the original module to use `code_for_node`
self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module
Expand Down Expand Up @@ -460,7 +464,11 @@ def leave_ClassDef(self, original_node, updated_node):

if super_file_name not in self.visited_module: # only extract classes once
class_finder = find_classes_in_file(
self.transformers_imports[super_file_name], model_name, self.model_name
self.transformers_imports[super_file_name],
model_name,
self.model_name,
self.given_old_name,
self.given_new_name,
)
self.visited_module[super_file_name] = class_finder
else: # we are re-using the previously parsed data
Expand Down Expand Up @@ -517,15 +525,15 @@ def leave_Module(self, original_node: cst.Assign, node):
return node.with_changes(body=[*new_body])


def convert_file(diff_file, cst_transformers=None):
def convert_file(diff_file, old_model_name=None, new_model_name=None, cst_transformers=None):
model_name = re.search(r"diff_(.*)(?=\.py$)", diff_file).groups()[0]
# Parse the Python file
with open(diff_file, "r") as file:
code = file.read()
module = cst.parse_module(code)
wrapper = MetadataWrapper(module)
if cst_transformers is None:
cst_transformers = DiffConverterTransformer(module, model_name)
cst_transformers = DiffConverterTransformer(module, model_name, old_model_name, new_model_name)
new_mod = wrapper.visit(cst_transformers)
ruffed_code = run_ruff(new_mod.code, True)
formatted_code = run_ruff(ruffed_code, False)
Expand All @@ -552,10 +560,20 @@ def convert_file(diff_file, cst_transformers=None):
nargs="+",
help="A list of `diff_xxxx` files that should be converted to single model file",
)
parser.add_argument(
"--old_model_name",
required=False,
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from diff-file",
)
parser.add_argument(
"--new_model_name",
required=False,
help="The name of the new model being added in CamelCase. If not provided is inferred from diff-file",
)
args = parser.parse_args()
if args.files_to_parse == ["all"]:
args.files_to_parse = glob.glob("src/transformers/models/**/diff_*.py", recursive=True)
for file_name in args.files_to_parse:
print(f"Converting {file_name} to a single model single file format")
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
converter = convert_file(file_name)
converter = convert_file(file_name, args.old_model_name, args.new_model_name)

0 comments on commit 1dc56ae

Please sign in to comment.