Skip to content

Commit

Permalink
Refactored RemoveEmptyStringsConcatenation
Browse files Browse the repository at this point in the history
Added separated tests
  • Loading branch information
andrecsilva committed Oct 23, 2023
1 parent d394942 commit c60b9b1
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import libcst as cst
from libcst import CSTTransformer


class RemoveEmptyStringConcatenation(CSTTransformer):
"""
Removes concatenation with empty strings (e.g. "hello " + "") or "hello" ""
"""

def leave_BinaryOperation(
self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation
) -> cst.BaseExpression:
return self.handle_node(updated_node)

def leave_ConcatenatedString(
self,
original_node: cst.ConcatenatedString,
updated_node: cst.ConcatenatedString,
) -> cst.BaseExpression:
return self.handle_node(updated_node)

def handle_node(
self, updated_node: cst.BinaryOperation | cst.ConcatenatedString
) -> cst.BaseExpression:
match updated_node.left:
# TODO f-string cases
case cst.SimpleString() if updated_node.left.raw_value == "":
match updated_node.right:
case cst.SimpleString() if updated_node.right.raw_value == "":
return cst.SimpleString(value='""')
case _:
return updated_node.right
match updated_node.right:
case cst.SimpleString() if updated_node.right.raw_value == "":
match updated_node.left:
case cst.SimpleString() if updated_node.left.raw_value == "":
return cst.SimpleString(value='""')
case _:
return updated_node.left
return updated_node
58 changes: 17 additions & 41 deletions src/core_codemods/sql_parameterization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import Any, Optional, Tuple
import libcst as cst
from libcst import CSTTransformer, SimpleWhitespace, matchers
from libcst import SimpleWhitespace, matchers
from libcst.codemod import Codemod, CodemodContext, ContextAwareVisitor
from libcst.metadata import (
ClassScope,
Expand All @@ -17,6 +17,9 @@
CodemodMetadata,
ReviewGuidance,
)
from codemodder.codemods.transformations.remove_empty_string_concatenation import (
RemoveEmptyStringConcatenation,
)
from codemodder.codemods.utils_mixin import NameResolutionMixin

parameter_token = "?"
Expand All @@ -39,7 +42,7 @@ def __init__(self, sequence: list[cst.CSTNode]) -> None:
class ReplaceNodes(cst.CSTTransformer):
"""
Replace nodes with their corresponding values in a given dict.
You can replace the entire node, some attributes of it via a dict(). Addionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, repectivelly a sequence.
You can replace the entire node or some attributes of it via a dict(). Addionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, respectivelly.
"""

def __init__(
Expand Down Expand Up @@ -220,6 +223,10 @@ def _fix_injection(


class LinearizeQuery(ContextAwareVisitor, NameResolutionMixin):
"""
Gather all the expressions that are concatenated to build the query.
"""

METADATA_DEPENDENCIES = (ParentNodeProvider,)

def __init__(self, context) -> None:
Expand Down Expand Up @@ -307,6 +314,10 @@ def _find_gparent(self, n: cst.CSTNode) -> Optional[cst.CSTNode]:


class ExtractParameters(ContextAwareVisitor):
"""
Detects injections and gather the expressions that are injectable.
"""

METADATA_DEPENDENCIES = (
ScopeProvider,
ParentNodeProvider,
Expand Down Expand Up @@ -428,6 +439,10 @@ def _is_literal_end(self, node: cst.CSTNode) -> bool:


class FindQueryCalls(ContextAwareVisitor):
"""
Find all the execute calls with a sql query as an argument.
"""

# right now it works by looking into some sql keywords in any pieces of the query
# Ideally we should infer what driver we are using
sql_keywords: list[str] = ["insert", "select", "delete", "create", "alter", "drop"]
Expand Down Expand Up @@ -467,42 +482,3 @@ def _get_function_name_node(call: cst.Call) -> Optional[cst.Name]:
case cst.Attribute():
return call.func.attr
return None


class RemoveEmptyStringConcatenation(CSTTransformer):
"""
Removes concatenation with empty strings (e.g. "hello " + "") or "hello" ""
"""

# TODO What about empty f-strings? they are a different type of node
# may not be necessary if handled correctly
def leave_BinaryOperation(
self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation
) -> cst.BaseExpression:
return self.handle_node(updated_node)

def leave_ConcatenatedString(
self,
original_node: cst.ConcatenatedString,
updated_node: cst.ConcatenatedString,
) -> cst.BaseExpression:
return self.handle_node(updated_node)

def handle_node(
self, updated_node: cst.BinaryOperation | cst.ConcatenatedString
) -> cst.BaseExpression:
match updated_node.left:
case cst.SimpleString() if updated_node.left.raw_value == "":
match updated_node.right:
case cst.SimpleString() if updated_node.right.raw_value == "":
return cst.SimpleString(value='""')
case _:
return updated_node.right
match updated_node.right:
case cst.SimpleString() if updated_node.right.raw_value == "":
match updated_node.left:
case cst.SimpleString() if updated_node.left.raw_value == "":
return cst.SimpleString(value='""')
case _:
return updated_node.left
return updated_node
33 changes: 33 additions & 0 deletions tests/codemods/test_sql_parameterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,41 @@ def test_multiple_escaped_quote(self, tmpdir):
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_simple_concatenated_strings(self, tmpdir):
input_code = """\
import sqlite3
name = input()
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute("SELECT * from USERS" "WHERE name ='" + name + "'")
"""
expected = """\
import sqlite3
name = input()
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute("SELECT * from USERS" "WHERE name =?", (name, ))
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

# negative tests below

def test_formatted_string_simple(self, tmpdir):
# TODO change when we add support for it
input_code = """\
import sqlite3
name = input()
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute(f"SELECT * from USERS WHERE name='{name}'")
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code))
assert len(self.file_context.codemod_changes) == 0

def test_no_sql_keyword(self, tmpdir):
input_code = """\
import sqlite3
Expand Down
82 changes: 82 additions & 0 deletions tests/transformations/test_remove_empty_string_concatenation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import libcst as cst
from libcst.codemod import Codemod, CodemodTest

from codemodder.codemods.transformations.remove_empty_string_concatenation import (
RemoveEmptyStringConcatenation,
)


class RemoveEmptyStringConcatenationCodemod(Codemod):
def transform_module_impl(self, tree: cst.Module) -> cst.Module:
return tree.visit(RemoveEmptyStringConcatenation())


class TestRemoveEmptyStringConcatenation(CodemodTest):
TRANSFORM = RemoveEmptyStringConcatenationCodemod

def test_left(self):
before = """
"" + "world"
"""

after = """
"world"
"""

self.assertCodemod(before, after)

def test_right(self):
before = """
"hello" + ""
"""

after = """
"hello"
"""

self.assertCodemod(before, after)

def test_both(self):
before = """
"" + ""
"""

after = """
""
"""

self.assertCodemod(before, after)

def test_concatenated_string_right(self):
before = """
"hello" ""
"""

after = """
"hello"
"""

self.assertCodemod(before, after)

def test_concatenated_string_left(self):
before = """
"world"
"""

after = """
"world"
"""

self.assertCodemod(before, after)

def test_multiple_mixed(self):
before = (
"""
"" + '' """
""" + r''''''
"""
)

after = '""'

self.assertCodemod(before, after)

0 comments on commit c60b9b1

Please sign in to comment.