Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New codemod: use-generator #135

Merged
merged 2 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions integration_tests/test_use_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from core_codemods.use_generator import UseGenerator
from integration_tests.base_test import (
BaseIntegrationTest,
original_and_expected_from_code_path,
)


class TestUseGenerator(BaseIntegrationTest):
codemod = UseGenerator
code_path = "tests/samples/use_generator.py"

original_code, expected_new_code = original_and_expected_from_code_path(
code_path,
[(5, "x = sum(i for i in range(1000))\n")],
)

expected_diff = """\
---
+++
@@ -3,5 +3,5 @@
yield i


-x = sum([i for i in range(1000)])
+x = sum(i for i in range(1000))
y = some([i for i in range(1000)])
"""

expected_line_change = "6"
change_description = UseGenerator.CHANGE_DESCRIPTION
2 changes: 2 additions & 0 deletions src/codemodder/codemods/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
CodemodMetadata,
BaseCodemod as _BaseCodemod,
SemgrepCodemod as _SemgrepCodemod,
# Make this available via the simplified API
ReviewGuidance, # pylint: disable=unused-import
)

from codemodder.codemods.base_visitor import BaseTransformer
Expand Down
10 changes: 10 additions & 0 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from libcst.metadata import (
Assignment,
BaseAssignment,
BuiltinAssignment,
ImportAssignment,
ScopeProvider,
)
Expand Down Expand Up @@ -149,6 +150,15 @@ def find_single_assignment(
return next(iter(assignments))
return None

def is_builtin_function(self, node: cst.Call):
"""
Given a `Call` node, find if it refers to a builtin function
"""
maybe_assignment = self.find_single_assignment(node)
if maybe_assignment and isinstance(maybe_assignment, BuiltinAssignment):
return matchers.matches(node.func, matchers.Name())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm 90% sure this is redundant, as in any function that that matches the if predicate will have node.func as Name().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair. I'll leave it for now but we could revisit.

return False


def iterate_left_expressions(node: cst.BaseExpression):
yield node
Expand Down
4 changes: 4 additions & 0 deletions src/codemodder/scripts/generate_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ class DocMetadata:
importance="High",
guidance_explained="Python has a wealth of database drivers that all use the same `dbapi2` interface detailed in [PEP249](https://peps.python.org/pep-0249/). Different drivers may require different string tokens used for parameterization, and Python's dynamic typing makes it quite hard, and sometimes impossible, to detect which driver is being used just by looking at the code.",
),
"use-generator": DocMetadata(
importance="Low",
guidance_explained="We believe this replacement is safe and leads to better performance.",
),
}


Expand Down
2 changes: 2 additions & 0 deletions src/core_codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .upgrade_sslcontext_tls import UpgradeSSLContextTLS
from .url_sandbox import UrlSandbox
from .use_defused_xml import UseDefusedXml
from .use_generator import UseGenerator
from .use_walrus_if import UseWalrusIf
from .with_threading_lock import WithThreadingLock

Expand Down Expand Up @@ -55,6 +56,7 @@
UpgradeSSLContextTLS,
UrlSandbox,
UseDefusedXml,
UseGenerator,
UseWalrusIf,
WithThreadingLock,
SQLQueryParameterization,
Expand Down
13 changes: 13 additions & 0 deletions src/core_codemods/docs/pixee_python_use-generator.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Imagine that someone handed you a pile of 100 apples and then asked you to count how many of them were green without putting any of them down. You'd probably find this quite challenging and you'd struggle to hold the pile of apples at all. Now imagine someone handed you the apples one at a time and asked you to just count the green ones. This would be a much easier task.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly a nitpick, but this is the kind of language I'd avoid in documentations. Be direct and on point.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is a little too whimsical but it's important to remember that a big part of our vision here involves educating developers and security professionals on best practices. This means we actually need to make our content as engaging as possible. I'm not necessarily saying I've succeeded in that here, but I'm going to leave it as-is and see what feedback we receive.


In Python, when we use list comprehensions, it's like we've created the entire pile of apples and asked the interpreter to hold onto it. Sometimes, a better practice involves using generator expressions, which create iterators that yield objects one at a time. For large data sets, this can turn a slow, memory intensive operation into a relatively fast one.

Using generator expressions instead of list comprehensions can lead to better performance. This is especially true for functions such as `any` where it's not always necessary to evaluate the entire list before returning. For other functions such as `max` or `sum` it means that the program does not need to store the entire list in memory. These performance effects becomes more noticeable as the sizes of the lists involved grow large.

This codemod replaces the use of a list comprehension expression with a generator expression within certain function calls. Generators allow for lazy evaluation of the iterator, which can have performance benefits.

The changes from this codemod look like this:
```diff
- result = sum([x for x in range(1000)])
+ result = sum(x for x in range(1000))
```
51 changes: 51 additions & 0 deletions src/core_codemods/use_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import libcst as cst

from codemodder.codemods.api import BaseCodemod, ReviewGuidance
from codemodder.codemods.utils_mixin import NameResolutionMixin


class UseGenerator(BaseCodemod, NameResolutionMixin):
NAME = "use-generator"
SUMMARY = "Use generators for lazy evaluation"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW
DESCRIPTION = "Replace list comprehension with generator expression"
REFERENCES = [
{
"url": "https://pylint.readthedocs.io/en/latest/user_guide/messages/refactor/use-a-generator.html",
"description": "",
},
{
"url": "https://docs.python.org/3/glossary.html#term-generator-expression",
"description": "",
},
{
"url": "https://docs.python.org/3/glossary.html#term-list-comprehension",
"description": "",
},
]

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
match original_node.func:
# NOTE: could also support things like `list` and `tuple`
# but it's a less compelling use case
case cst.Name("any" | "all" | "sum" | "min" | "max"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you have pointed it out, those could use an extra step to check if they are builtin functions. You can do this using the ScopeProvider. Query the metadata for assignments for this Name node and check if it is a BuiltinAssignment.

if self.is_builtin_function(original_node):
match original_node.args[0].value:
case cst.ListComp(elt=elt, for_in=for_in):
self.add_change(original_node, self.CHANGE_DESCRIPTION)
return updated_node.with_changes(
args=[
cst.Arg(
value=cst.GeneratorExp(
elt=elt, # type: ignore
for_in=for_in, # type: ignore
# No parens necessary since they are
# already included by the call expr itself
lpar=[],
rpar=[],
)
)
],
)

return original_node
31 changes: 31 additions & 0 deletions tests/codemods/test_use_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from core_codemods.use_generator import UseGenerator
from tests.codemods.base_codemod_test import BaseCodemodTest


class TestUseGenerator(BaseCodemodTest):
codemod = UseGenerator

@pytest.mark.parametrize("func", ["any", "all", "sum", "min", "max"])
def test_list_comprehension(self, tmpdir, func):
original_code = f"""
x = {func}([i for i in range(10)])
"""
new_code = f"""
x = {func}(i for i in range(10))
"""
self.run_and_assert(tmpdir, original_code, new_code)

def test_not_special_builtin(self, tmpdir):
expected = original_code = """
x = some([i for i in range(10)])
"""
self.run_and_assert(tmpdir, original_code, expected)

def test_not_global_function(self, tmpdir):
expected = original_code = """
from foo import any
x = any([i for i in range(10)])
"""
self.run_and_assert(tmpdir, original_code, expected)
7 changes: 7 additions & 0 deletions tests/samples/use_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def some(iterable):
for i in iterable:
yield i


x = sum([i for i in range(1000)])
y = some([i for i in range(1000)])
Loading