diff --git a/src/codemodder/project_analysis/python_repo_manager.py b/src/codemodder/project_analysis/python_repo_manager.py index c73e4656..36332dfd 100644 --- a/src/codemodder/project_analysis/python_repo_manager.py +++ b/src/codemodder/project_analysis/python_repo_manager.py @@ -26,5 +26,7 @@ def package_stores(self) -> list[PackageStore]: def _parse_all_stores(self) -> list[PackageStore]: discovered_pkg_stores: list[PackageStore] = [] for store in self._potential_stores: - discovered_pkg_stores.extend(store(self.parent_directory).parse()) + discovered_pkg_stores.extend( + store(self.parent_directory).parse() # type: ignore + ) return discovered_pkg_stores diff --git a/src/core_codemods/remove_unused_imports.py b/src/core_codemods/remove_unused_imports.py index 6d3c43c1..2fb25300 100644 --- a/src/core_codemods/remove_unused_imports.py +++ b/src/core_codemods/remove_unused_imports.py @@ -1,4 +1,4 @@ -from libcst import ensure_type, matchers +from libcst import CSTVisitor, ensure_type, matchers from libcst.codemod.visitors import GatherUnusedImportsVisitor from libcst.metadata import ( PositionProvider, @@ -20,7 +20,7 @@ import re from pylint.utils.pragma_parser import parse_pragma -NOQA_PATTERN = re.compile(r"^#\s*noqa") +NOQA_PATTERN = re.compile(r"^#\s*noqa", re.IGNORECASE) class RemoveUnusedImports(BaseCodemod, Codemod): @@ -78,18 +78,23 @@ def _is_disabled_by_linter(self, node: cst.CSTNode) -> bool: if parent and matchers.matches(parent, matchers.SimpleStatementLine()): stmt = ensure_type(parent, cst.SimpleStatementLine) - # has a trailing comment string - trailing_comment_string = ( - stmt.trailing_whitespace.comment.value - if stmt.trailing_whitespace.comment - else None - ) - if trailing_comment_string and NOQA_PATTERN.match(trailing_comment_string): - return True - if trailing_comment_string and _is_pylint_disable_unused_imports( - trailing_comment_string - ): - return True + # has a trailing comment string anywhere in the node + comments_visitor = GatherCommentNodes() + stmt.body[0].visit(comments_visitor) + # has a trailing comment string anywhere in the node + if stmt.trailing_whitespace.comment: + comments_visitor.comments.append(stmt.trailing_whitespace.comment) + + for comment in comments_visitor.comments: + trailing_comment_string = comment.value + if trailing_comment_string and NOQA_PATTERN.match( + trailing_comment_string + ): + return True + if trailing_comment_string and _is_pylint_disable_unused_imports( + trailing_comment_string + ): + return True # has a comment right above it if matchers.matches( @@ -111,25 +116,42 @@ def _is_disabled_by_linter(self, node: cst.CSTNode) -> bool: return False +class GatherCommentNodes(CSTVisitor): + def __init__(self) -> None: + self.comments: list[cst.Comment] = [] + super().__init__() + + def leave_Comment(self, original_node: cst.Comment) -> None: + self.comments.append(original_node) + + def match_line(pos, line): return pos.start.line == line and pos.end.line == line def _is_pylint_disable_unused_imports(comment: str) -> bool: - parsed = parse_pragma(comment) - for p in parsed: - if p.action == "disable" and ( - "unused-import" in p.messages or "W0611" in p.messages - ): - return True + # If pragma parse fails, ignore + try: + parsed = parse_pragma(comment) + for p in parsed: + if p.action == "disable" and ( + "unused-import" in p.messages or "W0611" in p.messages + ): + return True + except Exception: + pass return False def _is_pylint_disable_next_unused_imports(comment: str) -> bool: - parsed = parse_pragma(comment) - for p in parsed: - if p.action == "disable-next" and ( - "unused-import" in p.messages or "W0611" in p.messages - ): - return True + # If pragma parse fails, ignore + try: + parsed = parse_pragma(comment) + for p in parsed: + if p.action == "disable-next" and ( + "unused-import" in p.messages or "W0611" in p.messages + ): + return True + except Exception: + pass return False diff --git a/tests/codemods/test_remove_unused_imports.py b/tests/codemods/test_remove_unused_imports.py index 43ee5e55..1498a251 100644 --- a/tests/codemods/test_remove_unused_imports.py +++ b/tests/codemods/test_remove_unused_imports.py @@ -1,5 +1,6 @@ from core_codemods.remove_unused_imports import RemoveUnusedImports from tests.codemods.base_codemod_test import BaseCodemodTest +from textwrap import dedent class TestRemoveUnusedImports(BaseCodemodTest): @@ -90,6 +91,17 @@ def test_dont_remove_if_noqa_trailing(self, tmpdir): self.run_and_assert(tmpdir, before, before) assert len(self.file_context.codemod_changes) == 0 + def test_dont_remove_if_noqa_trailing_multiline(self, tmpdir): + before = dedent( + """\ + from _pytest.assertion.util import ( # noqa: F401 + format_explanation as _format_explanation, + )""" + ) + + self.run_and_assert(tmpdir, before, before) + assert len(self.file_context.codemod_changes) == 0 + def test_dont_remove_if_pylint_disable(self, tmpdir): before = "import a\nimport b # pylint: disable=W0611\na()" self.run_and_assert(tmpdir, before, before)