diff --git a/pyproject.toml b/pyproject.toml index bb6977d..cf7cdad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "refurb" -version = "1.22.0" +version = "1.22.1" description = "A tool for refurbish and modernize Python codebases" authors = ["dosisod"] license = "GPL-3.0-only" diff --git a/refurb/checks/itertools/use_chain_from_iterable.py b/refurb/checks/itertools/use_chain_from_iterable.py index 6fb6a5d..0952a31 100644 --- a/refurb/checks/itertools/use_chain_from_iterable.py +++ b/refurb/checks/itertools/use_chain_from_iterable.py @@ -4,6 +4,7 @@ ArgKind, CallExpr, GeneratorExpr, + ListComprehension, ListExpr, NameExpr, RefExpr, @@ -60,7 +61,7 @@ class ErrorInfo(Error): code = 179 -def check(node: GeneratorExpr | CallExpr, errors: list[Error]) -> None: +def is_flatten_generator(node: GeneratorExpr) -> bool: match node: case GeneratorExpr( left_expr=RefExpr(fullname=expr), @@ -69,6 +70,35 @@ def check(node: GeneratorExpr | CallExpr, errors: list[Error]) -> None: is_async=[False, False], condlists=[[], []], ) if expr == inner and inner_source == outer: + return True + + return False + + +# List of nodes we have already emitted errors for, since list comprehensions +# and their inner generators will emit 2 errors. +ignore = set[int]() + + +def check( + node: ListComprehension | GeneratorExpr | CallExpr, + errors: list[Error], +) -> None: + if id(node) in ignore: + return + + match node: + case ListComprehension(generator=g) if is_flatten_generator(g): + old = "[... for ... in x for ... in ...]" + new = "list(chain.from_iterable(x))" + + msg = f"Replace `{old}` with `{new}`" + + errors.append(ErrorInfo.from_node(node, msg)) + + ignore.add(id(g)) + + case GeneratorExpr() if is_flatten_generator(node): old = "... for ... in x for ... in ..." new = "chain.from_iterable(x)" diff --git a/refurb/checks/itertools/use_starmap.py b/refurb/checks/itertools/use_starmap.py index 5914beb..60203ea 100644 --- a/refurb/checks/itertools/use_starmap.py +++ b/refurb/checks/itertools/use_starmap.py @@ -1,6 +1,14 @@ from dataclasses import dataclass -from mypy.nodes import ArgKind, CallExpr, GeneratorExpr, NameExpr, TupleExpr +from mypy.nodes import ( + ArgKind, + CallExpr, + GeneratorExpr, + ListComprehension, + NameExpr, + SetComprehension, + TupleExpr, +) from refurb.error import Error @@ -49,7 +57,15 @@ def passed_test(score: int, passing_score: int) -> bool: categories = ("itertools", "performance") -def check(node: GeneratorExpr, errors: list[Error]) -> None: +ignore = set[int]() + + +def check_generator( + node: GeneratorExpr, + errors: list[Error], + old_wrapper: str = "{}", + new_wrapper: str = "{}", +) -> None: match node: case GeneratorExpr( left_expr=CallExpr(args=args, arg_kinds=arg_kinds), @@ -67,4 +83,42 @@ def check(node: GeneratorExpr, errors: list[Error]) -> None: ): return - errors.append(ErrorInfo.from_node(node)) + ignore.add(id(node)) + + old = "f(...) for ... in x" + old = old_wrapper.format(old) + + new = "starmap(f, x)" + new = new_wrapper.format(new) + + msg = f"Replace `{old}` with `{new}`" + + errors.append(ErrorInfo.from_node(node, msg)) + + +def check( + node: GeneratorExpr | ListComprehension | SetComprehension, + errors: list[Error], +) -> None: + if id(node) in ignore: + return + + match node: + case GeneratorExpr(): + check_generator(node, errors) + + case ListComprehension(generator=g): + check_generator( + g, + errors, + old_wrapper="[{}]", + new_wrapper="list({})", + ) + + case SetComprehension(generator=g): + check_generator( + g, + errors, + old_wrapper="{{{}}}", + new_wrapper="set({})", + ) diff --git a/refurb/visitor/visitor.py b/refurb/visitor/visitor.py index 1265505..9091208 100644 --- a/refurb/visitor/visitor.py +++ b/refurb/visitor/visitor.py @@ -15,11 +15,11 @@ def build_visitor(name: str, ty: type[Node], checks: Checks) -> VisitorMethod: def inner(self: RefurbVisitor, o: Node) -> None: - getattr(TraverserVisitor, name)(self, o) - for check in checks[ty]: self.run_check(o, check) + getattr(TraverserVisitor, name)(self, o) + inner.__name__ = name inner.__annotations__["o"] = ty return inner diff --git a/test/data/err_140.txt b/test/data/err_140.txt index 55dd721..3a1b7d9 100644 --- a/test/data/err_140.txt +++ b/test/data/err_140.txt @@ -1,3 +1,3 @@ -test/data/err_140.py:7:1 [FURB140]: Replace `f(...) for ... in x` with `starmap(f, x)` +test/data/err_140.py:7:1 [FURB140]: Replace `[f(...) for ... in x]` with `list(starmap(f, x))` test/data/err_140.py:9:1 [FURB140]: Replace `f(...) for ... in x` with `starmap(f, x)` -test/data/err_140.py:11:1 [FURB140]: Replace `f(...) for ... in x` with `starmap(f, x)` +test/data/err_140.py:11:1 [FURB140]: Replace `{f(...) for ... in x}` with `set(starmap(f, x))` diff --git a/test/data/err_179.txt b/test/data/err_179.txt index 2074853..ad67478 100644 --- a/test/data/err_179.txt +++ b/test/data/err_179.txt @@ -1,5 +1,5 @@ test/data/err_179.py:13:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` -test/data/err_179.py:16:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` +test/data/err_179.py:16:12 [FURB179]: Replace `[... for ... in x for ... in ...]` with `list(chain.from_iterable(x))` test/data/err_179.py:19:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` test/data/err_179.py:22:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` test/data/err_179.py:25:12 [FURB179]: Replace `sum(x, [])` with `chain.from_iterable(x)`