From c72e11220d3554ca13e474cf91a57f5b6d5f137b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 20:25:48 +0100 Subject: [PATCH] Add types_to_file_type + tweak annotation handling --- utils/modular_model_converter.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index fad96548878bef..df9a6962764580 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -490,9 +490,13 @@ def __init__(self, class_name: str, global_names: set | None): def visit_Name(self, node): if node.value != self.class_name and node.value in self.global_names: parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - # If it is only an annotation, do not add dependency - if not m.matches(parent_node, m.Annotation()): - self.dependencies.add(node.value) + # If it is only an annotation inside a method definition, do not add dependency (however, do it for + # annotations that are variable definitions, i.e. for Kwargs classes) + if m.matches(parent_node, m.Annotation()): + grand_parent = self.get_metadata(cst.metadata.ParentNodeProvider, parent_node) + if m.matches(grand_parent, m.Param() | m.FunctionDef()): + return + self.dependencies.add(node.value) def dependencies_for_class_node(node: cst.ClassDef, global_names: set) -> set: @@ -890,6 +894,9 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename "Processor": "processing", "ImageProcessor": "image_processing", "FeatureExtractor": "feature_extractor", + "ProcessorKwargs": "processing", + "ImagesKwargs": "processing", + "TextKwargs": "processing", }