diff --git a/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py b/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py index 58bda6b0..fbec19a9 100644 --- a/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py +++ b/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py @@ -1,5 +1,6 @@ +from typing import Union import libcst as cst -from libcst import CSTTransformer +from libcst import CSTTransformer, RemovalSentinel, SimpleString class RemoveEmptyStringConcatenation(CSTTransformer): @@ -7,6 +8,21 @@ class RemoveEmptyStringConcatenation(CSTTransformer): Removes concatenation with empty strings (e.g. "hello " + "") or "hello" "" """ + def leave_FormattedStringExpression( + self, + original_node: cst.FormattedStringExpression, + updated_node: cst.FormattedStringExpression, + ) -> Union[ + cst.BaseFormattedStringContent, + cst.FlattenSentinel[cst.BaseFormattedStringContent], + RemovalSentinel, + ]: + expr = original_node.expression + match expr: + case SimpleString() if expr.raw_value == "": # type: ignore + return RemovalSentinel.REMOVE + return updated_node + def leave_BinaryOperation( self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation ) -> cst.BaseExpression: @@ -22,19 +38,22 @@ def leave_ConcatenatedString( def handle_node( self, updated_node: cst.BinaryOperation | cst.ConcatenatedString ) -> cst.BaseExpression: - match updated_node.left: - # TODO f-string cases - case cst.SimpleString() if updated_node.left.raw_value == "": - match updated_node.right: - case cst.SimpleString() if updated_node.right.raw_value == "": - return cst.SimpleString(value='""') - case _: - return updated_node.right - match updated_node.right: - case cst.SimpleString() if updated_node.right.raw_value == "": - match updated_node.left: - case cst.SimpleString() if updated_node.left.raw_value == "": - return cst.SimpleString(value='""') - case _: - return updated_node.left + left = updated_node.left + right = updated_node.right + if self._is_empty_string_literal(left): + if self._is_empty_string_literal(right): + return cst.SimpleString(value='""') + return right + if self._is_empty_string_literal(right): + if self._is_empty_string_literal(left): + return cst.SimpleString(value='""') + return left return updated_node + + def _is_empty_string_literal(self, node): + match node: + case cst.SimpleString() if node.raw_value == "": + return True + case cst.FormattedString() if not node.parts: + return True + return False diff --git a/src/codemodder/codemods/utils.py b/src/codemodder/codemods/utils.py index a4a9c0ab..1b29298c 100644 --- a/src/codemodder/codemods/utils.py +++ b/src/codemodder/codemods/utils.py @@ -1,3 +1,4 @@ +from enum import Enum from pathlib import Path from typing import Optional, Any @@ -5,6 +6,52 @@ import libcst as cst +class BaseType(Enum): + """ + An enumeration representing the base literal types in Python. + """ + + NUMBER = 1 + LIST = 2 + STRING = 3 + BYTES = 4 + + +# pylint: disable-next=R0911 +def infer_expression_type(node: cst.BaseExpression) -> Optional[BaseType]: + """ + Tries to infer if the resulting type of a given expression is one of the base literal types. + """ + # The current implementation covers some common cases and is in no way complete + match node: + case cst.Integer() | cst.Imaginary() | cst.Float() | cst.Call( + func=cst.Name("int") + ) | cst.Call(func=cst.Name("float")) | cst.Call( + func=cst.Name("abs") + ) | cst.Call( + func=cst.Name("len") + ): + return BaseType.NUMBER + case cst.Call(name=cst.Name("list")) | cst.List() | cst.ListComp(): + return BaseType.LIST + case cst.Call(func=cst.Name("str")) | cst.FormattedString(): + return BaseType.STRING + case cst.SimpleString(): + if "b" in node.prefix.lower(): + return BaseType.BYTES + return BaseType.STRING + case cst.ConcatenatedString(): + return infer_expression_type(node.left) + case cst.BinaryOperation(operator=cst.Add()): + return infer_expression_type(node.left) or infer_expression_type(node.right) + case cst.IfExp(): + if_true = infer_expression_type(node.body) + or_else = infer_expression_type(node.orelse) + if if_true == or_else: + return if_true + return None + + class SequenceExtension: def __init__(self, sequence: list[cst.CSTNode]) -> None: self.sequence = sequence @@ -27,13 +74,7 @@ def __init__( self, replacements: dict[ cst.CSTNode, - dict[ - str, - cst.CSTNode - | cst.FlattenSentinel - | cst.RemovalSentinel - | dict[str, Any], - ], + cst.CSTNode | cst.FlattenSentinel | cst.RemovalSentinel | dict[str, Any], ], ): self.replacements = replacements diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 54be3829..3619111f 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -1,8 +1,20 @@ import re -from typing import Optional, Tuple +from typing import Any, Optional, Tuple +import itertools import libcst as cst -from libcst import SimpleWhitespace, matchers -from libcst.codemod import Codemod, CodemodContext, ContextAwareVisitor +from libcst import ( + FormattedString, + SimpleString, + SimpleWhitespace, + ensure_type, + matchers, +) +from libcst.codemod import ( + Codemod, + CodemodContext, + ContextAwareTransformer, + ContextAwareVisitor, +) from libcst.metadata import ( ClassScope, GlobalScope, @@ -21,15 +33,20 @@ from codemodder.codemods.transformations.remove_empty_string_concatenation import ( RemoveEmptyStringConcatenation, ) -from codemodder.codemods.utils import Append, ReplaceNodes, get_function_name_node +from codemodder.codemods.utils import ( + Append, + BaseType, + ReplaceNodes, + get_function_name_node, + infer_expression_type, +) from codemodder.codemods.utils_mixin import NameResolutionMixin from codemodder.file_context import FileContext parameter_token = "?" -literal_number = matchers.Integer() | matchers.Float() | matchers.Imaginary() -literal_string = matchers.SimpleString() -literal = literal_number | literal_string +quote_pattern = re.compile(r"(? None: self.changed_nodes: dict[ - cst.CSTNode, cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel + cst.CSTNode, + cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel | dict[str, Any], ] = {} BaseCodemod.__init__(self, file_context, *codemod_args) UtilsMixin.__init__(self, []) Codemod.__init__(self, context) - def _build_param_element(self, middle, index: int) -> cst.BaseExpression: - # TODO maybe a parameterized string would be better here - # f-strings need python 3.6 though - if index == 0: - return middle[0] - operator = cst.Add( - whitespace_after=cst.SimpleWhitespace(" "), - whitespace_before=cst.SimpleWhitespace(" "), + def _build_param_element(self, prepend, middle, append): + new_middle = ( + ([prepend] if prepend else []) + middle + ([append] if append else []) ) - return cst.BinaryOperation( - operator=operator, - left=self._build_param_element(middle, index - 1), - right=middle[index], + format_pieces: list[str] = [] + format_expr_count = 0 + args = [] + if len(new_middle) == 1: + # TODO maybe handle conversion here? + return new_middle[0] + for e in new_middle: + exception = False + if isinstance(e, cst.SimpleString | cst.FormattedStringText): + t = _extract_prefix_raw_value(self, e) + if t: + prefix, raw_value = t + if not "b" in prefix and not "r" in prefix and not "u" in prefix: + format_pieces.append(raw_value) + exception = True + if not exception: + format_pieces.append(f"{{{format_expr_count}}}") + format_expr_count += 1 + args.append(cst.Arg(e)) + + format_string = "".join(format_pieces) + format_string_node = cst.SimpleString(f"'{format_string}'") + return cst.Call( + func=cst.Attribute(value=format_string_node, attr=cst.Name(value="format")), + args=args, ) def transform_module_impl(self, tree: cst.Module) -> cst.Module: + # The transformation has four major steps: + # (1) FindQueryCalls - Find and gather all the SQL query execution calls. The result is a dict of call nodes and their associated list of nodes composing the query (i.e. step (2)). + # (2) LinearizeQuery - For each call, it gather all the string literals and expressions that composes the query. The result is a list of nodes whose concatenation is the query. + # (3) ExtractParameters - Detects which expressions are part of SQL string literals in the query. The result is a list of triples (a,b,c) such that a is the node that contains the start of the string literal, b is a list of expressions that composes that literal, and c is the node containing the end of the string literal. At least one node in b must be "injectable" (see). + # (4) SQLQueryParameterization - Executes steps (1)-(3) and gather a list of injection triples. For each triple (a,b,c) it makes the associated changes to insert the query parameter token. All the expressions in b are then concatenated in an expression and passed as a sequence of parameters to the execute call. + + # Steps (1) and (2) find_queries = FindQueryCalls(self.context) tree.visit(find_queries) @@ -96,22 +137,21 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: if not self.filter_by_path_includes_or_excludes(call_pos): break + # Step (3) ep = ExtractParameters(self.context, query) tree.visit(ep) - # build tuple elements and fix injection + # Step (4) - build tuple elements and fix injection params_elements: list[cst.Element] = [] for start, middle, end in ep.injection_patterns: - # TODO support for elements from f-strings? - # reminder that python has no implicit conversion while concatenating with +, might need to use str() for a particular expression - expr = self._build_param_element(middle, len(middle) - 1) + prepend, append = self._fix_injection(start, middle, end) + expr = self._build_param_element(prepend, middle, append) params_elements.append( cst.Element( value=expr, comma=cst.Comma(whitespace_after=SimpleWhitespace(" ")), ) ) - self._fix_injection(start, middle, end) # TODO research if named parameters are widely supported # it could solve for the case of existing parameters @@ -128,7 +168,12 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: self.file_context.codemod_changes.append( Change(line_number, SQLQueryParameterization.CHANGE_DESCRIPTION) ) + # Normalization and cleanup result = result.visit(RemoveEmptyStringConcatenation()) + result = NormalizeFStrings(self.context).transform_module(result) + # TODO The transform below may break nested f-strings: f"{f"1"}" -> f"{"1"}" + # May be a bug... + # result = UnnecessaryFormatString(self.context).transform_module(result) return result @@ -136,44 +181,90 @@ def _fix_injection( self, start: cst.CSTNode, middle: list[cst.CSTNode], end: cst.CSTNode ): for expr in middle: - self.changed_nodes[expr] = cst.parse_expression('""') + if isinstance( + expr, cst.FormattedStringText | cst.FormattedStringExpression + ): + self.changed_nodes[expr] = cst.RemovalSentinel.REMOVE + else: + self.changed_nodes[expr] = cst.parse_expression('""') + + # remove quote literal from start + updated_start = self.changed_nodes.get(start) or start + + t = _extract_prefix_raw_value(self, updated_start) + prefix, raw_value = t if t else ("", "") + + # gather string after the quote + if "r" in prefix: + quote_span = list(raw_quote_pattern.finditer(raw_value))[-1] + else: + quote_span = list(quote_pattern.finditer(raw_value))[-1] + + new_raw_value = raw_value[: quote_span.start()] + parameter_token + prepend_raw_value = raw_value[quote_span.end() :] + + prepend = self._remove_literal_and_gather_extra( + start, updated_start, prefix, new_raw_value, prepend_raw_value + ) + # remove quote literal from end - match end: + updated_end = self.changed_nodes.get(end) or end + + t = _extract_prefix_raw_value(self, updated_end) + prefix, raw_value = t if t else ("", "") + if "r" in prefix: + quote_span = list(raw_quote_pattern.finditer(raw_value))[0] + else: + quote_span = list(quote_pattern.finditer(raw_value))[0] + + new_raw_value = raw_value[quote_span.end() :] + append_raw_value = raw_value[: quote_span.start()] + + append = self._remove_literal_and_gather_extra( + end, updated_end, prefix, new_raw_value, append_raw_value + ) + + return (prepend, append) + + # pylint: disable-next=too-many-arguments + def _remove_literal_and_gather_extra( + self, original_node, updated_node, prefix, new_raw_value, extra_raw_value + ) -> Optional[SimpleString]: + extra = None + match updated_node: case cst.SimpleString(): - current_end = self.changed_nodes.get(end) or end - if current_end.raw_value.startswith("\\'"): - new_raw_value = current_end.raw_value[2:] - else: - new_raw_value = current_end.raw_value[1:] + # gather string after or before the quote + if extra_raw_value: + extra = cst.SimpleString( + value=updated_node.prefix + + updated_node.quote + + extra_raw_value + + updated_node.quote + ) + new_value = ( - current_end.prefix - + current_end.quote + updated_node.prefix + + updated_node.quote + new_raw_value - + current_end.quote + + updated_node.quote + ) + self.changed_nodes[original_node] = updated_node.with_changes( + value=new_value ) - self.changed_nodes[end] = current_end.with_changes(value=new_value) case cst.FormattedStringText(): - # TODO formatted string case - pass + if extra_raw_value: + extra = cst.SimpleString( + value=("r" if "r" in prefix else "") + + "'" + + extra_raw_value + + "'" + ) - # remove quote literal from start - match start: - case cst.SimpleString(): - current_start = self.changed_nodes.get(start) or start - if current_start.raw_value.endswith("\\'"): - new_raw_value = current_start.raw_value[:-2] + parameter_token - else: - new_raw_value = current_start.raw_value[:-1] + parameter_token - new_value = ( - current_start.prefix - + current_start.quote - + new_raw_value - + current_start.quote + new_value = new_raw_value + self.changed_nodes[original_node] = updated_node.with_changes( + value=new_value ) - self.changed_nodes[start] = current_start.with_changes(value=new_value) - case cst.FormattedStringText(): - # TODO formatted string case - pass + return extra class LinearizeQuery(ContextAwareVisitor, NameResolutionMixin): @@ -188,10 +279,6 @@ def __init__(self, context) -> None: self.leaves: list[cst.CSTNode] = [] def on_visit(self, node: cst.CSTNode): - # TODO function to detect if BinaryExpression results in a number or list? - # will it only matter inside fstrings? (outside it, we expect query to be string) - # check if any is a string should be necessary - # We only care about expressions, ignore everything else # Mostly as a sanity check, this may not be necessary since we start the visit with an expression node if isinstance( @@ -206,8 +293,8 @@ def on_visit(self, node: cst.CSTNode): if not matchers.matches( node, matchers.Name() - | matchers.BinaryOperation() | matchers.FormattedString() + | matchers.BinaryOperation() | matchers.FormattedStringExpression() | matchers.ConcatenatedString(), ): @@ -216,6 +303,13 @@ def on_visit(self, node: cst.CSTNode): return super().on_visit(node) return False + def visit_BinaryOperation(self, node: cst.BinaryOperation) -> Optional[bool]: + maybe_type = infer_expression_type(node) + if not maybe_type or maybe_type == BaseType.STRING: + return True + self.leaves.append(node) + return False + # recursive search def visit_Name(self, node: cst.Name) -> Optional[bool]: self.leaves.extend(self.recurse_Name(node)) @@ -277,9 +371,6 @@ class ExtractParameters(ContextAwareVisitor): ParentNodeProvider, ) - quote_pattern = re.compile(r"(? None: self.query: list[cst.CSTNode] = query self.injection_patterns: list[ @@ -289,10 +380,10 @@ def __init__(self, context: CodemodContext, query: list[cst.CSTNode]) -> None: def leave_Module(self, original_node: cst.Module): leaves = list(reversed(self.query)) - # treat it as a stack modulo_2 = 1 + # treat it as a stack while leaves: - # search for the literal start + # search for the literal start, we detect the single quote start = leaves.pop() if not self._is_literal_start(start, modulo_2): continue @@ -309,7 +400,6 @@ def leave_Module(self, original_node: cst.Module): # end may contain the start of another literal, put it back # should not be a single quote - # TODO think of a better solution here if self._is_literal_start(end, 0) and self._is_not_a_single_quote(end): modulo_2 = 0 leaves.append(end) @@ -317,80 +407,67 @@ def leave_Module(self, original_node: cst.Module): modulo_2 = 1 def _is_not_a_single_quote(self, expression: cst.CSTNode) -> bool: - match expression: - case cst.SimpleString(): - prefix = expression.prefix.lower() - if "b" in prefix: - return False - if "r" in prefix: - return ( - self.raw_quote_pattern.fullmatch(expression.raw_value) is None - ) - return self.quote_pattern.fullmatch(expression.raw_value) is None - return True - - def _is_injectable(self, expression: cst.CSTNode) -> bool: - # TODO exceptions - # tuple and list literals ??? - # BinaryExpression case - match expression: - case cst.Integer() | cst.Float() | cst.Imaginary() | cst.SimpleString(): - return False - case cst.Call(func=cst.Name(value="str"), args=[cst.Arg(value=arg), *_]): - # TODO - # treat str(encoding = 'utf-8', object=obj) - # ensure this is the built-in - if matchers.matches(arg, literal): # type: ignore - return False - case cst.FormattedStringExpression() if matchers.matches( - expression, literal - ): - return False - case cst.IfExp(): - return self._is_injectable(expression.body) or self._is_injectable( - expression.orelse - ) - return True + t = _extract_prefix_raw_value(self, expression) + if not t: + return True + prefix, raw_value = t + if "b" in prefix: + return False + if "r" in prefix: + return raw_quote_pattern.fullmatch(raw_value) is None + return quote_pattern.fullmatch(raw_value) is None + + def _is_injectable(self, expression: cst.BaseExpression) -> bool: + return not bool(infer_expression_type(expression)) def _is_literal_start(self, node: cst.CSTNode, modulo_2: int) -> bool: - # TODO limited for now, won't include cases like "name = 'username_" + name + "_tail'" - match node: - case cst.SimpleString(): - prefix = node.prefix.lower() - if "b" in prefix: - return False - if "r" in prefix: - matches = list(self.raw_quote_pattern.finditer(node.raw_value)) - else: - matches = list(self.quote_pattern.finditer(node.raw_value)) - # avoid cases like: "where name = 'foo\\\'s name'" - # don't count \\' as these are escaped in string literals - return ( - matches[-1].end() == len(node.raw_value) - if matches and len(matches) % 2 == modulo_2 - else False - ) - case cst.FormattedStringText(): - # TODO may be in the middle i.e. f"name='home_{exp}'" - # be careful of f"name='literal'", it needs one but not two - return False - return False + t = _extract_prefix_raw_value(self, node) + if not t: + return False + prefix, raw_value = t + if "b" in prefix: + return False + if "r" in prefix: + matches = list(raw_quote_pattern.finditer(raw_value)) + else: + matches = list(quote_pattern.finditer(raw_value)) + # avoid cases like: "where name = 'foo\\\'s name'" + # don't count \\' as these are escaped in string literals + return (matches != None) and len(matches) % 2 == modulo_2 def _is_literal_end(self, node: cst.CSTNode) -> bool: - match node: - case cst.SimpleString(): - if "b" in node.prefix: - return False - if "r" in node.prefix: - matches = list(self.raw_quote_pattern.finditer(node.raw_value)) - else: - matches = list(self.quote_pattern.finditer(node.raw_value)) - return matches[0].start() == 0 if matches else False - case cst.FormattedStringText(): - # TODO may be in the middle i.e. f"'{exp}_home'" - # be careful of f"name='literal'", it needs one but not two - return False - return False + t = _extract_prefix_raw_value(self, node) + if not t: + return False + prefix, raw_value = t + if "b" in prefix: + return False + if "r" in prefix: + matches = list(raw_quote_pattern.finditer(raw_value)) + else: + matches = list(quote_pattern.finditer(raw_value)) + return bool(matches) + + +class NormalizeFStrings(ContextAwareTransformer): + """ + Finds all the f-strings whose parts are only composed of FormattedStringText and concats all of them in a single part. + """ + + def leave_FormattedString( + self, original_node: cst.FormattedString, updated_node: cst.FormattedString + ) -> cst.BaseExpression: + all_parts = list( + itertools.takewhile( + lambda x: isinstance(x, cst.FormattedStringText), original_node.parts + ) + ) + if len(all_parts) != len(updated_node.parts): + return updated_node + new_part = cst.FormattedStringText( + value="".join(map(lambda x: x.value, all_parts)) + ) + return updated_node.with_changes(parts=[new_part]) class FindQueryCalls(ContextAwareVisitor): @@ -428,3 +505,18 @@ def leave_Call(self, original_node: cst.Call) -> None: expr.value ): self.calls[original_node] = query_visitor.leaves + + +def _extract_prefix_raw_value(self, node) -> Optional[Tuple[str, str]]: + match node: + case cst.SimpleString(): + return node.prefix.lower(), node.raw_value + case cst.FormattedStringText(): + try: + parent = self.get_metadata(ParentNodeProvider, node) + parent = ensure_type(parent, FormattedString) + except Exception: + return None + return parent.start.lower(), node.value + case _: + return None diff --git a/tests/codemods/test_sql_parameterization.py b/tests/codemods/test_sql_parameterization.py index 425b16c7..e7c72184 100644 --- a/tests/codemods/test_sql_parameterization.py +++ b/tests/codemods/test_sql_parameterization.py @@ -37,7 +37,7 @@ def test_multiple(self, tmpdir): phone = input() connection = sqlite3.connect("my_db.db") cursor = connection.cursor() - cursor.execute("SELECT * from USERS WHERE name ='" + name + "' AND phone ='" + phone + "'" ) + cursor.execute("SELECT * from USERS WHERE name ='" + name + r"' AND phone ='" + phone + "'" ) """ expected = """\ import sqlite3 @@ -46,7 +46,27 @@ def test_multiple(self, tmpdir): phone = input() connection = sqlite3.connect("my_db.db") cursor = connection.cursor() - cursor.execute("SELECT * from USERS WHERE name =?" + " AND phone =?", (name, phone, )) + cursor.execute("SELECT * from USERS WHERE name =?" + r" AND phone =?", (name, phone, )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_simple_with_quotes_in_middle(self, tmpdir): + input_code = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name ='user_" + name + r"_system'") + """ + expected = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name =?", ('user_{0}{1}'.format(name, r"_system"), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 @@ -138,10 +158,11 @@ def test_simple_concatenated_strings(self, tmpdir): self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 - # negative tests below + +class TestSQLQueryParameterizationFormattedString(BaseCodemodTest): + codemod = SQLQueryParameterization def test_formatted_string_simple(self, tmpdir): - # TODO change when we add support for it input_code = """\ import sqlite3 @@ -150,22 +171,96 @@ def test_formatted_string_simple(self, tmpdir): cursor = connection.cursor() cursor.execute(f"SELECT * from USERS WHERE name='{name}'") """ - self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 + expected = """\ + import sqlite3 - def test_no_sql_keyword(self, tmpdir): + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute(f"SELECT * from USERS WHERE name=?", (name, )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_formatted_string_quote_in_middle(self, tmpdir): input_code = """\ import sqlite3 - def foo(self, cursor, name, phone): + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute(f"SELECT * from USERS WHERE name='user_{name}_admin'") + """ + expected = """\ + import sqlite3 - a = "COLLECT * from USERS " - b = "WHERE name = '" + name - c = "' AND phone = '" + phone + "'" - return cursor.execute(a + b + c) + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute(f"SELECT * from USERS WHERE name=?", ('user_{0}_admin'.format(name), )) """ - self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_formatted_string_with_literal(self, tmpdir): + input_code = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute(f"SELECT * from USERS WHERE name='{name}_{1+2}'") + """ + expected = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute(f"SELECT * from USERS WHERE name=?", ('{0}_{1}'.format(name, 1+2), )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_formatted_string_nested(self, tmpdir): + input_code = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute(f"SELECT * from USERS WHERE name={f"'{name}'"}") + """ + expected = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute(f"SELECT * from USERS WHERE name={f"?"}", (name, )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_formatted_string_concat_mixed(self, tmpdir): + input_code = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name='" + f"{name}_{b'123'}" "'") + """ + expected = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name=?", ('{0}_{1}'.format(name, b'123'), )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 def test_multiple_expressions_injection(self, tmpdir): input_code = """\ @@ -182,11 +277,41 @@ def test_multiple_expressions_injection(self, tmpdir): name = input() connection = sqlite3.connect("my_db.db") cursor = connection.cursor() - cursor.execute("SELECT * from USERS WHERE name =?", (name + "_username", )) + cursor.execute("SELECT * from USERS WHERE name =?", ('{0}_username'.format(name), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 + +class TestSQLQueryParameterizationNegative(BaseCodemodTest): + codemod = SQLQueryParameterization + + # negative tests below + def test_no_sql_keyword(self, tmpdir): + input_code = """\ + import sqlite3 + + def foo(self, cursor, name, phone): + + a = "COLLECT * from USERS " + b = "WHERE name = '" + name + c = "' AND phone = '" + phone + "'" + return cursor.execute(a + b + c) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 + + def test_wont_mess_with_byte_strings(self, tmpdir): + input_code = """\ + import sqlite3 + + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE " + b"name ='" + str(1234) + b"'") + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 + def test_wont_parameterize_literals(self, tmpdir): input_code = """\ import sqlite3 @@ -204,7 +329,7 @@ def test_wont_parameterize_literals_if(self, tmpdir): connection = sqlite3.connect("my_db.db") cursor = connection.cursor() - cursor.execute("SELECT * from USERS WHERE name ='" + ('Jenny' if True else 123) + "'") + cursor.execute("SELECT * from USERS WHERE name ='" + ('Jenny' if True else 'Lorelei') + "'") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/test_basetype.py b/tests/test_basetype.py new file mode 100644 index 00000000..2096a31e --- /dev/null +++ b/tests/test_basetype.py @@ -0,0 +1,36 @@ +import libcst as cst +from codemodder.codemods.utils import BaseType, infer_expression_type + + +class TestBaseType: + def test_binary_op_number(self): + e = cst.parse_expression("1 + float(2)") + assert infer_expression_type(e) == BaseType.NUMBER + + def test_binary_op_string_mixed(self): + e = cst.parse_expression('f"a"+foo()') + assert infer_expression_type(e) == BaseType.STRING + + def test_binary_op_list(self): + e = cst.parse_expression("[1,2] + [x for x in [3]] + list((4,5))") + assert infer_expression_type(e) == BaseType.LIST + + def test_binary_op_none(self): + e = cst.parse_expression("foo() + bar()") + assert infer_expression_type(e) == None + + def test_bytes(self): + e = cst.parse_expression('b"123"') + assert infer_expression_type(e) == BaseType.BYTES + + def test_if_mixed(self): + e = cst.parse_expression('1 if True else "a"') + assert infer_expression_type(e) == None + + def test_if_numbers(self): + e = cst.parse_expression("abs(1) if True else 2") + assert infer_expression_type(e) == BaseType.NUMBER + + def test_if_numbers2(self): + e = cst.parse_expression("float(1) if True else len([1,2])") + assert infer_expression_type(e) == BaseType.NUMBER