diff --git a/integration_tests/test_use_generator.py b/integration_tests/test_use_generator.py new file mode 100644 index 00000000..a7d6b38a --- /dev/null +++ b/integration_tests/test_use_generator.py @@ -0,0 +1,30 @@ +from core_codemods.use_generator import UseGenerator +from integration_tests.base_test import ( + BaseIntegrationTest, + original_and_expected_from_code_path, +) + + +class TestUseGenerator(BaseIntegrationTest): + codemod = UseGenerator + code_path = "tests/samples/use_generator.py" + + original_code, expected_new_code = original_and_expected_from_code_path( + code_path, + [(5, "x = sum(i for i in range(1000))\n")], + ) + + expected_diff = """\ +--- ++++ +@@ -3,5 +3,5 @@ + yield i + + +-x = sum([i for i in range(1000)]) ++x = sum(i for i in range(1000)) + y = some([i for i in range(1000)]) +""" + + expected_line_change = "6" + change_description = UseGenerator.CHANGE_DESCRIPTION diff --git a/src/codemodder/codemods/api/__init__.py b/src/codemodder/codemods/api/__init__.py index a3a102c3..36b1e706 100644 --- a/src/codemodder/codemods/api/__init__.py +++ b/src/codemodder/codemods/api/__init__.py @@ -12,6 +12,8 @@ CodemodMetadata, BaseCodemod as _BaseCodemod, SemgrepCodemod as _SemgrepCodemod, + # Make this available via the simplified API + ReviewGuidance, # pylint: disable=unused-import ) from codemodder.codemods.base_visitor import BaseTransformer diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 787577b0..20736a00 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -5,6 +5,7 @@ from libcst.metadata import ( Assignment, BaseAssignment, + BuiltinAssignment, ImportAssignment, ScopeProvider, ) @@ -149,6 +150,15 @@ def find_single_assignment( return next(iter(assignments)) return None + def is_builtin_function(self, node: cst.Call): + """ + Given a `Call` node, find if it refers to a builtin function + """ + maybe_assignment = self.find_single_assignment(node) + if maybe_assignment and isinstance(maybe_assignment, BuiltinAssignment): + return matchers.matches(node.func, matchers.Name()) + return False + def iterate_left_expressions(node: cst.BaseExpression): yield node diff --git a/src/codemodder/scripts/generate_docs.py b/src/codemodder/scripts/generate_docs.py index 4d7bda81..d333cb75 100644 --- a/src/codemodder/scripts/generate_docs.py +++ b/src/codemodder/scripts/generate_docs.py @@ -130,6 +130,10 @@ class DocMetadata: importance="High", guidance_explained="Python has a wealth of database drivers that all use the same `dbapi2` interface detailed in [PEP249](https://peps.python.org/pep-0249/). Different drivers may require different string tokens used for parameterization, and Python's dynamic typing makes it quite hard, and sometimes impossible, to detect which driver is being used just by looking at the code.", ), + "use-generator": DocMetadata( + importance="Low", + guidance_explained="We believe this replacement is safe and leads to better performance.", + ), } diff --git a/src/core_codemods/__init__.py b/src/core_codemods/__init__.py index cca1bc2c..ee5dcd18 100644 --- a/src/core_codemods/__init__.py +++ b/src/core_codemods/__init__.py @@ -24,6 +24,7 @@ from .upgrade_sslcontext_tls import UpgradeSSLContextTLS from .url_sandbox import UrlSandbox from .use_defused_xml import UseDefusedXml +from .use_generator import UseGenerator from .use_walrus_if import UseWalrusIf from .with_threading_lock import WithThreadingLock @@ -55,6 +56,7 @@ UpgradeSSLContextTLS, UrlSandbox, UseDefusedXml, + UseGenerator, UseWalrusIf, WithThreadingLock, SQLQueryParameterization, diff --git a/src/core_codemods/docs/pixee_python_use-generator.md b/src/core_codemods/docs/pixee_python_use-generator.md new file mode 100644 index 00000000..d107209e --- /dev/null +++ b/src/core_codemods/docs/pixee_python_use-generator.md @@ -0,0 +1,13 @@ +Imagine that someone handed you a pile of 100 apples and then asked you to count how many of them were green without putting any of them down. You'd probably find this quite challenging and you'd struggle to hold the pile of apples at all. Now imagine someone handed you the apples one at a time and asked you to just count the green ones. This would be a much easier task. + +In Python, when we use list comprehensions, it's like we've created the entire pile of apples and asked the interpreter to hold onto it. Sometimes, a better practice involves using generator expressions, which create iterators that yield objects one at a time. For large data sets, this can turn a slow, memory intensive operation into a relatively fast one. + +Using generator expressions instead of list comprehensions can lead to better performance. This is especially true for functions such as `any` where it's not always necessary to evaluate the entire list before returning. For other functions such as `max` or `sum` it means that the program does not need to store the entire list in memory. These performance effects becomes more noticeable as the sizes of the lists involved grow large. + +This codemod replaces the use of a list comprehension expression with a generator expression within certain function calls. Generators allow for lazy evaluation of the iterator, which can have performance benefits. + +The changes from this codemod look like this: +```diff +- result = sum([x for x in range(1000)]) ++ result = sum(x for x in range(1000)) +``` diff --git a/src/core_codemods/use_generator.py b/src/core_codemods/use_generator.py new file mode 100644 index 00000000..e9b3839a --- /dev/null +++ b/src/core_codemods/use_generator.py @@ -0,0 +1,51 @@ +import libcst as cst + +from codemodder.codemods.api import BaseCodemod, ReviewGuidance +from codemodder.codemods.utils_mixin import NameResolutionMixin + + +class UseGenerator(BaseCodemod, NameResolutionMixin): + NAME = "use-generator" + SUMMARY = "Use generators for lazy evaluation" + REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW + DESCRIPTION = "Replace list comprehension with generator expression" + REFERENCES = [ + { + "url": "https://pylint.readthedocs.io/en/latest/user_guide/messages/refactor/use-a-generator.html", + "description": "", + }, + { + "url": "https://docs.python.org/3/glossary.html#term-generator-expression", + "description": "", + }, + { + "url": "https://docs.python.org/3/glossary.html#term-list-comprehension", + "description": "", + }, + ] + + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): + match original_node.func: + # NOTE: could also support things like `list` and `tuple` + # but it's a less compelling use case + case cst.Name("any" | "all" | "sum" | "min" | "max"): + if self.is_builtin_function(original_node): + match original_node.args[0].value: + case cst.ListComp(elt=elt, for_in=for_in): + self.add_change(original_node, self.CHANGE_DESCRIPTION) + return updated_node.with_changes( + args=[ + cst.Arg( + value=cst.GeneratorExp( + elt=elt, # type: ignore + for_in=for_in, # type: ignore + # No parens necessary since they are + # already included by the call expr itself + lpar=[], + rpar=[], + ) + ) + ], + ) + + return original_node diff --git a/tests/codemods/test_use_generator.py b/tests/codemods/test_use_generator.py new file mode 100644 index 00000000..b8cacf2e --- /dev/null +++ b/tests/codemods/test_use_generator.py @@ -0,0 +1,31 @@ +import pytest + +from core_codemods.use_generator import UseGenerator +from tests.codemods.base_codemod_test import BaseCodemodTest + + +class TestUseGenerator(BaseCodemodTest): + codemod = UseGenerator + + @pytest.mark.parametrize("func", ["any", "all", "sum", "min", "max"]) + def test_list_comprehension(self, tmpdir, func): + original_code = f""" + x = {func}([i for i in range(10)]) + """ + new_code = f""" + x = {func}(i for i in range(10)) + """ + self.run_and_assert(tmpdir, original_code, new_code) + + def test_not_special_builtin(self, tmpdir): + expected = original_code = """ + x = some([i for i in range(10)]) + """ + self.run_and_assert(tmpdir, original_code, expected) + + def test_not_global_function(self, tmpdir): + expected = original_code = """ + from foo import any + x = any([i for i in range(10)]) + """ + self.run_and_assert(tmpdir, original_code, expected) diff --git a/tests/samples/use_generator.py b/tests/samples/use_generator.py new file mode 100644 index 00000000..d427ee0f --- /dev/null +++ b/tests/samples/use_generator.py @@ -0,0 +1,7 @@ +def some(iterable): + for i in iterable: + yield i + + +x = sum([i for i in range(1000)]) +y = some([i for i in range(1000)])