Skip to content

Commit

Permalink
Fixed a bug in pragma parsing for RemoveUnusedImports
Browse files Browse the repository at this point in the history
  • Loading branch information
andrecsilva committed Oct 26, 2023
1 parent 4663b59 commit aa6d0f5
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 27 deletions.
4 changes: 3 additions & 1 deletion src/codemodder/project_analysis/python_repo_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ def package_stores(self) -> list[PackageStore]:
def _parse_all_stores(self) -> list[PackageStore]:
discovered_pkg_stores: list[PackageStore] = []
for store in self._potential_stores:
discovered_pkg_stores.extend(store(self.parent_directory).parse())
discovered_pkg_stores.extend(
store(self.parent_directory).parse() # type: ignore
)
return discovered_pkg_stores
74 changes: 48 additions & 26 deletions src/core_codemods/remove_unused_imports.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from libcst import ensure_type, matchers
from libcst import CSTVisitor, ensure_type, matchers
from libcst.codemod.visitors import GatherUnusedImportsVisitor
from libcst.metadata import (
PositionProvider,
Expand All @@ -20,7 +20,7 @@
import re
from pylint.utils.pragma_parser import parse_pragma

NOQA_PATTERN = re.compile(r"^#\s*noqa")
NOQA_PATTERN = re.compile(r"^#\s*noqa", re.IGNORECASE)


class RemoveUnusedImports(BaseCodemod, Codemod):
Expand Down Expand Up @@ -78,18 +78,23 @@ def _is_disabled_by_linter(self, node: cst.CSTNode) -> bool:
if parent and matchers.matches(parent, matchers.SimpleStatementLine()):
stmt = ensure_type(parent, cst.SimpleStatementLine)

# has a trailing comment string
trailing_comment_string = (
stmt.trailing_whitespace.comment.value
if stmt.trailing_whitespace.comment
else None
)
if trailing_comment_string and NOQA_PATTERN.match(trailing_comment_string):
return True
if trailing_comment_string and _is_pylint_disable_unused_imports(
trailing_comment_string
):
return True
# has a trailing comment string anywhere in the node
comments_visitor = GatherCommentNodes()
stmt.body[0].visit(comments_visitor)
# has a trailing comment string anywhere in the node
if stmt.trailing_whitespace.comment:
comments_visitor.comments.append(stmt.trailing_whitespace.comment)

for comment in comments_visitor.comments:
trailing_comment_string = comment.value
if trailing_comment_string and NOQA_PATTERN.match(
trailing_comment_string
):
return True
if trailing_comment_string and _is_pylint_disable_unused_imports(
trailing_comment_string
):
return True

# has a comment right above it
if matchers.matches(
Expand All @@ -111,25 +116,42 @@ def _is_disabled_by_linter(self, node: cst.CSTNode) -> bool:
return False


class GatherCommentNodes(CSTVisitor):
def __init__(self) -> None:
self.comments: list[cst.Comment] = []
super().__init__()

def leave_Comment(self, original_node: cst.Comment) -> None:
self.comments.append(original_node)


def match_line(pos, line):
return pos.start.line == line and pos.end.line == line


def _is_pylint_disable_unused_imports(comment: str) -> bool:
parsed = parse_pragma(comment)
for p in parsed:
if p.action == "disable" and (
"unused-import" in p.messages or "W0611" in p.messages
):
return True
# If pragma parse fails, ignore
try:
parsed = parse_pragma(comment)
for p in parsed:
if p.action == "disable" and (
"unused-import" in p.messages or "W0611" in p.messages
):
return True
except Exception:
pass
return False


def _is_pylint_disable_next_unused_imports(comment: str) -> bool:
parsed = parse_pragma(comment)
for p in parsed:
if p.action == "disable-next" and (
"unused-import" in p.messages or "W0611" in p.messages
):
return True
# If pragma parse fails, ignore
try:
parsed = parse_pragma(comment)
for p in parsed:
if p.action == "disable-next" and (
"unused-import" in p.messages or "W0611" in p.messages
):
return True
except Exception:
pass
return False
12 changes: 12 additions & 0 deletions tests/codemods/test_remove_unused_imports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from core_codemods.remove_unused_imports import RemoveUnusedImports
from tests.codemods.base_codemod_test import BaseCodemodTest
from textwrap import dedent


class TestRemoveUnusedImports(BaseCodemodTest):
Expand Down Expand Up @@ -90,6 +91,17 @@ def test_dont_remove_if_noqa_trailing(self, tmpdir):
self.run_and_assert(tmpdir, before, before)
assert len(self.file_context.codemod_changes) == 0

def test_dont_remove_if_noqa_trailing_multiline(self, tmpdir):
before = dedent(
"""\
from _pytest.assertion.util import ( # noqa: F401
format_explanation as _format_explanation,
)"""
)

self.run_and_assert(tmpdir, before, before)
assert len(self.file_context.codemod_changes) == 0

def test_dont_remove_if_pylint_disable(self, tmpdir):
before = "import a\nimport b # pylint: disable=W0611\na()"
self.run_and_assert(tmpdir, before, before)
Expand Down

0 comments on commit aa6d0f5

Please sign in to comment.