Skip to content

Commit

Permalink
do not convert cases with different variables or functions
Browse files Browse the repository at this point in the history
  • Loading branch information
clavedeluna committed Jan 3, 2024
1 parent af92775 commit dd2f3e3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
27 changes: 20 additions & 7 deletions src/core_codemods/combine_startswith_endswith.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ class CombineStartswithEndswith(BaseCodemod, NameResolutionMixin):
def leave_BooleanOperation(
self, original_node: cst.BooleanOperation, updated_node: cst.BooleanOperation
) -> cst.CSTNode:
if not self.filter_by_path_includes_or_excludes(
self.node_position(original_node)
):
return original_node

if self.matches_startswith_endswith_or_pattern(original_node):
left_call = cst.ensure_type(updated_node.left, cst.Call)
right_call = cst.ensure_type(updated_node.right, cst.Call)
Expand All @@ -38,12 +43,20 @@ def matches_startswith_endswith_or_pattern(
) -> bool:
# Match the pattern: x.startswith("...") or x.startswith("...")
# and the same but with endswith
call = m.Call(
func=m.Attribute(
value=m.Name(), attr=m.Name("startswith") | m.Name("endswith")
),
args=[m.Arg(value=m.SimpleString())],
args = [m.Arg(value=m.SimpleString())]
startswith = m.Call(
func=m.Attribute(value=m.Name(), attr=m.Name("startswith")),
args=args,
)
endswith = m.Call(
func=m.Attribute(value=m.Name(), attr=m.Name("endswith")),
args=args,
)
startswith_or = m.BooleanOperation(
left=startswith, operator=m.Or(), right=startswith
)
return m.matches(
node, m.BooleanOperation(left=call, operator=m.Or(), right=call)
endswith_or = m.BooleanOperation(left=endswith, operator=m.Or(), right=endswith)
return (
m.matches(node, startswith_or | endswith_or)
and node.left.func.value == node.right.func.value
)
5 changes: 3 additions & 2 deletions tests/codemods/test_combine_startswith_endswith.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
from tests.codemods.base_codemod_test import BaseCodemodTest
from core_codemods.combine_startswith_endswith import CombineStartswithEndswith
from textwrap import dedent

each_func = pytest.mark.parametrize("func", ["startswith", "endswith"])

Expand All @@ -22,7 +21,7 @@ def test_combine(self, tmpdir, func):
x = "foo"
x.{func}(("foo", "f"))
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
self.run_and_assert(tmpdir, input_code, expected)
assert len(self.file_context.codemod_changes) == 1

@pytest.mark.parametrize(
Expand All @@ -32,6 +31,8 @@ def test_combine(self, tmpdir, func):
"x.startswith(('f', 'foo'))",
"x.startswith('foo') and x.startswith('f')",
"x.startswith('foo') and x.startswith('f') or True",
"x.startswith('foo') or x.endswith('f')",
"x.startswith('foo') or y.startswith('f')",
],
)
def test_no_change(self, tmpdir, code):
Expand Down

0 comments on commit dd2f3e3

Please sign in to comment.