From f79d674f3e5abf595b83394b4144840a4babd2e6 Mon Sep 17 00:00:00 2001 From: Daniel D'Avella Date: Fri, 17 Nov 2023 15:08:08 -0500 Subject: [PATCH] Preserve Optional type annotation if already present --- src/core_codemods/fix_mutable_params.py | 11 +++++++++-- tests/codemods/test_fix_mutable_params.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/core_codemods/fix_mutable_params.py b/src/core_codemods/fix_mutable_params.py index f6ab3dd4..d95bcb3b 100644 --- a/src/core_codemods/fix_mutable_params.py +++ b/src/core_codemods/fix_mutable_params.py @@ -30,18 +30,25 @@ def __init__(self, *args, **kwargs): self._matches_builtin = m.Call(func=m.Name("list") | m.Name("dict")) def _create_annotation(self, orig: cst.Param, updated: cst.Param): + match orig.annotation: + case cst.Annotation(annotation=cst.Subscript(sub)): + match sub: # type: ignore + case cst.Name("Optional"): + # Already an Optional, so we can just preserve the original annotation + return updated.annotation + return ( updated.annotation.with_changes( annotation=cst.Subscript( value=cst.Name("Optional"), slice=[ cst.SubscriptElement( - slice=cst.Index(value=orig.annotation.annotation) + slice=cst.Index(value=updated.annotation.annotation) ) ], ) ) - if updated.annotation is not None + if orig.annotation is not None and updated.annotation is not None else None ) diff --git a/tests/codemods/test_fix_mutable_params.py b/tests/codemods/test_fix_mutable_params.py index 267b82b9..8854570e 100644 --- a/tests/codemods/test_fix_mutable_params.py +++ b/tests/codemods/test_fix_mutable_params.py @@ -194,6 +194,22 @@ def foo(x = None, y: Optional[List[int]] = None, z: Optional[Dict[str, int]] = N """ self.run_and_assert(tmpdir, input_code, expected_output) + def test_fix_type_already_optional(self, tmpdir): + input_code = """ + from typing import Optional, List + + def foo(x: Optional[List[int]] = []): + print(x) + """ + expected_output = """ + from typing import Optional, List + + def foo(x: Optional[List[int]] = None): + x = [] if x is None else x + print(x) + """ + self.run_and_assert(tmpdir, input_code, expected_output) + def test_fix_respect_docstring(self, tmpdir): input_code = ''' def func(foo=[]):