From 229e7ea1c2ef7b4b0c0910f94bb46f44e2fd4dd1 Mon Sep 17 00:00:00 2001 From: Logan Hunt <39638017+dosisod@users.noreply.github.com> Date: Fri, 20 Oct 2023 20:13:15 -0700 Subject: [PATCH] Add better detection of list/set comprehensions (#300): See https://github.com/dosisod/refurb/issues/298#issuecomment-1769238564 This PR adds better detection of list/set comprehensions. Previously Refurb checks ran depth first, meaning that the leaf nodes where the first nodes to be hit. Now the checks are ran before traversing, meaning checks will run on the root nodes first. This allows checks more flexibility, though for this PR, it allows them to ignore nodes that they have already seen which prevents multiple errors being emitted in certain circumstances. Also bump version for new release. --- pyproject.toml | 2 +- .../itertools/use_chain_from_iterable.py | 32 +++++++++- refurb/checks/itertools/use_starmap.py | 60 ++++++++++++++++++- refurb/visitor/visitor.py | 4 +- test/data/err_140.txt | 4 +- test/data/err_179.txt | 2 +- 6 files changed, 94 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bb6977d..cf7cdad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "refurb" -version = "1.22.0" +version = "1.22.1" description = "A tool for refurbish and modernize Python codebases" authors = ["dosisod"] license = "GPL-3.0-only" diff --git a/refurb/checks/itertools/use_chain_from_iterable.py b/refurb/checks/itertools/use_chain_from_iterable.py index 6fb6a5d..0952a31 100644 --- a/refurb/checks/itertools/use_chain_from_iterable.py +++ b/refurb/checks/itertools/use_chain_from_iterable.py @@ -4,6 +4,7 @@ ArgKind, CallExpr, GeneratorExpr, + ListComprehension, ListExpr, NameExpr, RefExpr, @@ -60,7 +61,7 @@ class ErrorInfo(Error): code = 179 -def check(node: GeneratorExpr | CallExpr, errors: list[Error]) -> None: +def is_flatten_generator(node: GeneratorExpr) -> bool: match node: case GeneratorExpr( left_expr=RefExpr(fullname=expr), @@ -69,6 +70,35 @@ def check(node: GeneratorExpr | CallExpr, errors: list[Error]) -> None: is_async=[False, False], condlists=[[], []], ) if expr == inner and inner_source == outer: + return True + + return False + + +# List of nodes we have already emitted errors for, since list comprehensions +# and their inner generators will emit 2 errors. +ignore = set[int]() + + +def check( + node: ListComprehension | GeneratorExpr | CallExpr, + errors: list[Error], +) -> None: + if id(node) in ignore: + return + + match node: + case ListComprehension(generator=g) if is_flatten_generator(g): + old = "[... for ... in x for ... in ...]" + new = "list(chain.from_iterable(x))" + + msg = f"Replace `{old}` with `{new}`" + + errors.append(ErrorInfo.from_node(node, msg)) + + ignore.add(id(g)) + + case GeneratorExpr() if is_flatten_generator(node): old = "... for ... in x for ... in ..." new = "chain.from_iterable(x)" diff --git a/refurb/checks/itertools/use_starmap.py b/refurb/checks/itertools/use_starmap.py index 5914beb..60203ea 100644 --- a/refurb/checks/itertools/use_starmap.py +++ b/refurb/checks/itertools/use_starmap.py @@ -1,6 +1,14 @@ from dataclasses import dataclass -from mypy.nodes import ArgKind, CallExpr, GeneratorExpr, NameExpr, TupleExpr +from mypy.nodes import ( + ArgKind, + CallExpr, + GeneratorExpr, + ListComprehension, + NameExpr, + SetComprehension, + TupleExpr, +) from refurb.error import Error @@ -49,7 +57,15 @@ def passed_test(score: int, passing_score: int) -> bool: categories = ("itertools", "performance") -def check(node: GeneratorExpr, errors: list[Error]) -> None: +ignore = set[int]() + + +def check_generator( + node: GeneratorExpr, + errors: list[Error], + old_wrapper: str = "{}", + new_wrapper: str = "{}", +) -> None: match node: case GeneratorExpr( left_expr=CallExpr(args=args, arg_kinds=arg_kinds), @@ -67,4 +83,42 @@ def check(node: GeneratorExpr, errors: list[Error]) -> None: ): return - errors.append(ErrorInfo.from_node(node)) + ignore.add(id(node)) + + old = "f(...) for ... in x" + old = old_wrapper.format(old) + + new = "starmap(f, x)" + new = new_wrapper.format(new) + + msg = f"Replace `{old}` with `{new}`" + + errors.append(ErrorInfo.from_node(node, msg)) + + +def check( + node: GeneratorExpr | ListComprehension | SetComprehension, + errors: list[Error], +) -> None: + if id(node) in ignore: + return + + match node: + case GeneratorExpr(): + check_generator(node, errors) + + case ListComprehension(generator=g): + check_generator( + g, + errors, + old_wrapper="[{}]", + new_wrapper="list({})", + ) + + case SetComprehension(generator=g): + check_generator( + g, + errors, + old_wrapper="{{{}}}", + new_wrapper="set({})", + ) diff --git a/refurb/visitor/visitor.py b/refurb/visitor/visitor.py index 1265505..9091208 100644 --- a/refurb/visitor/visitor.py +++ b/refurb/visitor/visitor.py @@ -15,11 +15,11 @@ def build_visitor(name: str, ty: type[Node], checks: Checks) -> VisitorMethod: def inner(self: RefurbVisitor, o: Node) -> None: - getattr(TraverserVisitor, name)(self, o) - for check in checks[ty]: self.run_check(o, check) + getattr(TraverserVisitor, name)(self, o) + inner.__name__ = name inner.__annotations__["o"] = ty return inner diff --git a/test/data/err_140.txt b/test/data/err_140.txt index 55dd721..3a1b7d9 100644 --- a/test/data/err_140.txt +++ b/test/data/err_140.txt @@ -1,3 +1,3 @@ -test/data/err_140.py:7:1 [FURB140]: Replace `f(...) for ... in x` with `starmap(f, x)` +test/data/err_140.py:7:1 [FURB140]: Replace `[f(...) for ... in x]` with `list(starmap(f, x))` test/data/err_140.py:9:1 [FURB140]: Replace `f(...) for ... in x` with `starmap(f, x)` -test/data/err_140.py:11:1 [FURB140]: Replace `f(...) for ... in x` with `starmap(f, x)` +test/data/err_140.py:11:1 [FURB140]: Replace `{f(...) for ... in x}` with `set(starmap(f, x))` diff --git a/test/data/err_179.txt b/test/data/err_179.txt index 2074853..ad67478 100644 --- a/test/data/err_179.txt +++ b/test/data/err_179.txt @@ -1,5 +1,5 @@ test/data/err_179.py:13:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` -test/data/err_179.py:16:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` +test/data/err_179.py:16:12 [FURB179]: Replace `[... for ... in x for ... in ...]` with `list(chain.from_iterable(x))` test/data/err_179.py:19:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` test/data/err_179.py:22:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` test/data/err_179.py:25:12 [FURB179]: Replace `sum(x, [])` with `chain.from_iterable(x)`