Skip to content

Commit

Permalink
Some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
andrecsilva committed Nov 14, 2023
1 parent 2f11ea9 commit 0189e06
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 104 deletions.
70 changes: 34 additions & 36 deletions src/codemodder/codemods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,40 @@ class BaseType(Enum):
STRING = 3
BYTES = 4

@classmethod
# pylint: disable-next=R0911
def infer_expression_type(cls, node: cst.BaseExpression) -> Optional["BaseType"]:
"""
Tries to infer if the type of a given expression is one of the base literal types.
"""
# The current implementation could be enhanced with a few more cases
match node:
case cst.Integer() | cst.Imaginary() | cst.Float() | cst.Call(
func=cst.Name("int")
) | cst.Call(func=cst.Name("float")) | cst.Call(
func=cst.Name("abs")
) | cst.Call(
func=cst.Name("len")
):
return BaseType.NUMBER
case cst.Call(name=cst.Name("list")) | cst.List() | cst.ListComp():
return BaseType.LIST
case cst.Call(func=cst.Name("str")) | cst.FormattedString():
return BaseType.STRING
case cst.SimpleString():
if "b" in node.prefix.lower():
return BaseType.BYTES
return BaseType.STRING
case cst.ConcatenatedString():
return cls.infer_expression_type(node.left)
case cst.BinaryOperation(operator=cst.Add()):
return cls.infer_expression_type(
node.left
) or cls.infer_expression_type(node.right)
case cst.IfExp():
if_true = cls.infer_expression_type(node.body)
or_else = cls.infer_expression_type(node.orelse)
if if_true == or_else:
return if_true
return None

# pylint: disable-next=R0911
def infer_expression_type(node: cst.BaseExpression) -> Optional[BaseType]:
"""
Tries to infer if the resulting type of a given expression is one of the base literal types.
"""
# The current implementation covers some common cases and is in no way complete
match node:
case cst.Integer() | cst.Imaginary() | cst.Float() | cst.Call(
func=cst.Name("int")
) | cst.Call(func=cst.Name("float")) | cst.Call(
func=cst.Name("abs")
) | cst.Call(
func=cst.Name("len")
):
return BaseType.NUMBER
case cst.Call(name=cst.Name("list")) | cst.List() | cst.ListComp():
return BaseType.LIST
case cst.Call(func=cst.Name("str")) | cst.FormattedString():
return BaseType.STRING
case cst.SimpleString():
if "b" in node.prefix.lower():
return BaseType.BYTES
return BaseType.STRING
case cst.ConcatenatedString():
return infer_expression_type(node.left)
case cst.BinaryOperation(operator=cst.Add()):
return infer_expression_type(node.left) or infer_expression_type(node.right)
case cst.IfExp():
if_true = infer_expression_type(node.body)
or_else = infer_expression_type(node.orelse)
if if_true == or_else:
return if_true
return None


class SequenceExtension:
Expand Down
119 changes: 60 additions & 59 deletions src/core_codemods/sql_parameterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
from typing import Any, Optional, Tuple
import itertools
import libcst as cst
from libcst import FormattedString, SimpleWhitespace, ensure_type, matchers
from libcst import (
FormattedString,
SimpleString,
SimpleWhitespace,
ensure_type,
matchers,
)
from libcst.codemod import (
Codemod,
CodemodContext,
Expand Down Expand Up @@ -32,6 +38,7 @@
BaseType,
ReplaceNodes,
get_function_name_node,
infer_expression_type,
)
from codemodder.codemods.utils_mixin import NameResolutionMixin
from codemodder.file_context import FileContext
Expand Down Expand Up @@ -113,6 +120,13 @@ def _build_param_element(self, prepend, middle, append):
)

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
# The transformation has four major steps:
# (1) FindQueryCalls - Find and gather all the SQL query execution calls. The result is a dict of call nodes and their associated list of nodes composing the query (i.e. step (2)).
# (2) LinearizeQuery - For each call, it gather all the string literals and expressions that composes the query. The result is a list of nodes whose concatenation is the query.
# (3) ExtractParameters - Detects which expressions are part of SQL string literals in the query. The result is a list of triples (a,b,c) such that a is the node that contains the start of the string literal, b is a list of expressions that composes that literal, and c is the node containing the end of the string literal. At least one node in b must be "injectable" (see).
# (4) SQLQueryParameterization - Executes steps (1)-(3) and gather a list of injection triples. For each triple (a,b,c) it makes the associated changes to insert the query parameter token. All the expressions in b are then concatenated in an expression and passed as a sequence of parameters to the execute call.

# Steps (1) and (2)
find_queries = FindQueryCalls(self.context)
tree.visit(find_queries)

Expand All @@ -123,10 +137,11 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module:
if not self.filter_by_path_includes_or_excludes(call_pos):
break

# Step (3)
ep = ExtractParameters(self.context, query)
tree.visit(ep)

# build tuple elements and fix injection
# Step (4) - build tuple elements and fix injection
params_elements: list[cst.Element] = []
for start, middle, end in ep.injection_patterns:
prepend, append = self._fix_injection(start, middle, end)
Expand Down Expand Up @@ -173,59 +188,29 @@ def _fix_injection(
else:
self.changed_nodes[expr] = cst.parse_expression('""')

prepend = append = None

# remove quote literal from start
current_start = self.changed_nodes.get(start) or start
prepend_raw_value = None
updated_start = self.changed_nodes.get(start) or start

t = _extract_prefix_raw_value(self, current_start)
t = _extract_prefix_raw_value(self, updated_start)
prefix, raw_value = t if t else ("", "")

# gather string after the quote
if "r" in prefix:
quote_span = list(raw_quote_pattern.finditer(raw_value))[-1]
else:
quote_span = list(quote_pattern.finditer(raw_value))[-1]

new_raw_value = raw_value[: quote_span.start()] + parameter_token
prepend_raw_value = raw_value[quote_span.end() :]

match current_start:
case cst.SimpleString():
# uses the same quote and prefixes to guarantee it will be correctly interpreted
if prepend_raw_value:
prepend = cst.SimpleString(
value=current_start.prefix
+ current_start.quote
+ prepend_raw_value
+ current_start.quote
)

new_value = (
current_start.prefix
+ current_start.quote
+ new_raw_value
+ current_start.quote
)
self.changed_nodes[start] = current_start.with_changes(value=new_value)

case cst.FormattedStringText():
if prepend_raw_value:
prepend = cst.SimpleString(
value=("r" if "r" in prefix else "")
+ "'"
+ prepend_raw_value
+ "'"
)

new_value = new_raw_value
self.changed_nodes[start] = current_start.with_changes(value=new_value)
prepend = self._remove_literal_and_gather_extra(
start, updated_start, prefix, new_raw_value, prepend_raw_value
)

# remove quote literal from end
current_end = self.changed_nodes.get(end) or end
append_raw_value = None
updated_end = self.changed_nodes.get(end) or end

t = _extract_prefix_raw_value(self, current_end)
t = _extract_prefix_raw_value(self, updated_end)
prefix, raw_value = t if t else ("", "")
if "r" in prefix:
quote_span = list(raw_quote_pattern.finditer(raw_value))[0]
Expand All @@ -234,36 +219,52 @@ def _fix_injection(

new_raw_value = raw_value[quote_span.end() :]
append_raw_value = raw_value[: quote_span.start()]
match current_end:

append = self._remove_literal_and_gather_extra(
end, updated_end, prefix, new_raw_value, append_raw_value
)

return (prepend, append)

# pylint: disable-next=too-many-arguments
def _remove_literal_and_gather_extra(
self, original_node, updated_node, prefix, new_raw_value, extra_raw_value
) -> Optional[SimpleString]:
extra = None
match updated_node:
case cst.SimpleString():
# gather string up to quote to parameter
if append_raw_value:
append = cst.SimpleString(
value=current_end.prefix
+ current_end.quote
+ append_raw_value
+ current_end.quote
# gather string after or before the quote
if extra_raw_value:
extra = cst.SimpleString(
value=updated_node.prefix
+ updated_node.quote
+ extra_raw_value
+ updated_node.quote
)

new_value = (
current_end.prefix
+ current_end.quote
updated_node.prefix
+ updated_node.quote
+ new_raw_value
+ current_end.quote
+ updated_node.quote
)
self.changed_nodes[original_node] = updated_node.with_changes(
value=new_value
)
self.changed_nodes[end] = current_end.with_changes(value=new_value)
case cst.FormattedStringText():
if append_raw_value:
append = cst.SimpleString(
if extra_raw_value:
extra = cst.SimpleString(
value=("r" if "r" in prefix else "")
+ "'"
+ append_raw_value
+ extra_raw_value
+ "'"
)

new_value = new_raw_value
self.changed_nodes[end] = current_end.with_changes(value=new_value)
return (prepend, append)
self.changed_nodes[original_node] = updated_node.with_changes(
value=new_value
)
return extra


class LinearizeQuery(ContextAwareVisitor, NameResolutionMixin):
Expand Down Expand Up @@ -303,7 +304,7 @@ def on_visit(self, node: cst.CSTNode):
return False

def visit_BinaryOperation(self, node: cst.BinaryOperation) -> Optional[bool]:
maybe_type = BaseType.infer_expression_type(node)
maybe_type = infer_expression_type(node)
if not maybe_type or maybe_type == BaseType.STRING:
return True
self.leaves.append(node)
Expand Down Expand Up @@ -417,7 +418,7 @@ def _is_not_a_single_quote(self, expression: cst.CSTNode) -> bool:
return quote_pattern.fullmatch(raw_value) is None

def _is_injectable(self, expression: cst.BaseExpression) -> bool:
return not bool(BaseType.infer_expression_type(expression))
return not bool(infer_expression_type(expression))

def _is_literal_start(self, node: cst.CSTNode, modulo_2: int) -> bool:
t = _extract_prefix_raw_value(self, node)
Expand Down
18 changes: 9 additions & 9 deletions tests/test_basetype.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
import libcst as cst
from codemodder.codemods.utils import BaseType
from codemodder.codemods.utils import BaseType, infer_expression_type


class TestBaseType:
def test_binary_op_number(self):
e = cst.parse_expression("1 + float(2)")
assert BaseType.infer_expression_type(e) == BaseType.NUMBER
assert infer_expression_type(e) == BaseType.NUMBER

def test_binary_op_string_mixed(self):
e = cst.parse_expression('f"a"+foo()')
assert BaseType.infer_expression_type(e) == BaseType.STRING
assert infer_expression_type(e) == BaseType.STRING

def test_binary_op_list(self):
e = cst.parse_expression("[1,2] + [x for x in [3]] + list((4,5))")
assert BaseType.infer_expression_type(e) == BaseType.LIST
assert infer_expression_type(e) == BaseType.LIST

def test_binary_op_none(self):
e = cst.parse_expression("foo() + bar()")
assert BaseType.infer_expression_type(e) == None
assert infer_expression_type(e) == None

def test_bytes(self):
e = cst.parse_expression('b"123"')
assert BaseType.infer_expression_type(e) == BaseType.BYTES
assert infer_expression_type(e) == BaseType.BYTES

def test_if_mixed(self):
e = cst.parse_expression('1 if True else "a"')
assert BaseType.infer_expression_type(e) == None
assert infer_expression_type(e) == None

def test_if_numbers(self):
e = cst.parse_expression("abs(1) if True else 2")
assert BaseType.infer_expression_type(e) == BaseType.NUMBER
assert infer_expression_type(e) == BaseType.NUMBER

def test_if_numbers2(self):
e = cst.parse_expression("float(1) if True else len([1,2])")
assert BaseType.infer_expression_type(e) == BaseType.NUMBER
assert infer_expression_type(e) == BaseType.NUMBER

0 comments on commit 0189e06

Please sign in to comment.