From a7f5479b45a8040392af80bf1107a2bdd796931c Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 17 Dec 2024 08:05:35 +0100 Subject: [PATCH] fix modular order (#35297) * fix modular ordre * fix * style --- utils/create_dependency_mapping.py | 62 +++++++++++++++++------------- utils/modular_model_converter.py | 2 +- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/utils/create_dependency_mapping.py b/utils/create_dependency_mapping.py index f25a8fb5ca6ff1..0df782d1c21740 100644 --- a/utils/create_dependency_mapping.py +++ b/utils/create_dependency_mapping.py @@ -1,40 +1,48 @@ import ast -from collections import defaultdict, deque +from collections import defaultdict # Function to perform topological sorting def topological_sort(dependencies): - # Create a graph and in-degree count for each node + new_dependencies = {} graph = defaultdict(list) - in_degree = defaultdict(int) - - # Build the graph for node, deps in dependencies.items(): for dep in deps: - graph[dep].append(node) # node depends on dep - in_degree[node] += 1 # increase in-degree of node + if "example" not in node and "auto" not in dep: + graph[dep.split(".")[-2]].append(node.split("/")[-2]) + new_dependencies[node.split("/")[-2]] = node - # Add all nodes with zero in-degree to the queue - zero_in_degree_queue = deque([node for node in dependencies if in_degree[node] == 0]) + # Create a graph and in-degree count for each node + def filter_one_by_one(filtered_list, reverse): + if len(reverse) == 0: + return filtered_list - sorted_list = [] - # Perform topological sorting - while zero_in_degree_queue: - current = zero_in_degree_queue.popleft() - sorted_list.append(current) + graph = defaultdict(list) + # Build the graph + for node, deps in reverse.items(): + for dep in deps: + graph[dep].append(node) - # For each node that current points to, reduce its in-degree - for neighbor in graph[current]: - in_degree[neighbor] -= 1 - if in_degree[neighbor] == 0: - zero_in_degree_queue.append(neighbor) + base_modules = set(reverse.keys()) - set(graph.keys()) + if base_modules == reverse.keys(): + # we are at the end + return filtered_list + list(graph.keys()) + to_add = [] + for k in graph.keys(): + if len(graph[k]) == 1 and graph[k][0] in base_modules: + if graph[k][0] in reverse: + del reverse[graph[k][0]] + if k not in filtered_list: + to_add += [k] + for k in base_modules: + if k not in filtered_list: + to_add += [k] + filtered_list += list(to_add) + return filter_one_by_one(filtered_list, reverse) - # Handle nodes that have no dependencies and were not initially part of the loop - for node in dependencies: - if node not in sorted_list: - sorted_list.append(node) + final_order = filter_one_by_one([], graph) - return sorted_list + return [new_dependencies.get(k) for k in final_order if k in new_dependencies] # Function to extract class and import info from a file @@ -46,7 +54,7 @@ def extract_classes_and_imports(file_path): for node in ast.walk(tree): if isinstance(node, (ast.Import, ast.ImportFrom)): module = node.module if isinstance(node, ast.ImportFrom) else None - if module and "transformers" in module: + if module and (".modeling_" in module): imports.add(module) return imports @@ -56,7 +64,7 @@ def map_dependencies(py_files): dependencies = defaultdict(set) # First pass: Extract all classes and map to files for file_path in py_files: - dependencies[file_path].add(None) + # dependencies[file_path].add(None) class_to_file = extract_classes_and_imports(file_path) for module in class_to_file: dependencies[file_path].add(module) @@ -66,4 +74,4 @@ def map_dependencies(py_files): def find_priority_list(py_files): dependencies = map_dependencies(py_files) ordered_classes = topological_sort(dependencies) - return ordered_classes[::-1] + return ordered_classes diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index e8d117cd2af08f..28fcc4fc7b9e1a 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1678,7 +1678,7 @@ def save_modeling_file(modular_file, converted_file): parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/aria/modular_aria.py"], + default=["all"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", )