diff --git a/src/codemodder/codemods/utils.py b/src/codemodder/codemods/utils.py index b2c01e3a..d50721b0 100644 --- a/src/codemodder/codemods/utils.py +++ b/src/codemodder/codemods/utils.py @@ -2,7 +2,9 @@ from pathlib import Path from typing import Optional, Any -from libcst import matchers +from libcst import MetadataDependent, matchers +from libcst.codemod import CodemodContext +from libcst.matchers import MatcherDecoratableTransformer import libcst as cst @@ -104,6 +106,40 @@ def on_leave(self, original_node, updated_node): return updated_node +class MetadataPreservingTransformer( + MatcherDecoratableTransformer, cst.MetadataDependent +): + """ + The CSTTransformer equivalent of ContextAwareVisitor. Will preserve metadata passed through a context. You should not chain more than one of these, otherwise metadata will not reflect the state of the tree. + """ + + def __init__(self, context: CodemodContext) -> None: + MetadataDependent.__init__(self) + MatcherDecoratableTransformer.__init__(self) + self.context = context + dependencies = self.get_inherited_dependencies() + if dependencies: + wrapper = self.context.wrapper + if wrapper is None: + # pylint: disable-next=broad-exception-raised + raise Exception( + f"Attempting to instantiate {self.__class__.__name__} outside of " + + "an active transform. This means that metadata hasn't been " + + "calculated and we cannot successfully create this visitor." + ) + for dep in dependencies: + if dep not in wrapper._metadata: + # pylint: disable-next=broad-exception-raised + raise Exception( + f"Attempting to access metadata {dep.__name__} that was not a " + + "declared dependency of parent transform! This means it is " + + "not possible to compute this value. Please ensure that all " + + f"parent transforms of {self.__class__.__name__} declare " + + f"{dep.__name__} as a metadata dependency." + ) + self.metadata = {dep: wrapper._metadata[dep] for dep in dependencies} + + def is_django_settings_file(file_path: Path): if "settings.py" not in file_path.name: return False diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 60a4b1be..641f274f 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -1,12 +1,14 @@ -from typing import Any, Optional, Tuple, Union +from typing import Any, Collection, Optional, Tuple, Union import libcst as cst from libcst import MetadataDependent, matchers from libcst.helpers import get_full_name_for_node from libcst.metadata import ( + Access, Assignment, BaseAssignment, BuiltinAssignment, ImportAssignment, + ParentNodeProvider, ScopeProvider, ) from libcst.metadata.scope_provider import GlobalScope @@ -158,6 +160,141 @@ def is_builtin_function(self, node: cst.Call): return matchers.matches(node.func, matchers.Name()) return False + def find_accesses(self, node) -> Collection[Access]: + scope = self.get_metadata(ScopeProvider, node, None) + if scope: + return scope.accesses[node] + return {} + + +class AncestorPatternsMixin(MetadataDependent): + METADATA_DEPENDENCIES: Tuple[Any, ...] = (ParentNodeProvider,) + + def is_value_of_assignment( + self, expr + ) -> Optional[cst.AnnAssign | cst.Assign | cst.WithItem | cst.NamedExpr]: + """ + Tests if expr is the value in an assignment. + """ + parent = self.get_metadata(ParentNodeProvider, expr) + match parent: + case cst.AnnAssign(value=value) | cst.Assign(value=value) | cst.WithItem( + item=value + ) | cst.NamedExpr( + value=value + ) if expr == value: # type: ignore + return parent + return None + + def has_attr_called(self, node: cst.CSTNode) -> Optional[cst.Name]: + """ + Checks if node is part of an expression of the form: .call(). + """ + maybe_attr = self.is_attribute_value(node) + maybe_call = self.is_call_func(maybe_attr) if maybe_attr else None + if maybe_attr and maybe_call: + return maybe_attr.attr + return None + + def is_argument_of_call(self, node: cst.CSTNode) -> Optional[cst.Arg]: + """ + Checks if the node is an argument of a call. + """ + maybe_parent = self.get_parent(node) + match maybe_parent: + case cst.Arg(value=node): + return maybe_parent + return None + + def is_yield_value(self, node: cst.CSTNode) -> Optional[cst.Yield]: + """ + Checks if the node is the value of a Yield statement. + """ + maybe_parent = self.get_parent(node) + match maybe_parent: + case cst.Yield(value=node): + return maybe_parent + return None + + def is_return_value(self, node: cst.CSTNode) -> Optional[cst.Return]: + """ + Checks if the node is the value of a Return statement. + """ + maybe_parent = self.get_parent(node) + match maybe_parent: + case cst.Return(value=node): + return maybe_parent + return None + + def is_with_item(self, node: cst.CSTNode) -> Optional[cst.WithItem]: + """ + Checks if the node is the name of a WithItem. + """ + maybe_parent = self.get_parent(node) + match maybe_parent: + case cst.WithItem(item=node): + return maybe_parent + return None + + def is_call_func(self, node: cst.CSTNode) -> Optional[cst.Call]: + """ + Checks if the node is the func of an Call. + """ + maybe_parent = self.get_parent(node) + match maybe_parent: + case cst.Call(func=node): + return maybe_parent + return None + + def is_attribute_value(self, node: cst.CSTNode) -> Optional[cst.Attribute]: + """ + Checks if node is the value of an Attribute. + """ + maybe_parent = self.get_parent(node) + match maybe_parent: + case cst.Attribute(value=node): + return maybe_parent + return None + + def path_to_root(self, node: cst.CSTNode) -> list[cst.CSTNode]: + """ + Returns node's path to root. Includes self. + """ + path = [] + maybe_parent = node + while maybe_parent: + path.append(maybe_parent) + maybe_parent = self.get_parent(maybe_parent) + return path + + def path_to_root_as_set(self, node: cst.CSTNode) -> set[cst.CSTNode]: + """ + Returns the set of nodes in node's path to root. Includes self. + """ + path = set() + maybe_parent = node + while maybe_parent: + path.add(maybe_parent) + maybe_parent = self.get_parent(maybe_parent) + return path + + def is_ancestor(self, node: cst.CSTNode, other_node: cst.CSTNode) -> bool: + """ + Tests if other_node is an ancestor of node in the CST. + """ + path = self.path_to_root_as_set(node) + return other_node in path + + def get_parent(self, node: cst.CSTNode) -> Optional[cst.CSTNode]: + """ + Retrieves the parent of node. Will return None for the root. + """ + try: + return self.get_metadata(ParentNodeProvider, node, None) + except Exception: + pass + return None + def iterate_left_expressions(node: cst.BaseExpression): yield node diff --git a/src/core_codemods/file_resource_leak.py b/src/core_codemods/file_resource_leak.py index 177c3130..b58b66af 100644 --- a/src/core_codemods/file_resource_leak.py +++ b/src/core_codemods/file_resource_leak.py @@ -1,14 +1,12 @@ -from typing import Collection, Optional, Sequence +from typing import Optional, Sequence import libcst as cst -from libcst import MetadataDependent, ensure_type, matchers -from libcst.matchers import MatcherDecoratableTransformer +from libcst import ensure_type, matchers from libcst.codemod import ( Codemod, CodemodContext, ContextAwareVisitor, ) from libcst.metadata import ( - Access, BuiltinAssignment, ParentNodeProvider, PositionProvider, @@ -21,45 +19,12 @@ ReviewGuidance, ) from codemodder.codemods.base_visitor import UtilsMixin -from codemodder.codemods.utils_mixin import NameResolutionMixin +from codemodder.codemods.utils import MetadataPreservingTransformer +from codemodder.codemods.utils_mixin import AncestorPatternsMixin, NameResolutionMixin from codemodder.file_context import FileContext from functools import partial -class MetadataPreservingTransformer( - MatcherDecoratableTransformer, cst.MetadataDependent -): - """ - The CSTTransformer equivalent of ContextAwareVisitor. Will preserve metadata passed through a context. You should not chain more than one of these, otherwise metadata will not reflect the state of the tree. - """ - - def __init__(self, context: CodemodContext) -> None: - MetadataDependent.__init__(self) - MatcherDecoratableTransformer.__init__(self) - self.context = context - dependencies = self.get_inherited_dependencies() - if dependencies: - wrapper = self.context.wrapper - if wrapper is None: - # pylint: disable-next=broad-exception-raised - raise Exception( - f"Attempting to instantiate {self.__class__.__name__} outside of " - + "an active transform. This means that metadata hasn't been " - + "calculated and we cannot successfully create this visitor." - ) - for dep in dependencies: - if dep not in wrapper._metadata: - # pylint: disable-next=broad-exception-raised - raise Exception( - f"Attempting to access metadata {dep.__name__} that was not a " - + "declared dependency of parent transform! This means it is " - + "not possible to compute this value. Please ensure that all " - + f"parent transforms of {self.__class__.__name__} declare " - + f"{dep.__name__} as a metadata dependency." - ) - self.metadata = {dep: wrapper._metadata[dep] for dep in dependencies} - - class FileResourceLeak(BaseCodemod, UtilsMixin, Codemod): SUMMARY = "Automatically close resources" METADATA = CodemodMetadata( @@ -111,13 +76,11 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: return result -class FindResources(ContextAwareVisitor, NameResolutionMixin): +class FindResources(ContextAwareVisitor, NameResolutionMixin, AncestorPatternsMixin): """ Finds and all the patterns of the form x = resource(...), where resource is an call that results in an open resource. It gathers the path in the tree corresponding to the mentioned pattern. """ - METADATA_DEPENDENCIES = (ParentNodeProvider,) - def __init__(self, context: CodemodContext) -> None: super().__init__(context) self.assigned_resources: list[ @@ -188,15 +151,10 @@ def _is_resource(self, call: cst.Call) -> bool: return True return False - def get_parent(self, node: cst.CSTNode) -> Optional[cst.CSTNode]: - try: - return self.get_metadata(ParentNodeProvider, node, None) - except Exception: - pass - return None - -class ResourceLeakFixer(MetadataPreservingTransformer, NameResolutionMixin): +class ResourceLeakFixer( + MetadataPreservingTransformer, NameResolutionMixin, AncestorPatternsMixin +): METADATA_DEPENDENCIES = ( PositionProvider, ScopeProvider, @@ -464,134 +422,3 @@ def _is_arg_of_contextlib_function_in_with_item( if true_name and true_name.startswith("contextlib."): return self.is_with_item(maybe_gparent) return None - - def find_accesses(self, node) -> Collection[Access]: - scope = self.get_metadata(ScopeProvider, node, None) - if scope: - return scope.accesses[node] - return {} - - def is_value_of_assignment( - self, expr - ) -> Optional[cst.AnnAssign | cst.Assign | cst.WithItem | cst.NamedExpr]: - """ - Tests if expr is the value in an assignment. - """ - parent = self.get_metadata(ParentNodeProvider, expr) - match parent: - case cst.AnnAssign(value=value) | cst.Assign(value=value) | cst.WithItem( - item=value - ) | cst.NamedExpr( - value=value - ) if expr == value: # type: ignore - return parent - return None - - def has_attr_called(self, node: cst.CSTNode) -> Optional[cst.Name]: - """ - Checks if node is part of an expression of the form: .call(). - """ - maybe_attr = self.is_attribute_value(node) - maybe_call = self.is_call_func(maybe_attr) if maybe_attr else None - if maybe_attr and maybe_call: - return maybe_attr.attr - return None - - def is_argument_of_call(self, node: cst.CSTNode) -> Optional[cst.Arg]: - """ - Checks if the node is an argument of a call. - """ - maybe_parent = self.get_parent(node) - match maybe_parent: - case cst.Arg(value=node): - return maybe_parent - return None - - def is_yield_value(self, node: cst.CSTNode) -> Optional[cst.Yield]: - """ - Checks if the node is the value of a Yield statement. - """ - maybe_parent = self.get_parent(node) - match maybe_parent: - case cst.Yield(value=node): - return maybe_parent - return None - - def is_return_value(self, node: cst.CSTNode) -> Optional[cst.Return]: - """ - Checks if the node is the value of a Return statement. - """ - maybe_parent = self.get_parent(node) - match maybe_parent: - case cst.Return(value=node): - return maybe_parent - return None - - def is_with_item(self, node: cst.CSTNode) -> Optional[cst.WithItem]: - """ - Checks if the node is the name of a WithItem. - """ - maybe_parent = self.get_parent(node) - match maybe_parent: - case cst.WithItem(item=node): - return maybe_parent - return None - - def is_call_func(self, node: cst.CSTNode) -> Optional[cst.Call]: - """ - Checks if the node is the func of an Call. - """ - maybe_parent = self.get_parent(node) - match maybe_parent: - case cst.Call(func=node): - return maybe_parent - return None - - def is_attribute_value(self, node: cst.CSTNode) -> Optional[cst.Attribute]: - """ - Checks if node is the value of an Attribute. - """ - maybe_parent = self.get_parent(node) - match maybe_parent: - case cst.Attribute(value=node): - return maybe_parent - return None - - def path_to_root(self, node: cst.CSTNode) -> list[cst.CSTNode]: - """ - Returns node's path to root. Includes self. - """ - path = [] - maybe_parent = node - while maybe_parent: - path.append(maybe_parent) - maybe_parent = self.get_parent(maybe_parent) - return path - - def path_to_root_as_set(self, node: cst.CSTNode) -> set[cst.CSTNode]: - """ - Returns the set of nodes in node's path to root. Includes self. - """ - path = set() - maybe_parent = node - while maybe_parent: - path.add(maybe_parent) - maybe_parent = self.get_parent(maybe_parent) - return path - - def is_ancestor(self, node: cst.CSTNode, other_node: cst.CSTNode) -> bool: - """ - Tests if other_node is an ancestor of node in the CST. - """ - path = self.path_to_root_as_set(node) - return other_node in path - - def get_parent(self, node: cst.CSTNode) -> Optional[cst.CSTNode]: - """ - Retrieves the parent of node. Will return None for the root. - """ - try: - return self.get_metadata(ParentNodeProvider, node, None) - except Exception: - pass - return None diff --git a/tests/test_ancestorpatterns_mixin.py b/tests/test_ancestorpatterns_mixin.py new file mode 100644 index 00000000..f749e23f --- /dev/null +++ b/tests/test_ancestorpatterns_mixin.py @@ -0,0 +1,91 @@ +import libcst as cst +from libcst.codemod import Codemod, CodemodContext +from codemodder.codemods.utils_mixin import AncestorPatternsMixin +from textwrap import dedent + + +class TestNameResolutionMixin: + def test_get_parent(self): + class TestCodemod(Codemod, AncestorPatternsMixin): + def transform_module_impl(self, tree: cst.Module) -> cst.Module: + stmt = cst.ensure_type(tree.body[-1], cst.SimpleStatementLine) + + maybe_parent = self.get_parent(stmt) + assert maybe_parent == tree + return tree + + input_code = dedent( + """\ + a = 1 + """ + ) + tree = cst.parse_module(input_code) + TestCodemod(CodemodContext()).transform_module(tree) + + def test_get_parent_root(self): + class TestCodemod(Codemod, AncestorPatternsMixin): + def transform_module_impl(self, tree: cst.Module) -> cst.Module: + maybe_parent = self.get_parent(tree) + assert maybe_parent == None + return tree + + input_code = dedent( + """\ + a = 1 + """ + ) + tree = cst.parse_module(input_code) + TestCodemod(CodemodContext()).transform_module(tree) + + def test_get_path_to_root(self): + class TestCodemod(Codemod, AncestorPatternsMixin): + def transform_module_impl(self, tree: cst.Module) -> cst.Module: + stmt = cst.ensure_type(tree.body[-1], cst.SimpleStatementLine) + node = stmt.body[0] + + path = self.path_to_root(node) + assert len(path) == 3 + return tree + + input_code = dedent( + """\ + a = 1 + """ + ) + tree = cst.parse_module(input_code) + TestCodemod(CodemodContext()).transform_module(tree) + + def test_get_path_to_root_set(self): + class TestCodemod(Codemod, AncestorPatternsMixin): + def transform_module_impl(self, tree: cst.Module) -> cst.Module: + stmt = cst.ensure_type(tree.body[-1], cst.SimpleStatementLine) + node = stmt.body[0] + + path = self.path_to_root_as_set(node) + assert len(path) == 3 + return tree + + input_code = dedent( + """\ + a = 1 + """ + ) + tree = cst.parse_module(input_code) + TestCodemod(CodemodContext()).transform_module(tree) + + def test_is_ancestor(self): + class TestCodemod(Codemod, AncestorPatternsMixin): + def transform_module_impl(self, tree: cst.Module) -> cst.Module: + stmt = cst.ensure_type(tree.body[-1], cst.SimpleStatementLine) + node = stmt.body[0] + + assert self.is_ancestor(node, tree) + return tree + + input_code = dedent( + """\ + a = 1 + """ + ) + tree = cst.parse_module(input_code) + TestCodemod(CodemodContext()).transform_module(tree)