Skip to content

Commit

Permalink
fix modular order (#35297)
Browse files Browse the repository at this point in the history
* fix modular ordre

* fix

* style
  • Loading branch information
ArthurZucker authored Dec 17, 2024
1 parent f5620a7 commit a7f5479
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 28 deletions.
62 changes: 35 additions & 27 deletions utils/create_dependency_mapping.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,48 @@
import ast
from collections import defaultdict, deque
from collections import defaultdict


# Function to perform topological sorting
def topological_sort(dependencies):
# Create a graph and in-degree count for each node
new_dependencies = {}
graph = defaultdict(list)
in_degree = defaultdict(int)

# Build the graph
for node, deps in dependencies.items():
for dep in deps:
graph[dep].append(node) # node depends on dep
in_degree[node] += 1 # increase in-degree of node
if "example" not in node and "auto" not in dep:
graph[dep.split(".")[-2]].append(node.split("/")[-2])
new_dependencies[node.split("/")[-2]] = node

# Add all nodes with zero in-degree to the queue
zero_in_degree_queue = deque([node for node in dependencies if in_degree[node] == 0])
# Create a graph and in-degree count for each node
def filter_one_by_one(filtered_list, reverse):
if len(reverse) == 0:
return filtered_list

sorted_list = []
# Perform topological sorting
while zero_in_degree_queue:
current = zero_in_degree_queue.popleft()
sorted_list.append(current)
graph = defaultdict(list)
# Build the graph
for node, deps in reverse.items():
for dep in deps:
graph[dep].append(node)

# For each node that current points to, reduce its in-degree
for neighbor in graph[current]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
zero_in_degree_queue.append(neighbor)
base_modules = set(reverse.keys()) - set(graph.keys())
if base_modules == reverse.keys():
# we are at the end
return filtered_list + list(graph.keys())
to_add = []
for k in graph.keys():
if len(graph[k]) == 1 and graph[k][0] in base_modules:
if graph[k][0] in reverse:
del reverse[graph[k][0]]
if k not in filtered_list:
to_add += [k]
for k in base_modules:
if k not in filtered_list:
to_add += [k]
filtered_list += list(to_add)
return filter_one_by_one(filtered_list, reverse)

# Handle nodes that have no dependencies and were not initially part of the loop
for node in dependencies:
if node not in sorted_list:
sorted_list.append(node)
final_order = filter_one_by_one([], graph)

return sorted_list
return [new_dependencies.get(k) for k in final_order if k in new_dependencies]


# Function to extract class and import info from a file
Expand All @@ -46,7 +54,7 @@ def extract_classes_and_imports(file_path):
for node in ast.walk(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
module = node.module if isinstance(node, ast.ImportFrom) else None
if module and "transformers" in module:
if module and (".modeling_" in module):
imports.add(module)
return imports

Expand All @@ -56,7 +64,7 @@ def map_dependencies(py_files):
dependencies = defaultdict(set)
# First pass: Extract all classes and map to files
for file_path in py_files:
dependencies[file_path].add(None)
# dependencies[file_path].add(None)
class_to_file = extract_classes_and_imports(file_path)
for module in class_to_file:
dependencies[file_path].add(module)
Expand All @@ -66,4 +74,4 @@ def map_dependencies(py_files):
def find_priority_list(py_files):
dependencies = map_dependencies(py_files)
ordered_classes = topological_sort(dependencies)
return ordered_classes[::-1]
return ordered_classes
2 changes: 1 addition & 1 deletion utils/modular_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,7 +1678,7 @@ def save_modeling_file(modular_file, converted_file):
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["src/transformers/models/aria/modular_aria.py"],
default=["all"],
nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file",
)
Expand Down

0 comments on commit a7f5479

Please sign in to comment.