From 23dd9ab81a2db413d55c9b66ff9f640afb91236a Mon Sep 17 00:00:00 2001 From: Daniel D'Avella Date: Thu, 16 Nov 2023 09:13:56 -0500 Subject: [PATCH] Add check for builtin functions --- src/codemodder/codemods/utils_mixin.py | 10 +++++++ src/core_codemods/use_generator.py | 36 ++++++++++++++------------ tests/codemods/test_use_generator.py | 1 - 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 787577b0..20736a00 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -5,6 +5,7 @@ from libcst.metadata import ( Assignment, BaseAssignment, + BuiltinAssignment, ImportAssignment, ScopeProvider, ) @@ -149,6 +150,15 @@ def find_single_assignment( return next(iter(assignments)) return None + def is_builtin_function(self, node: cst.Call): + """ + Given a `Call` node, find if it refers to a builtin function + """ + maybe_assignment = self.find_single_assignment(node) + if maybe_assignment and isinstance(maybe_assignment, BuiltinAssignment): + return matchers.matches(node.func, matchers.Name()) + return False + def iterate_left_expressions(node: cst.BaseExpression): yield node diff --git a/src/core_codemods/use_generator.py b/src/core_codemods/use_generator.py index 655a701d..e9b3839a 100644 --- a/src/core_codemods/use_generator.py +++ b/src/core_codemods/use_generator.py @@ -1,9 +1,10 @@ import libcst as cst from codemodder.codemods.api import BaseCodemod, ReviewGuidance +from codemodder.codemods.utils_mixin import NameResolutionMixin -class UseGenerator(BaseCodemod): +class UseGenerator(BaseCodemod, NameResolutionMixin): NAME = "use-generator" SUMMARY = "Use generators for lazy evaluation" REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW @@ -28,22 +29,23 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): # NOTE: could also support things like `list` and `tuple` # but it's a less compelling use case case cst.Name("any" | "all" | "sum" | "min" | "max"): - match original_node.args[0].value: - case cst.ListComp(elt=elt, for_in=for_in): - self.add_change(original_node, self.CHANGE_DESCRIPTION) - return updated_node.with_changes( - args=[ - cst.Arg( - value=cst.GeneratorExp( - elt=elt, # type: ignore - for_in=for_in, # type: ignore - # No parens necessary since they are - # already included by the call expr itself - lpar=[], - rpar=[], + if self.is_builtin_function(original_node): + match original_node.args[0].value: + case cst.ListComp(elt=elt, for_in=for_in): + self.add_change(original_node, self.CHANGE_DESCRIPTION) + return updated_node.with_changes( + args=[ + cst.Arg( + value=cst.GeneratorExp( + elt=elt, # type: ignore + for_in=for_in, # type: ignore + # No parens necessary since they are + # already included by the call expr itself + lpar=[], + rpar=[], + ) ) - ) - ], - ) + ], + ) return original_node diff --git a/tests/codemods/test_use_generator.py b/tests/codemods/test_use_generator.py index 658640d6..b8cacf2e 100644 --- a/tests/codemods/test_use_generator.py +++ b/tests/codemods/test_use_generator.py @@ -23,7 +23,6 @@ def test_not_special_builtin(self, tmpdir): """ self.run_and_assert(tmpdir, original_code, expected) - @pytest.mark.xfail(reason="TODO: check for built-in names") def test_not_global_function(self, tmpdir): expected = original_code = """ from foo import any