Skip to content

Commit

Permalink
Add better detection of list/set comprehensions (#300):
Browse files Browse the repository at this point in the history
See #298 (comment)

This PR adds better detection of list/set comprehensions. Previously Refurb
checks ran depth first, meaning that the leaf nodes where the first nodes to
be hit. Now the checks are ran before traversing, meaning checks will run on
the root nodes first. This allows checks more flexibility, though for this
PR, it allows them to ignore nodes that they have already seen which prevents
multiple errors being emitted in certain circumstances.

Also bump version for new release.
  • Loading branch information
dosisod authored Oct 21, 2023
1 parent 5eb602b commit 229e7ea
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "refurb"
version = "1.22.0"
version = "1.22.1"
description = "A tool for refurbish and modernize Python codebases"
authors = ["dosisod"]
license = "GPL-3.0-only"
Expand Down
32 changes: 31 additions & 1 deletion refurb/checks/itertools/use_chain_from_iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
ArgKind,
CallExpr,
GeneratorExpr,
ListComprehension,
ListExpr,
NameExpr,
RefExpr,
Expand Down Expand Up @@ -60,7 +61,7 @@ class ErrorInfo(Error):
code = 179


def check(node: GeneratorExpr | CallExpr, errors: list[Error]) -> None:
def is_flatten_generator(node: GeneratorExpr) -> bool:
match node:
case GeneratorExpr(
left_expr=RefExpr(fullname=expr),
Expand All @@ -69,6 +70,35 @@ def check(node: GeneratorExpr | CallExpr, errors: list[Error]) -> None:
is_async=[False, False],
condlists=[[], []],
) if expr == inner and inner_source == outer:
return True

return False


# List of nodes we have already emitted errors for, since list comprehensions
# and their inner generators will emit 2 errors.
ignore = set[int]()


def check(
node: ListComprehension | GeneratorExpr | CallExpr,
errors: list[Error],
) -> None:
if id(node) in ignore:
return

match node:
case ListComprehension(generator=g) if is_flatten_generator(g):
old = "[... for ... in x for ... in ...]"
new = "list(chain.from_iterable(x))"

msg = f"Replace `{old}` with `{new}`"

errors.append(ErrorInfo.from_node(node, msg))

ignore.add(id(g))

case GeneratorExpr() if is_flatten_generator(node):
old = "... for ... in x for ... in ..."
new = "chain.from_iterable(x)"

Expand Down
60 changes: 57 additions & 3 deletions refurb/checks/itertools/use_starmap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from dataclasses import dataclass

from mypy.nodes import ArgKind, CallExpr, GeneratorExpr, NameExpr, TupleExpr
from mypy.nodes import (
ArgKind,
CallExpr,
GeneratorExpr,
ListComprehension,
NameExpr,
SetComprehension,
TupleExpr,
)

from refurb.error import Error

Expand Down Expand Up @@ -49,7 +57,15 @@ def passed_test(score: int, passing_score: int) -> bool:
categories = ("itertools", "performance")


def check(node: GeneratorExpr, errors: list[Error]) -> None:
ignore = set[int]()


def check_generator(
node: GeneratorExpr,
errors: list[Error],
old_wrapper: str = "{}",
new_wrapper: str = "{}",
) -> None:
match node:
case GeneratorExpr(
left_expr=CallExpr(args=args, arg_kinds=arg_kinds),
Expand All @@ -67,4 +83,42 @@ def check(node: GeneratorExpr, errors: list[Error]) -> None:
):
return

errors.append(ErrorInfo.from_node(node))
ignore.add(id(node))

old = "f(...) for ... in x"
old = old_wrapper.format(old)

new = "starmap(f, x)"
new = new_wrapper.format(new)

msg = f"Replace `{old}` with `{new}`"

errors.append(ErrorInfo.from_node(node, msg))


def check(
node: GeneratorExpr | ListComprehension | SetComprehension,
errors: list[Error],
) -> None:
if id(node) in ignore:
return

match node:
case GeneratorExpr():
check_generator(node, errors)

case ListComprehension(generator=g):
check_generator(
g,
errors,
old_wrapper="[{}]",
new_wrapper="list({})",
)

case SetComprehension(generator=g):
check_generator(
g,
errors,
old_wrapper="{{{}}}",
new_wrapper="set({})",
)
4 changes: 2 additions & 2 deletions refurb/visitor/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

def build_visitor(name: str, ty: type[Node], checks: Checks) -> VisitorMethod:
def inner(self: RefurbVisitor, o: Node) -> None:
getattr(TraverserVisitor, name)(self, o)

for check in checks[ty]:
self.run_check(o, check)

getattr(TraverserVisitor, name)(self, o)

inner.__name__ = name
inner.__annotations__["o"] = ty
return inner
Expand Down
4 changes: 2 additions & 2 deletions test/data/err_140.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
test/data/err_140.py:7:1 [FURB140]: Replace `f(...) for ... in x` with `starmap(f, x)`
test/data/err_140.py:7:1 [FURB140]: Replace `[f(...) for ... in x]` with `list(starmap(f, x))`
test/data/err_140.py:9:1 [FURB140]: Replace `f(...) for ... in x` with `starmap(f, x)`
test/data/err_140.py:11:1 [FURB140]: Replace `f(...) for ... in x` with `starmap(f, x)`
test/data/err_140.py:11:1 [FURB140]: Replace `{f(...) for ... in x}` with `set(starmap(f, x))`
2 changes: 1 addition & 1 deletion test/data/err_179.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test/data/err_179.py:13:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)`
test/data/err_179.py:16:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)`
test/data/err_179.py:16:12 [FURB179]: Replace `[... for ... in x for ... in ...]` with `list(chain.from_iterable(x))`
test/data/err_179.py:19:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)`
test/data/err_179.py:22:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)`
test/data/err_179.py:25:12 [FURB179]: Replace `sum(x, [])` with `chain.from_iterable(x)`
Expand Down

0 comments on commit 229e7ea

Please sign in to comment.