diff --git a/src/core_codemods/combine_startswith_endswith.py b/src/core_codemods/combine_startswith_endswith.py index 070497bc..2dc42d51 100644 --- a/src/core_codemods/combine_startswith_endswith.py +++ b/src/core_codemods/combine_startswith_endswith.py @@ -17,7 +17,7 @@ def leave_BooleanOperation( if not self.filter_by_path_includes_or_excludes( self.node_position(original_node) ): - return original_node + return updated_node if self.matches_startswith_endswith_or_pattern(original_node): left_call = cst.ensure_type(updated_node.left, cst.Call) diff --git a/src/core_codemods/django_receiver_on_top.py b/src/core_codemods/django_receiver_on_top.py index 6ebeab0a..5d65bf32 100644 --- a/src/core_codemods/django_receiver_on_top.py +++ b/src/core_codemods/django_receiver_on_top.py @@ -23,6 +23,8 @@ def leave_FunctionDef( ) -> Union[ cst.BaseStatement, cst.FlattenSentinel[cst.BaseStatement], cst.RemovalSentinel ]: + # TODO: add filter by include or exclude that works for nodes + # that that have different start/end numbers. maybe_receiver_with_index = None for i, decorator in enumerate(original_node.decorators): true_name = self.find_base_name(decorator.decorator) diff --git a/src/core_codemods/exception_without_raise.py b/src/core_codemods/exception_without_raise.py index 3a796a32..c293d0d6 100644 --- a/src/core_codemods/exception_without_raise.py +++ b/src/core_codemods/exception_without_raise.py @@ -27,6 +27,11 @@ def leave_SimpleStatementLine( ) -> Union[ cst.BaseStatement, cst.FlattenSentinel[cst.BaseStatement], cst.RemovalSentinel ]: + if not self.filter_by_path_includes_or_excludes( + self.node_position(original_node) + ): + return updated_node + match original_node: case cst.SimpleStatementLine( body=[cst.Expr(cst.Name() | cst.Attribute() as name)] diff --git a/src/core_codemods/fix_deprecated_abstractproperty.py b/src/core_codemods/fix_deprecated_abstractproperty.py index 3ed64a11..c4c0dd38 100644 --- a/src/core_codemods/fix_deprecated_abstractproperty.py +++ b/src/core_codemods/fix_deprecated_abstractproperty.py @@ -19,6 +19,11 @@ class FixDeprecatedAbstractproperty(BaseCodemod, NameResolutionMixin): def leave_Decorator( self, original_node: cst.Decorator, updated_node: cst.Decorator ): + if not self.filter_by_path_includes_or_excludes( + self.node_position(original_node) + ): + return updated_node + if ( base_name := self.find_base_name(original_node.decorator) ) and base_name == "abc.abstractproperty": diff --git a/src/core_codemods/fix_mutable_params.py b/src/core_codemods/fix_mutable_params.py index a1ebf898..ae991131 100644 --- a/src/core_codemods/fix_mutable_params.py +++ b/src/core_codemods/fix_mutable_params.py @@ -154,6 +154,8 @@ def leave_FunctionDef( updated_node: cst.FunctionDef, ): """Transforms function definitions with mutable default parameters""" + # TODO: add filter by include or exclude that works for nodes + # that that have different start/end numbers. ( updated_params, new_var_decls, diff --git a/src/core_codemods/remove_debug_breakpoint.py b/src/core_codemods/remove_debug_breakpoint.py index 7a35846f..27e9eeae 100644 --- a/src/core_codemods/remove_debug_breakpoint.py +++ b/src/core_codemods/remove_debug_breakpoint.py @@ -12,12 +12,14 @@ class RemoveDebugBreakpoint(BaseCodemod, NameResolutionMixin, AncestorPatternsMi REFERENCES: list = [] def leave_Expr( - self, original_node: cst.Expr, _ + self, + original_node: cst.Expr, + updated_node: cst.Expr, ) -> Union[cst.Expr, cst.RemovalSentinel]: if not self.filter_by_path_includes_or_excludes( self.node_position(original_node) ): - return original_node + return updated_node match call_node := original_node.value: case cst.Call(): @@ -29,4 +31,4 @@ def leave_Expr( self.report_change(original_node) return cst.RemovalSentinel.REMOVE - return original_node + return updated_node diff --git a/src/core_codemods/remove_module_global.py b/src/core_codemods/remove_module_global.py index 1b404093..b0f1d28b 100644 --- a/src/core_codemods/remove_module_global.py +++ b/src/core_codemods/remove_module_global.py @@ -13,12 +13,14 @@ class RemoveModuleGlobal(BaseCodemod, NameResolutionMixin): REFERENCES: list = [] def leave_Global( - self, original_node: cst.Global, _ + self, + original_node: cst.Global, + updated_node: cst.Global, ) -> Union[cst.Global, cst.RemovalSentinel,]: if not self.filter_by_path_includes_or_excludes( self.node_position(original_node) ): - return original_node + return updated_node scope = self.get_metadata(ScopeProvider, original_node) if isinstance(scope, GlobalScope): self.report_change(original_node) diff --git a/src/core_codemods/remove_unnecessary_f_str.py b/src/core_codemods/remove_unnecessary_f_str.py index c07e8246..d9998a7b 100644 --- a/src/core_codemods/remove_unnecessary_f_str.py +++ b/src/core_codemods/remove_unnecessary_f_str.py @@ -34,6 +34,11 @@ def _check_formatted_string( _original_node: cst.FormattedString, updated_node: cst.FormattedString, ): + if not self.filter_by_path_includes_or_excludes( + self.node_position(_original_node) + ): + return updated_node + transformed_node = super()._check_formatted_string(_original_node, updated_node) if not _original_node.deep_equals(transformed_node): self.report_change(_original_node) diff --git a/src/core_codemods/subprocess_shell_false.py b/src/core_codemods/subprocess_shell_false.py index 4286d5ab..536d651b 100644 --- a/src/core_codemods/subprocess_shell_false.py +++ b/src/core_codemods/subprocess_shell_false.py @@ -33,7 +33,7 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): if not self.filter_by_path_includes_or_excludes( self.node_position(original_node) ): - return original_node + return updated_node if self.find_base_name(original_node.func) in self.SUBPROCESS_FUNCS: for arg in original_node.args: diff --git a/src/core_codemods/use_generator.py b/src/core_codemods/use_generator.py index 0a507db4..6b80e196 100644 --- a/src/core_codemods/use_generator.py +++ b/src/core_codemods/use_generator.py @@ -25,6 +25,11 @@ class UseGenerator(BaseCodemod, NameResolutionMixin): ] def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): + if not self.filter_by_path_includes_or_excludes( + self.node_position(original_node) + ): + return updated_node + match original_node.func: # NOTE: could also support things like `list` and `tuple` # but it's a less compelling use case diff --git a/src/core_codemods/use_walrus_if.py b/src/core_codemods/use_walrus_if.py index a66988fb..e2866df5 100644 --- a/src/core_codemods/use_walrus_if.py +++ b/src/core_codemods/use_walrus_if.py @@ -118,6 +118,8 @@ def visit_If(self, node: cst.If): ) def leave_If(self, original_node, updated_node): + # TODO: add filter by include or exclude that works for nodes + # that that have different start/end numbers. if (result := self._if_stack.pop()) is not None: position, named_expr = result is_name = m.matches(updated_node.test, m.Name()) diff --git a/tests/codemods/base_codemod_test.py b/tests/codemods/base_codemod_test.py index 221aaa47..74aa3e06 100644 --- a/tests/codemods/base_codemod_test.py +++ b/tests/codemods/base_codemod_test.py @@ -33,6 +33,32 @@ def run_and_assert(self, tmpdir, input_code, expected): tmp_file_path = Path(tmpdir / "code.py") self.run_and_assert_filepath(tmpdir, tmp_file_path, input_code, expected) + def assert_no_change_line_excluded( + self, tmpdir, input_code, expected, lines_to_exclude + ): + tmp_file_path = Path(tmpdir / "code.py") + input_tree = cst.parse_module(dedent(input_code)) + self.execution_context = CodemodExecutionContext( + directory=tmpdir, + dry_run=True, + verbose=False, + registry=mock.MagicMock(), + repo_manager=mock.MagicMock(), + ) + + self.file_context = FileContext( + tmpdir, + tmp_file_path, + lines_to_exclude, + [], + [], + ) + codemod_instance = self.initialize_codemod(input_tree) + output_tree = codemod_instance.transform_module(input_tree) + + assert output_tree.code == dedent(expected) + assert len(self.file_context.codemod_changes) == 0 + def run_and_assert_filepath(self, root, file_path, input_code, expected): input_tree = cst.parse_module(dedent(input_code)) self.execution_context = CodemodExecutionContext( diff --git a/tests/codemods/test_combine_startswith_endswith.py b/tests/codemods/test_combine_startswith_endswith.py index 19305f9d..6ce21923 100644 --- a/tests/codemods/test_combine_startswith_endswith.py +++ b/tests/codemods/test_combine_startswith_endswith.py @@ -38,3 +38,13 @@ def test_combine(self, tmpdir, func): def test_no_change(self, tmpdir, code): self.run_and_assert(tmpdir, code, code) assert len(self.file_context.codemod_changes) == 0 + + def test_exclude_line(self, tmpdir): + input_code = expected = """\ + x = "foo" + x.startswith("foo") or x.startswith("f") + """ + lines_to_exclude = [2] + self.assert_no_change_line_excluded( + tmpdir, input_code, expected, lines_to_exclude + ) diff --git a/tests/codemods/test_exception_without_raise.py b/tests/codemods/test_exception_without_raise.py index 6ec37df9..dd1f8b3a 100644 --- a/tests/codemods/test_exception_without_raise.py +++ b/tests/codemods/test_exception_without_raise.py @@ -54,3 +54,13 @@ def test_raised_exception(self, tmpdir): """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) assert len(self.file_context.codemod_changes) == 0 + + def test_exclude_line(self, tmpdir): + input_code = expected = """\ + print(1) + ValueError("Bad value!") + """ + lines_to_exclude = [2] + self.assert_no_change_line_excluded( + tmpdir, input_code, expected, lines_to_exclude + ) diff --git a/tests/codemods/test_fix_deprecated_abstractproperty.py b/tests/codemods/test_fix_deprecated_abstractproperty.py index 71150586..c4e7db1b 100644 --- a/tests/codemods/test_fix_deprecated_abstractproperty.py +++ b/tests/codemods/test_fix_deprecated_abstractproperty.py @@ -121,3 +121,17 @@ def foo(self): pass """ self.run_and_assert(tmpdir, original_code, new_code) + + def test_exclude_line(self, tmpdir): + input_code = expected = """\ + import abc + + class A: + @abc.abstractproperty + def foo(self): + pass + """ + lines_to_exclude = [4] + self.assert_no_change_line_excluded( + tmpdir, input_code, expected, lines_to_exclude + ) diff --git a/tests/codemods/test_remove_debug_breakpoint.py b/tests/codemods/test_remove_debug_breakpoint.py index 27e10b83..f384c1c2 100644 --- a/tests/codemods/test_remove_debug_breakpoint.py +++ b/tests/codemods/test_remove_debug_breakpoint.py @@ -85,3 +85,13 @@ def something(): """ self.run_and_assert(tmpdir, input_code, expected) assert len(self.file_context.codemod_changes) == 1 + + def test_exclude_line(self, tmpdir): + input_code = expected = """\ + x = "foo" + breakpoint() + """ + lines_to_exclude = [2] + self.assert_no_change_line_excluded( + tmpdir, input_code, expected, lines_to_exclude + ) diff --git a/tests/codemods/test_remove_unnecessary_f_str.py b/tests/codemods/test_remove_unnecessary_f_str.py index 35c49c32..6e70eede 100644 --- a/tests/codemods/test_remove_unnecessary_f_str.py +++ b/tests/codemods/test_remove_unnecessary_f_str.py @@ -32,3 +32,12 @@ def test_change(self, tmpdir): """ self.run_and_assert(tmpdir, before, after) assert len(self.file_context.codemod_changes) == 3 + + def test_exclude_line(self, tmpdir): + input_code = expected = """\ + bad: str = f"bad" + "bad" + """ + lines_to_exclude = [1] + self.assert_no_change_line_excluded( + tmpdir, input_code, expected, lines_to_exclude + ) diff --git a/tests/codemods/test_subprocess_shell_false.py b/tests/codemods/test_subprocess_shell_false.py index b19c9513..4b178a48 100644 --- a/tests/codemods/test_subprocess_shell_false.py +++ b/tests/codemods/test_subprocess_shell_false.py @@ -56,3 +56,13 @@ def test_shell_False(self, tmpdir, func): """ self.run_and_assert(tmpdir, input_code, input_code) assert len(self.file_context.codemod_changes) == 0 + + def test_exclude_line(self, tmpdir): + input_code = expected = """\ + import subprocess + subprocess.run(args, shell=True) + """ + lines_to_exclude = [2] + self.assert_no_change_line_excluded( + tmpdir, input_code, expected, lines_to_exclude + ) diff --git a/tests/codemods/test_use_generator.py b/tests/codemods/test_use_generator.py index b8cacf2e..709c76af 100644 --- a/tests/codemods/test_use_generator.py +++ b/tests/codemods/test_use_generator.py @@ -29,3 +29,12 @@ def test_not_global_function(self, tmpdir): x = any([i for i in range(10)]) """ self.run_and_assert(tmpdir, original_code, expected) + + def test_exclude_line(self, tmpdir): + input_code = expected = """\ + x = any([i for i in range(10)]) + """ + lines_to_exclude = [1] + self.assert_no_change_line_excluded( + tmpdir, input_code, expected, lines_to_exclude + )