Skip to content

Commit

Permalink
Implement use-walrus-if without semgrep
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Dec 7, 2023
1 parent 966f49d commit fec0705
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 73 deletions.
160 changes: 91 additions & 69 deletions src/core_codemods/use_walrus_if.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import namedtuple
import itertools
from typing import List, Tuple, Optional

import libcst as cst
Expand All @@ -6,11 +8,20 @@
from libcst.metadata import ParentNodeProvider, ScopeProvider

from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.api import BaseCodemod


class UseWalrusIf(SemgrepCodemod):
METADATA_DEPENDENCIES = SemgrepCodemod.METADATA_DEPENDENCIES + (
FoundAssign = namedtuple("FoundAssign", ["assign", "target", "value"])


def pairwise(iterable):
a, b = itertools.tee(iterable)
next(b, None)
return zip(a, b)


class UseWalrusIf(BaseCodemod):
METADATA_DEPENDENCIES = BaseCodemod.METADATA_DEPENDENCIES + (
ParentNodeProvider,
ScopeProvider,
)
Expand All @@ -27,52 +38,89 @@ class UseWalrusIf(SemgrepCodemod):
}
]

@classmethod
def rule(cls):
return """
rules:
- patterns:
- pattern: |
$ASSIGN
if $COND:
$BODY
- metavariable-pattern:
metavariable: $ASSIGN
patterns:
- pattern: $VAR = $RHS
- metavariable-pattern:
metavariable: $COND
patterns:
- pattern: $VAR
- metavariable-pattern:
metavariable: $BODY
pattern: $VAR
- focus-metavariable: $ASSIGN
"""

_modify_next_if: List[Tuple[CodeRange, cst.Assign]]
_if_stack: List[Optional[Tuple[CodeRange, cst.Assign]]]
_modify_next_if: List[Tuple[CodeRange, cst.NamedExpr]]
_if_stack: List[Optional[Tuple[CodeRange, cst.NamedExpr]]]
assigns: dict[cst.Assign, cst.NamedExpr]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._modify_next_if = []
self._if_stack = []
self.assigns = {}

def _build_named_expr(self, target, value, parens=True):
return cst.NamedExpr(
target=target,
value=value,
lpar=[cst.LeftParen()] if parens else [],
rpar=[cst.RightParen()] if parens else [],
)

def visit_If(self, node):
def _filter_assigns(self, node: cst.CSTNode) -> FoundAssign | None:
match node:
case cst.SimpleStatementLine(
body=[
cst.Assign(
targets=[
cst.AssignTarget(target=cst.Name() as target),
],
value=value,
) as assign
]
):
return FoundAssign(assign, target, value)
return None

def _filter_if(self, node: cst.CSTNode) -> cst.BaseExpression | None:
match node:
case cst.If(test=test):
return test
return None

def on_visit(self, node: cst.CSTNode) -> Optional[bool]:
if len(node.children) < 2:
return super().on_visit(node)

for a, b in pairwise(node.children):
if not (found_assign := self._filter_assigns(a)):
continue
if not (if_test := self._filter_if(b)):
continue

assign, target, value = found_assign
# If test can be a comparison expression
match if_test:
case cst.Comparison(
left=cst.Name() as left,
comparisons=[
cst.ComparisonTarget(
operator=(
cst.Is() | cst.IsNot() | cst.Equal() | cst.NotEqual()
)
)
],
):
if left.value == target.value:
named_expr = self._build_named_expr(target, value, parens=True)
self.assigns[assign] = named_expr
# If test can also be a bare name
case cst.Name() as name:
if name.value == target.value:
named_expr = self._build_named_expr(target, value, parens=False)
self.assigns[assign] = named_expr

return super().on_visit(node)

def visit_If(self, node: cst.If):
del node
self._if_stack.append(
self._modify_next_if.pop() if len(self._modify_next_if) else None
)

def leave_If(self, original_node, updated_node):
if (result := self._if_stack.pop()) is not None:
position, if_node = result
position, named_expr = result
is_name = m.matches(updated_node.test, m.Name())
named_expr = cst.NamedExpr(
target=if_node.targets[0].target,
value=if_node.value,
lpar=[] if is_name else [cst.LeftParen()],
rpar=[] if is_name else [cst.RightParen()],
)
self.add_change_from_position(position, self.CHANGE_DESCRIPTION)
return (
updated_node.with_changes(test=named_expr)
Expand All @@ -84,38 +132,12 @@ def leave_If(self, original_node, updated_node):

return original_node

def _is_valid_modification(self, node):
"""
Restricts the kind of modifications we can make to the AST.
This is necessary since the semgrep rule can't fully encode this restriction.
"""
if parent := self.get_metadata(ParentNodeProvider, node):
if gparent := self.get_metadata(ParentNodeProvider, parent):
if (idx := gparent.children.index(parent)) >= 0:
conditional = gparent.children[idx + 1]
match conditional:
case cst.If(test=(cst.Name())):
return True
case cst.If(test=cst.Comparison(left=cst.Name()) as test):
match test.comparisons[0]:
case cst.ComparisonTarget(
operator=(
cst.Is()
| cst.IsNot()
| cst.Equal()
| cst.NotEqual()
)
):
return True
return False

def leave_Assign(self, original_node, updated_node):
if self.node_is_selected(original_node):
if self._is_valid_modification(original_node):
position = self.node_position(original_node)
self._modify_next_if.append((position, updated_node))
return cst.RemoveFromParent()
def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign):
del updated_node
if named_expr := self.assigns.get(original_node):
position = self.node_position(original_node)
self._modify_next_if.append((position, named_expr))
return cst.RemoveFromParent()

return original_node

Expand Down Expand Up @@ -147,7 +169,7 @@ def leave_SimpleStatementLine(self, original_node, updated_node):
# state management to fit within the visitor pattern. We should
# revisit this at some point later.
return cst.FlattenSentinel(
original_node.leading_lines + trailing_whitespace
tuple(original_node.leading_lines) + trailing_whitespace
)

return updated_node
8 changes: 4 additions & 4 deletions tests/codemods/test_walrus_if.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest

from core_codemods.use_walrus_if import UseWalrusIf
from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest
from tests.codemods.base_codemod_test import BaseCodemodTest


class TestUseWalrusIf(BaseSemgrepCodemodTest):
class TestUseWalrusIf(BaseCodemodTest):
codemod = UseWalrusIf

@pytest.mark.parametrize(
Expand Down Expand Up @@ -168,15 +168,15 @@ def test_dont_apply_walrus_expr(self, tmpdir):
def test_walrus_with_comparison(self, tmpdir, comparator):
input_code = f"""
def func(y):
x = foo()
y = bar(y)
x = foo()
if x {comparator} y:
print("whatever", y)
"""
expected_output = f"""
def func(y):
x = foo()
y = bar(y)
x = foo()
if x {comparator} y:
print("whatever", y)
"""
Expand Down

0 comments on commit fec0705

Please sign in to comment.