From 4c55b12463fda31612e947221a65a34f1b8cac89 Mon Sep 17 00:00:00 2001 From: Daniel D'Avella Date: Fri, 17 Nov 2023 16:05:25 -0500 Subject: [PATCH] Do not modify body of abstractmethods for fix-mutable-params --- src/core_codemods/fix_mutable_params.py | 14 +++++++++++++- tests/codemods/test_fix_mutable_params.py | 19 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/core_codemods/fix_mutable_params.py b/src/core_codemods/fix_mutable_params.py index d95bcb3b..a1ebf898 100644 --- a/src/core_codemods/fix_mutable_params.py +++ b/src/core_codemods/fix_mutable_params.py @@ -140,6 +140,14 @@ def _build_new_body(self, new_var_decls, body): new_body.extend(body[offset:]) return new_body + def _is_abstractmethod(self, node: cst.FunctionDef) -> bool: + for decorator in node.decorators: + match decorator.decorator: + case cst.Name("abstractmethod"): + return True + + return False + def leave_FunctionDef( self, original_node: cst.FunctionDef, @@ -151,7 +159,11 @@ def leave_FunctionDef( new_var_decls, add_annotation, ) = self._gather_and_update_params(original_node, updated_node) - new_body = self._build_new_body(new_var_decls, updated_node.body.body) + new_body = ( + self._build_new_body(new_var_decls, updated_node.body.body) + if not self._is_abstractmethod(original_node) + else updated_node.body.body + ) if new_var_decls: # If we're adding statements to the body, we know a change took place self.add_change(original_node, self.CHANGE_DESCRIPTION) diff --git a/tests/codemods/test_fix_mutable_params.py b/tests/codemods/test_fix_mutable_params.py index 8854570e..8145bd3f 100644 --- a/tests/codemods/test_fix_mutable_params.py +++ b/tests/codemods/test_fix_mutable_params.py @@ -237,3 +237,22 @@ def func(foo=None): print(foo) """ self.run_and_assert(tmpdir, input_code, expected_output) + + def test_dont_modify_abstractmethod_body(self, tmpdir): + input_code = """ + from abc import ABC, abstractmethod + + class Foo(ABC): + @abstractmethod + def foo(self, bar=[]): + pass + """ + expected_output = """ + from abc import ABC, abstractmethod + + class Foo(ABC): + @abstractmethod + def foo(self, bar=None): + pass + """ + self.run_and_assert(tmpdir, input_code, expected_output)