Skip to content

Commit

Permalink
Restrict walrus assignment operators
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Dec 7, 2023
1 parent be75dce commit 966f49d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/core_codemods/use_walrus_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,21 @@ def _is_valid_modification(self, node):
if parent := self.get_metadata(ParentNodeProvider, node):
if gparent := self.get_metadata(ParentNodeProvider, parent):
if (idx := gparent.children.index(parent)) >= 0:
return m.matches(
gparent.children[idx + 1],
m.If(test=(m.Name() | m.Comparison(left=m.Name()))),
)
conditional = gparent.children[idx + 1]
match conditional:
case cst.If(test=(cst.Name())):
return True
case cst.If(test=cst.Comparison(left=cst.Name()) as test):
match test.comparisons[0]:
case cst.ComparisonTarget(
operator=(
cst.Is()
| cst.IsNot()
| cst.Equal()
| cst.NotEqual()
)
):
return True
return False

def leave_Assign(self, original_node, updated_node):
Expand Down
18 changes: 18 additions & 0 deletions tests/codemods/test_walrus_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,21 @@ def test_dont_apply_walrus_expr(self, tmpdir):
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, input_code)

@pytest.mark.parametrize("comparator", [">", "<", ">=", "<="])
def test_walrus_with_comparison(self, tmpdir, comparator):
input_code = f"""
def func(y):
x = foo()
y = bar(y)
if x {comparator} y:
print("whatever", y)
"""
expected_output = f"""
def func(y):
x = foo()
y = bar(y)
if x {comparator} y:
print("whatever", y)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

0 comments on commit 966f49d

Please sign in to comment.