Skip to content

Commit

Permalink
Add check for builtin functions
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Nov 16, 2023
1 parent 33f2483 commit 97febf7
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 18 deletions.
10 changes: 10 additions & 0 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from libcst.metadata import (
Assignment,
BaseAssignment,
BuiltinAssignment,
ImportAssignment,
ScopeProvider,
)
Expand Down Expand Up @@ -149,6 +150,15 @@ def find_single_assignment(
return next(iter(assignments))
return None

def is_builtin_function(self, node: cst.Call):
"""
Given a `Call` node, find if it refers to a builtin function
"""
maybe_assignment = self.find_single_assignment(node)
if maybe_assignment and isinstance(maybe_assignment, BuiltinAssignment):
return matchers.matches(node.func, matchers.Name())
return False


def iterate_left_expressions(node: cst.BaseExpression):
yield node
Expand Down
36 changes: 19 additions & 17 deletions src/core_codemods/use_generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import libcst as cst

from codemodder.codemods.api import BaseCodemod, ReviewGuidance
from codemodder.codemods.utils_mixin import NameResolutionMixin


class UseGenerator(BaseCodemod):
class UseGenerator(BaseCodemod, NameResolutionMixin):
NAME = "use-generator"
SUMMARY = "Use generators for lazy evaluation"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW
Expand All @@ -28,22 +29,23 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
# NOTE: could also support things like `list` and `tuple`
# but it's a less compelling use case
case cst.Name("any" | "all" | "sum" | "min" | "max"):
match original_node.args[0].value:
case cst.ListComp(elt=elt, for_in=for_in):
self.add_change(original_node, self.CHANGE_DESCRIPTION)
return updated_node.with_changes(
args=[
cst.Arg(
value=cst.GeneratorExp(
elt=elt, # type: ignore
for_in=for_in, # type: ignore
# No parens necessary since they are
# already included by the call expr itself
lpar=[],
rpar=[],
if self.is_builtin_function(original_node):
match original_node.args[0].value:
case cst.ListComp(elt=elt, for_in=for_in):
self.add_change(original_node, self.CHANGE_DESCRIPTION)
return updated_node.with_changes(
args=[
cst.Arg(
value=cst.GeneratorExp(
elt=elt, # type: ignore
for_in=for_in, # type: ignore
# No parens necessary since they are
# already included by the call expr itself
lpar=[],
rpar=[],
)
)
)
],
)
],
)

return original_node
1 change: 0 additions & 1 deletion tests/codemods/test_use_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def test_not_special_builtin(self, tmpdir):
"""
self.run_and_assert(tmpdir, original_code, expected)

@pytest.mark.xfail(reason="TODO: check for built-in names")
def test_not_global_function(self, tmpdir):
expected = original_code = """
from foo import any
Expand Down

0 comments on commit 97febf7

Please sign in to comment.