From 41be954a8e84c59b0b070557cd6b9f99cce3ca72 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Mon, 6 Nov 2023 10:05:31 -0300 Subject: [PATCH 1/8] Expanded literal detection --- src/core_codemods/sql_parameterization.py | 130 +++++++++++++------- tests/codemods/test_sql_parameterization.py | 20 +++ 2 files changed, 106 insertions(+), 44 deletions(-) diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 54be3829..c6b2e360 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -31,6 +31,9 @@ literal_string = matchers.SimpleString() literal = literal_number | literal_string +quote_pattern = re.compile(r"(? cst.BaseExpression: + def _build_param_element_recurse(self, middle, index: int) -> cst.BaseExpression: # TODO maybe a parameterized string would be better here # f-strings need python 3.6 though if index == 0: @@ -81,10 +84,16 @@ def _build_param_element(self, middle, index: int) -> cst.BaseExpression: ) return cst.BinaryOperation( operator=operator, - left=self._build_param_element(middle, index - 1), + left=self._build_param_element_recurse(middle, index - 1), right=middle[index], ) + def _build_param_element(self, prepend, middle, append): + new_middle = ( + ([prepend] if prepend else []) + middle + ([append] if append else []) + ) + return self._build_param_element_recurse(new_middle, len(new_middle) - 1) + def transform_module_impl(self, tree: cst.Module) -> cst.Module: find_queries = FindQueryCalls(self.context) tree.visit(find_queries) @@ -104,14 +113,14 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: for start, middle, end in ep.injection_patterns: # TODO support for elements from f-strings? # reminder that python has no implicit conversion while concatenating with +, might need to use str() for a particular expression - expr = self._build_param_element(middle, len(middle) - 1) + prepend, append = self._fix_injection(start, middle, end) + expr = self._build_param_element(prepend, middle, append) params_elements.append( cst.Element( value=expr, comma=cst.Comma(whitespace_after=SimpleWhitespace(" ")), ) ) - self._fix_injection(start, middle, end) # TODO research if named parameters are widely supported # it could solve for the case of existing parameters @@ -137,43 +146,85 @@ def _fix_injection( ): for expr in middle: self.changed_nodes[expr] = cst.parse_expression('""') - # remove quote literal from end - match end: + + prepend = append = None + # remove quote literal from start + match start: case cst.SimpleString(): - current_end = self.changed_nodes.get(end) or end - if current_end.raw_value.startswith("\\'"): - new_raw_value = current_end.raw_value[2:] + current_start = self.changed_nodes.get(start) or start + + # find + if "r" in start.prefix.lower(): + quote_span = list( + raw_quote_pattern.finditer(current_start.raw_value) + )[-1] else: - new_raw_value = current_end.raw_value[1:] + quote_span = list(quote_pattern.finditer(current_start.raw_value))[ + -1 + ] + + # gather string after the quote + # uses the same quote and prefixes to guarantee it will be correctly interpreted + new_raw_value = ( + current_start.raw_value[: quote_span.start()] + parameter_token + ) + + prepend_raw_value = current_start.raw_value[quote_span.end() :] + if prepend_raw_value: + prepend = cst.SimpleString( + value=current_start.prefix + + current_start.quote + + prepend_raw_value + + current_start.quote + ) + new_value = ( - current_end.prefix - + current_end.quote + current_start.prefix + + current_start.quote + new_raw_value - + current_end.quote + + current_start.quote ) - self.changed_nodes[end] = current_end.with_changes(value=new_value) + + self.changed_nodes[start] = current_start.with_changes(value=new_value) case cst.FormattedStringText(): # TODO formatted string case pass - # remove quote literal from start - match start: + # remove quote literal from end + match end: case cst.SimpleString(): - current_start = self.changed_nodes.get(start) or start - if current_start.raw_value.endswith("\\'"): - new_raw_value = current_start.raw_value[:-2] + parameter_token + current_end = self.changed_nodes.get(end) or end + + if "r" in start.prefix.lower(): + quote_span = list( + raw_quote_pattern.finditer(current_end.raw_value) + )[0] else: - new_raw_value = current_start.raw_value[:-1] + parameter_token + quote_span = list(quote_pattern.finditer(current_end.raw_value))[0] + + new_raw_value = current_end.raw_value[quote_span.end() :] + + # gather string up to quote to parameter + append_raw_value = current_end.raw_value[: quote_span.start()] + if append_raw_value: + append = cst.SimpleString( + value=current_end.prefix + + current_end.quote + + append_raw_value + + current_end.quote + ) + new_value = ( - current_start.prefix - + current_start.quote + current_end.prefix + + current_end.quote + new_raw_value - + current_start.quote + + current_end.quote ) - self.changed_nodes[start] = current_start.with_changes(value=new_value) + self.changed_nodes[end] = current_end.with_changes(value=new_value) case cst.FormattedStringText(): # TODO formatted string case pass + return (prepend, append) class LinearizeQuery(ContextAwareVisitor, NameResolutionMixin): @@ -277,9 +328,6 @@ class ExtractParameters(ContextAwareVisitor): ParentNodeProvider, ) - quote_pattern = re.compile(r"(? None: self.query: list[cst.CSTNode] = query self.injection_patterns: list[ @@ -289,10 +337,10 @@ def __init__(self, context: CodemodContext, query: list[cst.CSTNode]) -> None: def leave_Module(self, original_node: cst.Module): leaves = list(reversed(self.query)) - # treat it as a stack modulo_2 = 1 + # treat it as a stack while leaves: - # search for the literal start + # search for the literal start, we detect the single quote start = leaves.pop() if not self._is_literal_start(start, modulo_2): continue @@ -323,10 +371,8 @@ def _is_not_a_single_quote(self, expression: cst.CSTNode) -> bool: if "b" in prefix: return False if "r" in prefix: - return ( - self.raw_quote_pattern.fullmatch(expression.raw_value) is None - ) - return self.quote_pattern.fullmatch(expression.raw_value) is None + return raw_quote_pattern.fullmatch(expression.raw_value) is None + return quote_pattern.fullmatch(expression.raw_value) is None return True def _is_injectable(self, expression: cst.CSTNode) -> bool: @@ -353,23 +399,18 @@ def _is_injectable(self, expression: cst.CSTNode) -> bool: return True def _is_literal_start(self, node: cst.CSTNode, modulo_2: int) -> bool: - # TODO limited for now, won't include cases like "name = 'username_" + name + "_tail'" match node: case cst.SimpleString(): prefix = node.prefix.lower() if "b" in prefix: return False if "r" in prefix: - matches = list(self.raw_quote_pattern.finditer(node.raw_value)) + matches = list(raw_quote_pattern.finditer(node.raw_value)) else: - matches = list(self.quote_pattern.finditer(node.raw_value)) + matches = list(quote_pattern.finditer(node.raw_value)) # avoid cases like: "where name = 'foo\\\'s name'" # don't count \\' as these are escaped in string literals - return ( - matches[-1].end() == len(node.raw_value) - if matches and len(matches) % 2 == modulo_2 - else False - ) + return (matches != None) and len(matches) % 2 == modulo_2 case cst.FormattedStringText(): # TODO may be in the middle i.e. f"name='home_{exp}'" # be careful of f"name='literal'", it needs one but not two @@ -377,15 +418,16 @@ def _is_literal_start(self, node: cst.CSTNode, modulo_2: int) -> bool: return False def _is_literal_end(self, node: cst.CSTNode) -> bool: + print(node) match node: case cst.SimpleString(): if "b" in node.prefix: return False if "r" in node.prefix: - matches = list(self.raw_quote_pattern.finditer(node.raw_value)) + matches = list(raw_quote_pattern.finditer(node.raw_value)) else: - matches = list(self.quote_pattern.finditer(node.raw_value)) - return matches[0].start() == 0 if matches else False + matches = list(quote_pattern.finditer(node.raw_value)) + return bool(matches) case cst.FormattedStringText(): # TODO may be in the middle i.e. f"'{exp}_home'" # be careful of f"name='literal'", it needs one but not two diff --git a/tests/codemods/test_sql_parameterization.py b/tests/codemods/test_sql_parameterization.py index 425b16c7..a73bec1b 100644 --- a/tests/codemods/test_sql_parameterization.py +++ b/tests/codemods/test_sql_parameterization.py @@ -51,6 +51,26 @@ def test_multiple(self, tmpdir): self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 + def test_simple_with_quotes_in_middle(self, tmpdir): + input_code = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute(r"SELECT * from USERS WHERE name ='user_" + name + r"_system'") + """ + expected = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute(r"SELECT * from USERS WHERE name =?", (r"user_" + name + r"_system", )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + def test_can_deal_with_multiple_variables(self, tmpdir): input_code = """\ import sqlite3 From ee40238c0227e82c09e70a13b36a4e88486341c5 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Tue, 7 Nov 2023 09:21:16 -0300 Subject: [PATCH 2/8] Initial support for f-strings --- .../remove_empty_string_concatenation.py | 19 +- src/core_codemods/sql_parameterization.py | 209 +++++++++++------- tests/codemods/test_sql_parameterization.py | 17 +- 3 files changed, 160 insertions(+), 85 deletions(-) diff --git a/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py b/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py index 58bda6b0..c70d115d 100644 --- a/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py +++ b/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py @@ -1,5 +1,6 @@ +from typing import Union import libcst as cst -from libcst import CSTTransformer +from libcst import CSTTransformer, RemovalSentinel, SimpleString class RemoveEmptyStringConcatenation(CSTTransformer): @@ -7,6 +8,22 @@ class RemoveEmptyStringConcatenation(CSTTransformer): Removes concatenation with empty strings (e.g. "hello " + "") or "hello" "" """ + def leave_FormattedStringExpression( + self, + original_node: cst.FormattedStringExpression, + updated_node: cst.FormattedStringExpression, + ) -> Union[ + cst.BaseFormattedStringContent, + cst.FlattenSentinel[cst.BaseFormattedStringContent], + RemovalSentinel, + ]: + expr = original_node.expression + match expr: + case SimpleString(): # type: ignore + if expr.raw_value == "": + return RemovalSentinel.REMOVE + return updated_node + def leave_BinaryOperation( self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation ) -> cst.BaseExpression: diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index c6b2e360..2d47f430 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -1,8 +1,15 @@ import re from typing import Optional, Tuple +import itertools import libcst as cst -from libcst import SimpleWhitespace, matchers -from libcst.codemod import Codemod, CodemodContext, ContextAwareVisitor +from libcst import FormattedString, SimpleWhitespace, ensure_type, matchers +from libcst.codemod import ( + Codemod, + CodemodContext, + ContextAwareTransformer, + ContextAwareVisitor, +) +from libcst.codemod.commands.unnecessary_format_string import UnnecessaryFormatString from libcst.metadata import ( ClassScope, GlobalScope, @@ -137,7 +144,10 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: self.file_context.codemod_changes.append( Change(line_number, SQLQueryParameterization.CHANGE_DESCRIPTION) ) + # Normalization and cleanup result = result.visit(RemoveEmptyStringConcatenation()) + result = NormalizeFStrings(self.context).transform_module(result) + result = UnnecessaryFormatString(self.context).transform_module(result) return result @@ -148,28 +158,25 @@ def _fix_injection( self.changed_nodes[expr] = cst.parse_expression('""') prepend = append = None + # remove quote literal from start - match start: + current_start = self.changed_nodes.get(start) or start + prepend_raw_value = None + + t = _extract_prefix_raw_value(self, current_start) + prefix, raw_value = t if t else ("", "") + + # gather string after the quote + if "r" in prefix: + quote_span = list(raw_quote_pattern.finditer(raw_value))[-1] + else: + quote_span = list(quote_pattern.finditer(raw_value))[-1] + new_raw_value = raw_value[: quote_span.start()] + parameter_token + prepend_raw_value = raw_value[quote_span.end() :] + + match current_start: case cst.SimpleString(): - current_start = self.changed_nodes.get(start) or start - - # find - if "r" in start.prefix.lower(): - quote_span = list( - raw_quote_pattern.finditer(current_start.raw_value) - )[-1] - else: - quote_span = list(quote_pattern.finditer(current_start.raw_value))[ - -1 - ] - - # gather string after the quote # uses the same quote and prefixes to guarantee it will be correctly interpreted - new_raw_value = ( - current_start.raw_value[: quote_span.start()] + parameter_token - ) - - prepend_raw_value = current_start.raw_value[quote_span.end() :] if prepend_raw_value: prepend = cst.SimpleString( value=current_start.prefix @@ -184,28 +191,36 @@ def _fix_injection( + new_raw_value + current_start.quote ) - self.changed_nodes[start] = current_start.with_changes(value=new_value) + case cst.FormattedStringText(): - # TODO formatted string case - pass + if prepend_raw_value: + prepend = cst.SimpleString( + value=("r" if "r" in prefix else "") + + "'" + + prepend_raw_value + + "'" + ) + + new_value = new_raw_value + self.changed_nodes[start] = current_start.with_changes(value=new_value) # remove quote literal from end + current_end = self.changed_nodes.get(end) or end + append_raw_value = None + + t = _extract_prefix_raw_value(self, current_end) + prefix, raw_value = t if t else ("", "") + if "r" in prefix: + quote_span = list(raw_quote_pattern.finditer(raw_value))[0] + else: + quote_span = list(quote_pattern.finditer(raw_value))[0] + + new_raw_value = raw_value[quote_span.end() :] + append_raw_value = raw_value[: quote_span.start()] match end: case cst.SimpleString(): - current_end = self.changed_nodes.get(end) or end - - if "r" in start.prefix.lower(): - quote_span = list( - raw_quote_pattern.finditer(current_end.raw_value) - )[0] - else: - quote_span = list(quote_pattern.finditer(current_end.raw_value))[0] - - new_raw_value = current_end.raw_value[quote_span.end() :] - # gather string up to quote to parameter - append_raw_value = current_end.raw_value[: quote_span.start()] if append_raw_value: append = cst.SimpleString( value=current_end.prefix @@ -222,8 +237,16 @@ def _fix_injection( ) self.changed_nodes[end] = current_end.with_changes(value=new_value) case cst.FormattedStringText(): - # TODO formatted string case - pass + if append_raw_value: + append = cst.SimpleString( + value=("r" if "r" in prefix else "") + + "'" + + append_raw_value + + "'" + ) + + new_value = new_raw_value + self.changed_nodes[end] = current_end.with_changes(value=new_value) return (prepend, append) @@ -365,15 +388,15 @@ def leave_Module(self, original_node: cst.Module): modulo_2 = 1 def _is_not_a_single_quote(self, expression: cst.CSTNode) -> bool: - match expression: - case cst.SimpleString(): - prefix = expression.prefix.lower() - if "b" in prefix: - return False - if "r" in prefix: - return raw_quote_pattern.fullmatch(expression.raw_value) is None - return quote_pattern.fullmatch(expression.raw_value) is None - return True + t = _extract_prefix_raw_value(self, expression) + if not t: + return True + prefix, raw_value = t + if "b" in prefix: + return False + if "r" in prefix: + return raw_quote_pattern.fullmatch(raw_value) is None + return quote_pattern.fullmatch(raw_value) is None def _is_injectable(self, expression: cst.CSTNode) -> bool: # TODO exceptions @@ -399,40 +422,53 @@ def _is_injectable(self, expression: cst.CSTNode) -> bool: return True def _is_literal_start(self, node: cst.CSTNode, modulo_2: int) -> bool: - match node: - case cst.SimpleString(): - prefix = node.prefix.lower() - if "b" in prefix: - return False - if "r" in prefix: - matches = list(raw_quote_pattern.finditer(node.raw_value)) - else: - matches = list(quote_pattern.finditer(node.raw_value)) - # avoid cases like: "where name = 'foo\\\'s name'" - # don't count \\' as these are escaped in string literals - return (matches != None) and len(matches) % 2 == modulo_2 - case cst.FormattedStringText(): - # TODO may be in the middle i.e. f"name='home_{exp}'" - # be careful of f"name='literal'", it needs one but not two - return False - return False + t = _extract_prefix_raw_value(self, node) + if not t: + return False + prefix, raw_value = t + if "b" in prefix: + return False + if "r" in prefix: + matches = list(raw_quote_pattern.finditer(raw_value)) + else: + matches = list(quote_pattern.finditer(raw_value)) + # avoid cases like: "where name = 'foo\\\'s name'" + # don't count \\' as these are escaped in string literals + return (matches != None) and len(matches) % 2 == modulo_2 def _is_literal_end(self, node: cst.CSTNode) -> bool: - print(node) - match node: - case cst.SimpleString(): - if "b" in node.prefix: - return False - if "r" in node.prefix: - matches = list(raw_quote_pattern.finditer(node.raw_value)) - else: - matches = list(quote_pattern.finditer(node.raw_value)) - return bool(matches) - case cst.FormattedStringText(): - # TODO may be in the middle i.e. f"'{exp}_home'" - # be careful of f"name='literal'", it needs one but not two - return False - return False + t = _extract_prefix_raw_value(self, node) + if not t: + return False + prefix, raw_value = t + if "b" in prefix: + return False + if "r" in prefix: + matches = list(raw_quote_pattern.finditer(raw_value)) + else: + matches = list(quote_pattern.finditer(raw_value)) + return bool(matches) + + +class NormalizeFStrings(ContextAwareTransformer): + """ + Finds all the f-strings whose parts are only composed of FormattedStringText and concats all of them in a single part. + """ + + def leave_FormattedString( + self, original_node: cst.FormattedString, updated_node: cst.FormattedString + ) -> cst.BaseExpression: + all_parts = list( + itertools.takewhile( + lambda x: isinstance(x, cst.FormattedStringText), original_node.parts + ) + ) + if len(all_parts) != len(updated_node.parts): + return updated_node + new_part = cst.FormattedStringText( + value="".join(map(lambda x: x.value, all_parts)) + ) + return updated_node.with_changes(parts=[new_part]) class FindQueryCalls(ContextAwareVisitor): @@ -470,3 +506,18 @@ def leave_Call(self, original_node: cst.Call) -> None: expr.value ): self.calls[original_node] = query_visitor.leaves + + +def _extract_prefix_raw_value(self, node) -> Optional[Tuple[str, str]]: + match node: + case cst.SimpleString(): + return node.prefix.lower(), node.raw_value + case cst.FormattedStringText(): + try: + parent = self.get_metadata(ParentNodeProvider, node) + parent = ensure_type(parent, FormattedString) + except: + return None + return parent.start.lower(), node.value + case _: + return None diff --git a/tests/codemods/test_sql_parameterization.py b/tests/codemods/test_sql_parameterization.py index a73bec1b..2336ac04 100644 --- a/tests/codemods/test_sql_parameterization.py +++ b/tests/codemods/test_sql_parameterization.py @@ -158,10 +158,7 @@ def test_simple_concatenated_strings(self, tmpdir): self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 - # negative tests below - def test_formatted_string_simple(self, tmpdir): - # TODO change when we add support for it input_code = """\ import sqlite3 @@ -170,8 +167,18 @@ def test_formatted_string_simple(self, tmpdir): cursor = connection.cursor() cursor.execute(f"SELECT * from USERS WHERE name='{name}'") """ - self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 + expected = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name=?", (name, )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + # negative tests below def test_no_sql_keyword(self, tmpdir): input_code = """\ From c0d5b8d0342e3c019b72f2a3503711676ecae634 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Wed, 8 Nov 2023 11:12:24 -0300 Subject: [PATCH 3/8] Bugfixes and refactoring --- src/core_codemods/sql_parameterization.py | 139 +++++++++++++------- tests/codemods/test_sql_parameterization.py | 48 ++++++- 2 files changed, 136 insertions(+), 51 deletions(-) diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 2d47f430..553a98f6 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -32,6 +32,8 @@ from codemodder.codemods.utils_mixin import NameResolutionMixin from codemodder.file_context import FileContext +from enum import Enum + parameter_token = "?" literal_number = matchers.Integer() | matchers.Float() | matchers.Imaginary() @@ -42,6 +44,54 @@ raw_quote_pattern = re.compile(r"(? Optional["BaseType"]: + """ + Tries to infer if the type of a given expression is one of the base literal types. + """ + # The current implementation could be enhanced with a few more cases + match node: + case cst.Integer() | cst.Imaginary() | cst.Float() | cst.Call( + func=cst.Name("int") + ) | cst.Call(func=cst.Name("float")) | cst.Call( + func=cst.Name("abs") + ) | cst.Call( + func=cst.Name("len") + ): + return BaseType.NUMBER + case cst.Call(name=cst.Name("list")) | cst.List() | cst.ListComp(): + return BaseType.LIST + case cst.Call(func=cst.Name("str")) | cst.FormattedString(): + return BaseType.STRING + case cst.SimpleString(): + if "b" in node.prefix.lower(): + return BaseType.BYTES + return BaseType.STRING + case cst.ConcatenatedString(): + return cls.infer_expression_type(node.left) + case cst.BinaryOperation(operator=cst.Add()): + return cls.infer_expression_type( + node.left + ) or cls.infer_expression_type(node.right) + case cst.IfExp(): + if_true = cls.infer_expression_type(node.body) + or_else = cls.infer_expression_type(node.orelse) + if if_true == or_else: + return if_true + return None + + class SQLQueryParameterization(BaseCodemod, UtilsMixin, Codemod): SUMMARY = "Parameterize SQL Queries" METADATA = CodemodMetadata( @@ -80,26 +130,36 @@ def __init__( UtilsMixin.__init__(self, []) Codemod.__init__(self, context) - def _build_param_element_recurse(self, middle, index: int) -> cst.BaseExpression: - # TODO maybe a parameterized string would be better here - # f-strings need python 3.6 though - if index == 0: - return middle[0] - operator = cst.Add( - whitespace_after=cst.SimpleWhitespace(" "), - whitespace_before=cst.SimpleWhitespace(" "), - ) - return cst.BinaryOperation( - operator=operator, - left=self._build_param_element_recurse(middle, index - 1), - right=middle[index], - ) - def _build_param_element(self, prepend, middle, append): new_middle = ( ([prepend] if prepend else []) + middle + ([append] if append else []) ) - return self._build_param_element_recurse(new_middle, len(new_middle) - 1) + format_pieces: list[str] = [] + format_expr_count = 0 + args = [] + if len(new_middle) == 1: + # TODO maybe handle conversion here? + return new_middle[0] + for e in new_middle: + exception = False + if isinstance(e, cst.SimpleString | cst.FormattedStringText): + t = _extract_prefix_raw_value(self, e) + if t: + prefix, raw_value = t + if not "b" in prefix and not "r" in prefix and not "u" in prefix: + format_pieces.append(raw_value) + exception = True + if not exception: + format_pieces.append(f"{{{format_expr_count}}}") + format_expr_count += 1 + args.append(cst.Arg(e)) + + format_string = "".join(format_pieces) + format_string_node = cst.SimpleString(f"'{format_string}'") + return cst.Call( + func=cst.Attribute(value=format_string_node, attr=cst.Name(value="format")), + args=args, + ) def transform_module_impl(self, tree: cst.Module) -> cst.Module: find_queries = FindQueryCalls(self.context) @@ -118,8 +178,6 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: # build tuple elements and fix injection params_elements: list[cst.Element] = [] for start, middle, end in ep.injection_patterns: - # TODO support for elements from f-strings? - # reminder that python has no implicit conversion while concatenating with +, might need to use str() for a particular expression prepend, append = self._fix_injection(start, middle, end) expr = self._build_param_element(prepend, middle, append) params_elements.append( @@ -155,7 +213,12 @@ def _fix_injection( self, start: cst.CSTNode, middle: list[cst.CSTNode], end: cst.CSTNode ): for expr in middle: - self.changed_nodes[expr] = cst.parse_expression('""') + if isinstance( + expr, cst.FormattedStringText | cst.FormattedStringExpression + ): + self.changed_nodes[expr] = cst.RemovalSentinel.REMOVE + else: + self.changed_nodes[expr] = cst.parse_expression('""') prepend = append = None @@ -262,10 +325,6 @@ def __init__(self, context) -> None: self.leaves: list[cst.CSTNode] = [] def on_visit(self, node: cst.CSTNode): - # TODO function to detect if BinaryExpression results in a number or list? - # will it only matter inside fstrings? (outside it, we expect query to be string) - # check if any is a string should be necessary - # We only care about expressions, ignore everything else # Mostly as a sanity check, this may not be necessary since we start the visit with an expression node if isinstance( @@ -280,8 +339,8 @@ def on_visit(self, node: cst.CSTNode): if not matchers.matches( node, matchers.Name() - | matchers.BinaryOperation() | matchers.FormattedString() + | matchers.BinaryOperation() | matchers.FormattedStringExpression() | matchers.ConcatenatedString(), ): @@ -290,6 +349,13 @@ def on_visit(self, node: cst.CSTNode): return super().on_visit(node) return False + def visit_BinaryOperation(self, node: cst.BinaryOperation) -> Optional[bool]: + maybe_type = BaseType.infer_expression_type(node) + if not maybe_type or maybe_type == BaseType.STRING: + return True + self.leaves.append(node) + return False + # recursive search def visit_Name(self, node: cst.Name) -> Optional[bool]: self.leaves.extend(self.recurse_Name(node)) @@ -380,7 +446,6 @@ def leave_Module(self, original_node: cst.Module): # end may contain the start of another literal, put it back # should not be a single quote - # TODO think of a better solution here if self._is_literal_start(end, 0) and self._is_not_a_single_quote(end): modulo_2 = 0 leaves.append(end) @@ -398,28 +463,8 @@ def _is_not_a_single_quote(self, expression: cst.CSTNode) -> bool: return raw_quote_pattern.fullmatch(raw_value) is None return quote_pattern.fullmatch(raw_value) is None - def _is_injectable(self, expression: cst.CSTNode) -> bool: - # TODO exceptions - # tuple and list literals ??? - # BinaryExpression case - match expression: - case cst.Integer() | cst.Float() | cst.Imaginary() | cst.SimpleString(): - return False - case cst.Call(func=cst.Name(value="str"), args=[cst.Arg(value=arg), *_]): - # TODO - # treat str(encoding = 'utf-8', object=obj) - # ensure this is the built-in - if matchers.matches(arg, literal): # type: ignore - return False - case cst.FormattedStringExpression() if matchers.matches( - expression, literal - ): - return False - case cst.IfExp(): - return self._is_injectable(expression.body) or self._is_injectable( - expression.orelse - ) - return True + def _is_injectable(self, expression: cst.BaseExpression) -> bool: + return not bool(BaseType.infer_expression_type(expression)) def _is_literal_start(self, node: cst.CSTNode, modulo_2: int) -> bool: t = _extract_prefix_raw_value(self, node) diff --git a/tests/codemods/test_sql_parameterization.py b/tests/codemods/test_sql_parameterization.py index 2336ac04..79ef36eb 100644 --- a/tests/codemods/test_sql_parameterization.py +++ b/tests/codemods/test_sql_parameterization.py @@ -58,7 +58,7 @@ def test_simple_with_quotes_in_middle(self, tmpdir): name = input() connection = sqlite3.connect("my_db.db") cursor = connection.cursor() - cursor.execute(r"SELECT * from USERS WHERE name ='user_" + name + r"_system'") + cursor.execute("SELECT * from USERS WHERE name ='user_" + name + r"_system'") """ expected = """\ import sqlite3 @@ -66,7 +66,7 @@ def test_simple_with_quotes_in_middle(self, tmpdir): name = input() connection = sqlite3.connect("my_db.db") cursor = connection.cursor() - cursor.execute(r"SELECT * from USERS WHERE name =?", (r"user_" + name + r"_system", )) + cursor.execute("SELECT * from USERS WHERE name =?", ('user_{0}{1}'.format(name, r"_system"), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 @@ -178,6 +178,46 @@ def test_formatted_string_simple(self, tmpdir): self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 + def test_formatted_string_quote_in_middle(self, tmpdir): + input_code = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute(f"SELECT * from USERS WHERE name='user_{name}_admin'") + """ + expected = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name=?", ('user_{0}_admin'.format(name), )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_formatted_string_with_literal(self, tmpdir): + input_code = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute(f"SELECT * from USERS WHERE name='{name}_{1+2}'") + """ + expected = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name=?", ('{0}_{1}'.format(name, 1+2), )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + # negative tests below def test_no_sql_keyword(self, tmpdir): @@ -209,7 +249,7 @@ def test_multiple_expressions_injection(self, tmpdir): name = input() connection = sqlite3.connect("my_db.db") cursor = connection.cursor() - cursor.execute("SELECT * from USERS WHERE name =?", (name + "_username", )) + cursor.execute("SELECT * from USERS WHERE name =?", ('{0}_username'.format(name), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 @@ -231,7 +271,7 @@ def test_wont_parameterize_literals_if(self, tmpdir): connection = sqlite3.connect("my_db.db") cursor = connection.cursor() - cursor.execute("SELECT * from USERS WHERE name ='" + ('Jenny' if True else 123) + "'") + cursor.execute("SELECT * from USERS WHERE name ='" + ('Jenny' if True else 'Lorelei') + "'") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) assert len(self.file_context.codemod_changes) == 0 From a5d15bf635d416de663d6e6effe34a63bea66db7 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Thu, 9 Nov 2023 09:35:59 -0300 Subject: [PATCH 4/8] Bugfixes, refactoring and new tests --- .../remove_empty_string_concatenation.py | 33 +++---- src/codemodder/codemods/utils.py | 49 +++++++++++ src/core_codemods/sql_parameterization.py | 64 +++----------- tests/codemods/test_sql_parameterization.py | 86 ++++++++++++++++--- 4 files changed, 149 insertions(+), 83 deletions(-) diff --git a/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py b/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py index c70d115d..1721c928 100644 --- a/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py +++ b/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py @@ -39,19 +39,22 @@ def leave_ConcatenatedString( def handle_node( self, updated_node: cst.BinaryOperation | cst.ConcatenatedString ) -> cst.BaseExpression: - match updated_node.left: - # TODO f-string cases - case cst.SimpleString() if updated_node.left.raw_value == "": - match updated_node.right: - case cst.SimpleString() if updated_node.right.raw_value == "": - return cst.SimpleString(value='""') - case _: - return updated_node.right - match updated_node.right: - case cst.SimpleString() if updated_node.right.raw_value == "": - match updated_node.left: - case cst.SimpleString() if updated_node.left.raw_value == "": - return cst.SimpleString(value='""') - case _: - return updated_node.left + left = updated_node.left + right = updated_node.right + if self._is_empty_string_literal(left): + if self._is_empty_string_literal(right): + return cst.SimpleString(value='""') + return right + if self._is_empty_string_literal(right): + if self._is_empty_string_literal(left): + return cst.SimpleString(value='""') + return left return updated_node + + def _is_empty_string_literal(self, node): + match node: + case cst.SimpleString() if node.raw_value == "": + return True + case cst.FormattedString() if not node.parts: + return True + return False diff --git a/src/codemodder/codemods/utils.py b/src/codemodder/codemods/utils.py index a4a9c0ab..a7c4be57 100644 --- a/src/codemodder/codemods/utils.py +++ b/src/codemodder/codemods/utils.py @@ -1,3 +1,4 @@ +from enum import Enum from pathlib import Path from typing import Optional, Any @@ -5,6 +6,54 @@ import libcst as cst +class BaseType(Enum): + """ + An enumeration representing the base literal types in Python. + """ + + NUMBER = 1 + LIST = 2 + STRING = 3 + BYTES = 4 + + @classmethod + # pylint: disable-next=R0911 + def infer_expression_type(cls, node: cst.BaseExpression) -> Optional["BaseType"]: + """ + Tries to infer if the type of a given expression is one of the base literal types. + """ + # The current implementation could be enhanced with a few more cases + match node: + case cst.Integer() | cst.Imaginary() | cst.Float() | cst.Call( + func=cst.Name("int") + ) | cst.Call(func=cst.Name("float")) | cst.Call( + func=cst.Name("abs") + ) | cst.Call( + func=cst.Name("len") + ): + return BaseType.NUMBER + case cst.Call(name=cst.Name("list")) | cst.List() | cst.ListComp(): + return BaseType.LIST + case cst.Call(func=cst.Name("str")) | cst.FormattedString(): + return BaseType.STRING + case cst.SimpleString(): + if "b" in node.prefix.lower(): + return BaseType.BYTES + return BaseType.STRING + case cst.ConcatenatedString(): + return cls.infer_expression_type(node.left) + case cst.BinaryOperation(operator=cst.Add()): + return cls.infer_expression_type( + node.left + ) or cls.infer_expression_type(node.right) + case cst.IfExp(): + if_true = cls.infer_expression_type(node.body) + or_else = cls.infer_expression_type(node.orelse) + if if_true == or_else: + return if_true + return None + + class SequenceExtension: def __init__(self, sequence: list[cst.CSTNode]) -> None: self.sequence = sequence diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 553a98f6..93a20da6 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -9,7 +9,6 @@ ContextAwareTransformer, ContextAwareVisitor, ) -from libcst.codemod.commands.unnecessary_format_string import UnnecessaryFormatString from libcst.metadata import ( ClassScope, GlobalScope, @@ -28,12 +27,15 @@ from codemodder.codemods.transformations.remove_empty_string_concatenation import ( RemoveEmptyStringConcatenation, ) -from codemodder.codemods.utils import Append, ReplaceNodes, get_function_name_node +from codemodder.codemods.utils import ( + Append, + BaseType, + ReplaceNodes, + get_function_name_node, +) from codemodder.codemods.utils_mixin import NameResolutionMixin from codemodder.file_context import FileContext -from enum import Enum - parameter_token = "?" literal_number = matchers.Integer() | matchers.Float() | matchers.Imaginary() @@ -44,54 +46,6 @@ raw_quote_pattern = re.compile(r"(? Optional["BaseType"]: - """ - Tries to infer if the type of a given expression is one of the base literal types. - """ - # The current implementation could be enhanced with a few more cases - match node: - case cst.Integer() | cst.Imaginary() | cst.Float() | cst.Call( - func=cst.Name("int") - ) | cst.Call(func=cst.Name("float")) | cst.Call( - func=cst.Name("abs") - ) | cst.Call( - func=cst.Name("len") - ): - return BaseType.NUMBER - case cst.Call(name=cst.Name("list")) | cst.List() | cst.ListComp(): - return BaseType.LIST - case cst.Call(func=cst.Name("str")) | cst.FormattedString(): - return BaseType.STRING - case cst.SimpleString(): - if "b" in node.prefix.lower(): - return BaseType.BYTES - return BaseType.STRING - case cst.ConcatenatedString(): - return cls.infer_expression_type(node.left) - case cst.BinaryOperation(operator=cst.Add()): - return cls.infer_expression_type( - node.left - ) or cls.infer_expression_type(node.right) - case cst.IfExp(): - if_true = cls.infer_expression_type(node.body) - or_else = cls.infer_expression_type(node.orelse) - if if_true == or_else: - return if_true - return None - - class SQLQueryParameterization(BaseCodemod, UtilsMixin, Codemod): SUMMARY = "Parameterize SQL Queries" METADATA = CodemodMetadata( @@ -205,7 +159,9 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: # Normalization and cleanup result = result.visit(RemoveEmptyStringConcatenation()) result = NormalizeFStrings(self.context).transform_module(result) - result = UnnecessaryFormatString(self.context).transform_module(result) + # TODO The transform below may break nested f-strings: f"{f"1"}" -> f"{"1"}" + # May be a bug... + # result = UnnecessaryFormatString(self.context).transform_module(result) return result @@ -561,7 +517,7 @@ def _extract_prefix_raw_value(self, node) -> Optional[Tuple[str, str]]: try: parent = self.get_metadata(ParentNodeProvider, node) parent = ensure_type(parent, FormattedString) - except: + except Exception: return None return parent.start.lower(), node.value case _: diff --git a/tests/codemods/test_sql_parameterization.py b/tests/codemods/test_sql_parameterization.py index 79ef36eb..e7c72184 100644 --- a/tests/codemods/test_sql_parameterization.py +++ b/tests/codemods/test_sql_parameterization.py @@ -37,7 +37,7 @@ def test_multiple(self, tmpdir): phone = input() connection = sqlite3.connect("my_db.db") cursor = connection.cursor() - cursor.execute("SELECT * from USERS WHERE name ='" + name + "' AND phone ='" + phone + "'" ) + cursor.execute("SELECT * from USERS WHERE name ='" + name + r"' AND phone ='" + phone + "'" ) """ expected = """\ import sqlite3 @@ -46,7 +46,7 @@ def test_multiple(self, tmpdir): phone = input() connection = sqlite3.connect("my_db.db") cursor = connection.cursor() - cursor.execute("SELECT * from USERS WHERE name =?" + " AND phone =?", (name, phone, )) + cursor.execute("SELECT * from USERS WHERE name =?" + r" AND phone =?", (name, phone, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 @@ -158,6 +158,10 @@ def test_simple_concatenated_strings(self, tmpdir): self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 + +class TestSQLQueryParameterizationFormattedString(BaseCodemodTest): + codemod = SQLQueryParameterization + def test_formatted_string_simple(self, tmpdir): input_code = """\ import sqlite3 @@ -173,7 +177,7 @@ def test_formatted_string_simple(self, tmpdir): name = input() connection = sqlite3.connect("my_db.db") cursor = connection.cursor() - cursor.execute("SELECT * from USERS WHERE name=?", (name, )) + cursor.execute(f"SELECT * from USERS WHERE name=?", (name, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 @@ -193,7 +197,7 @@ def test_formatted_string_quote_in_middle(self, tmpdir): name = input() connection = sqlite3.connect("my_db.db") cursor = connection.cursor() - cursor.execute("SELECT * from USERS WHERE name=?", ('user_{0}_admin'.format(name), )) + cursor.execute(f"SELECT * from USERS WHERE name=?", ('user_{0}_admin'.format(name), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 @@ -213,26 +217,50 @@ def test_formatted_string_with_literal(self, tmpdir): name = input() connection = sqlite3.connect("my_db.db") cursor = connection.cursor() - cursor.execute("SELECT * from USERS WHERE name=?", ('{0}_{1}'.format(name, 1+2), )) + cursor.execute(f"SELECT * from USERS WHERE name=?", ('{0}_{1}'.format(name, 1+2), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 - # negative tests below + def test_formatted_string_nested(self, tmpdir): + input_code = """\ + import sqlite3 - def test_no_sql_keyword(self, tmpdir): + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute(f"SELECT * from USERS WHERE name={f"'{name}'"}") + """ + expected = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute(f"SELECT * from USERS WHERE name={f"?"}", (name, )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_formatted_string_concat_mixed(self, tmpdir): input_code = """\ import sqlite3 - def foo(self, cursor, name, phone): + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name='" + f"{name}_{b'123'}" "'") + """ + expected = """\ + import sqlite3 - a = "COLLECT * from USERS " - b = "WHERE name = '" + name - c = "' AND phone = '" + phone + "'" - return cursor.execute(a + b + c) + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name=?", ('{0}_{1}'.format(name, b'123'), )) """ - self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 def test_multiple_expressions_injection(self, tmpdir): input_code = """\ @@ -254,6 +282,36 @@ def test_multiple_expressions_injection(self, tmpdir): self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 + +class TestSQLQueryParameterizationNegative(BaseCodemodTest): + codemod = SQLQueryParameterization + + # negative tests below + def test_no_sql_keyword(self, tmpdir): + input_code = """\ + import sqlite3 + + def foo(self, cursor, name, phone): + + a = "COLLECT * from USERS " + b = "WHERE name = '" + name + c = "' AND phone = '" + phone + "'" + return cursor.execute(a + b + c) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 + + def test_wont_mess_with_byte_strings(self, tmpdir): + input_code = """\ + import sqlite3 + + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE " + b"name ='" + str(1234) + b"'") + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 + def test_wont_parameterize_literals(self, tmpdir): input_code = """\ import sqlite3 From 4480ad13de478506ca6639a2fda40fa2a52fc5fa Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Thu, 9 Nov 2023 10:33:42 -0300 Subject: [PATCH 5/8] Added separated tests for BaseType --- tests/test_basetype.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 tests/test_basetype.py diff --git a/tests/test_basetype.py b/tests/test_basetype.py new file mode 100644 index 00000000..e72907ec --- /dev/null +++ b/tests/test_basetype.py @@ -0,0 +1,36 @@ +import libcst as cst +from codemodder.codemods.utils import BaseType + + +class TestBaseType: + def test_binary_op_number(self): + e = cst.parse_expression("1 + float(2)") + assert BaseType.infer_expression_type(e) == BaseType.NUMBER + + def test_binary_op_string_mixed(self): + e = cst.parse_expression('f"a"+foo()') + assert BaseType.infer_expression_type(e) == BaseType.STRING + + def test_binary_op_list(self): + e = cst.parse_expression("[1,2] + [x for x in [3]] + list((4,5))") + assert BaseType.infer_expression_type(e) == BaseType.LIST + + def test_binary_op_none(self): + e = cst.parse_expression("foo() + bar()") + assert BaseType.infer_expression_type(e) == None + + def test_bytes(self): + e = cst.parse_expression('b"123"') + assert BaseType.infer_expression_type(e) == BaseType.BYTES + + def test_if_mixed(self): + e = cst.parse_expression('1 if True else "a"') + assert BaseType.infer_expression_type(e) == None + + def test_if_numbers(self): + e = cst.parse_expression("abs(1) if True else 2") + assert BaseType.infer_expression_type(e) == BaseType.NUMBER + + def test_if_numbers2(self): + e = cst.parse_expression("float(1) if True else len([1,2])") + assert BaseType.infer_expression_type(e) == BaseType.NUMBER From 2687508cb126d24d998cdd157426f5ea22faf1a2 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Thu, 9 Nov 2023 10:51:24 -0300 Subject: [PATCH 6/8] Small typing and variables cleanup --- src/codemodder/codemods/utils.py | 8 +------- src/core_codemods/sql_parameterization.py | 11 ++++------- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/src/codemodder/codemods/utils.py b/src/codemodder/codemods/utils.py index a7c4be57..ce7b5ac5 100644 --- a/src/codemodder/codemods/utils.py +++ b/src/codemodder/codemods/utils.py @@ -76,13 +76,7 @@ def __init__( self, replacements: dict[ cst.CSTNode, - dict[ - str, - cst.CSTNode - | cst.FlattenSentinel - | cst.RemovalSentinel - | dict[str, Any], - ], + cst.CSTNode | cst.FlattenSentinel | cst.RemovalSentinel | dict[str, Any], ], ): self.replacements = replacements diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 93a20da6..f5d85cf6 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -1,5 +1,5 @@ import re -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import itertools import libcst as cst from libcst import FormattedString, SimpleWhitespace, ensure_type, matchers @@ -38,10 +38,6 @@ parameter_token = "?" -literal_number = matchers.Integer() | matchers.Float() | matchers.Imaginary() -literal_string = matchers.SimpleString() -literal = literal_number | literal_string - quote_pattern = re.compile(r"(? None: self.changed_nodes: dict[ - cst.CSTNode, cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel + cst.CSTNode, + cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel | dict[str, Any], ] = {} BaseCodemod.__init__(self, file_context, *codemod_args) UtilsMixin.__init__(self, []) @@ -237,7 +234,7 @@ def _fix_injection( new_raw_value = raw_value[quote_span.end() :] append_raw_value = raw_value[: quote_span.start()] - match end: + match current_end: case cst.SimpleString(): # gather string up to quote to parameter if append_raw_value: From 8b293e011a2c55d15ba7e20521afa56614a3f2ea Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Mon, 13 Nov 2023 08:50:53 -0300 Subject: [PATCH 7/8] Simplified case match --- .../transformations/remove_empty_string_concatenation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py b/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py index 1721c928..fbec19a9 100644 --- a/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py +++ b/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py @@ -19,9 +19,8 @@ def leave_FormattedStringExpression( ]: expr = original_node.expression match expr: - case SimpleString(): # type: ignore - if expr.raw_value == "": - return RemovalSentinel.REMOVE + case SimpleString() if expr.raw_value == "": # type: ignore + return RemovalSentinel.REMOVE return updated_node def leave_BinaryOperation( From 672c4cd8f310d981debca4ad29ce98d9e3896afc Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Tue, 14 Nov 2023 10:26:00 -0300 Subject: [PATCH 8/8] Some refactoring --- src/codemodder/codemods/utils.py | 70 +++++++------ src/core_codemods/sql_parameterization.py | 119 +++++++++++----------- tests/test_basetype.py | 18 ++-- 3 files changed, 103 insertions(+), 104 deletions(-) diff --git a/src/codemodder/codemods/utils.py b/src/codemodder/codemods/utils.py index ce7b5ac5..1b29298c 100644 --- a/src/codemodder/codemods/utils.py +++ b/src/codemodder/codemods/utils.py @@ -16,42 +16,40 @@ class BaseType(Enum): STRING = 3 BYTES = 4 - @classmethod - # pylint: disable-next=R0911 - def infer_expression_type(cls, node: cst.BaseExpression) -> Optional["BaseType"]: - """ - Tries to infer if the type of a given expression is one of the base literal types. - """ - # The current implementation could be enhanced with a few more cases - match node: - case cst.Integer() | cst.Imaginary() | cst.Float() | cst.Call( - func=cst.Name("int") - ) | cst.Call(func=cst.Name("float")) | cst.Call( - func=cst.Name("abs") - ) | cst.Call( - func=cst.Name("len") - ): - return BaseType.NUMBER - case cst.Call(name=cst.Name("list")) | cst.List() | cst.ListComp(): - return BaseType.LIST - case cst.Call(func=cst.Name("str")) | cst.FormattedString(): - return BaseType.STRING - case cst.SimpleString(): - if "b" in node.prefix.lower(): - return BaseType.BYTES - return BaseType.STRING - case cst.ConcatenatedString(): - return cls.infer_expression_type(node.left) - case cst.BinaryOperation(operator=cst.Add()): - return cls.infer_expression_type( - node.left - ) or cls.infer_expression_type(node.right) - case cst.IfExp(): - if_true = cls.infer_expression_type(node.body) - or_else = cls.infer_expression_type(node.orelse) - if if_true == or_else: - return if_true - return None + +# pylint: disable-next=R0911 +def infer_expression_type(node: cst.BaseExpression) -> Optional[BaseType]: + """ + Tries to infer if the resulting type of a given expression is one of the base literal types. + """ + # The current implementation covers some common cases and is in no way complete + match node: + case cst.Integer() | cst.Imaginary() | cst.Float() | cst.Call( + func=cst.Name("int") + ) | cst.Call(func=cst.Name("float")) | cst.Call( + func=cst.Name("abs") + ) | cst.Call( + func=cst.Name("len") + ): + return BaseType.NUMBER + case cst.Call(name=cst.Name("list")) | cst.List() | cst.ListComp(): + return BaseType.LIST + case cst.Call(func=cst.Name("str")) | cst.FormattedString(): + return BaseType.STRING + case cst.SimpleString(): + if "b" in node.prefix.lower(): + return BaseType.BYTES + return BaseType.STRING + case cst.ConcatenatedString(): + return infer_expression_type(node.left) + case cst.BinaryOperation(operator=cst.Add()): + return infer_expression_type(node.left) or infer_expression_type(node.right) + case cst.IfExp(): + if_true = infer_expression_type(node.body) + or_else = infer_expression_type(node.orelse) + if if_true == or_else: + return if_true + return None class SequenceExtension: diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index f5d85cf6..3619111f 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -2,7 +2,13 @@ from typing import Any, Optional, Tuple import itertools import libcst as cst -from libcst import FormattedString, SimpleWhitespace, ensure_type, matchers +from libcst import ( + FormattedString, + SimpleString, + SimpleWhitespace, + ensure_type, + matchers, +) from libcst.codemod import ( Codemod, CodemodContext, @@ -32,6 +38,7 @@ BaseType, ReplaceNodes, get_function_name_node, + infer_expression_type, ) from codemodder.codemods.utils_mixin import NameResolutionMixin from codemodder.file_context import FileContext @@ -113,6 +120,13 @@ def _build_param_element(self, prepend, middle, append): ) def transform_module_impl(self, tree: cst.Module) -> cst.Module: + # The transformation has four major steps: + # (1) FindQueryCalls - Find and gather all the SQL query execution calls. The result is a dict of call nodes and their associated list of nodes composing the query (i.e. step (2)). + # (2) LinearizeQuery - For each call, it gather all the string literals and expressions that composes the query. The result is a list of nodes whose concatenation is the query. + # (3) ExtractParameters - Detects which expressions are part of SQL string literals in the query. The result is a list of triples (a,b,c) such that a is the node that contains the start of the string literal, b is a list of expressions that composes that literal, and c is the node containing the end of the string literal. At least one node in b must be "injectable" (see). + # (4) SQLQueryParameterization - Executes steps (1)-(3) and gather a list of injection triples. For each triple (a,b,c) it makes the associated changes to insert the query parameter token. All the expressions in b are then concatenated in an expression and passed as a sequence of parameters to the execute call. + + # Steps (1) and (2) find_queries = FindQueryCalls(self.context) tree.visit(find_queries) @@ -123,10 +137,11 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: if not self.filter_by_path_includes_or_excludes(call_pos): break + # Step (3) ep = ExtractParameters(self.context, query) tree.visit(ep) - # build tuple elements and fix injection + # Step (4) - build tuple elements and fix injection params_elements: list[cst.Element] = [] for start, middle, end in ep.injection_patterns: prepend, append = self._fix_injection(start, middle, end) @@ -173,13 +188,10 @@ def _fix_injection( else: self.changed_nodes[expr] = cst.parse_expression('""') - prepend = append = None - # remove quote literal from start - current_start = self.changed_nodes.get(start) or start - prepend_raw_value = None + updated_start = self.changed_nodes.get(start) or start - t = _extract_prefix_raw_value(self, current_start) + t = _extract_prefix_raw_value(self, updated_start) prefix, raw_value = t if t else ("", "") # gather string after the quote @@ -187,45 +199,18 @@ def _fix_injection( quote_span = list(raw_quote_pattern.finditer(raw_value))[-1] else: quote_span = list(quote_pattern.finditer(raw_value))[-1] + new_raw_value = raw_value[: quote_span.start()] + parameter_token prepend_raw_value = raw_value[quote_span.end() :] - match current_start: - case cst.SimpleString(): - # uses the same quote and prefixes to guarantee it will be correctly interpreted - if prepend_raw_value: - prepend = cst.SimpleString( - value=current_start.prefix - + current_start.quote - + prepend_raw_value - + current_start.quote - ) - - new_value = ( - current_start.prefix - + current_start.quote - + new_raw_value - + current_start.quote - ) - self.changed_nodes[start] = current_start.with_changes(value=new_value) - - case cst.FormattedStringText(): - if prepend_raw_value: - prepend = cst.SimpleString( - value=("r" if "r" in prefix else "") - + "'" - + prepend_raw_value - + "'" - ) - - new_value = new_raw_value - self.changed_nodes[start] = current_start.with_changes(value=new_value) + prepend = self._remove_literal_and_gather_extra( + start, updated_start, prefix, new_raw_value, prepend_raw_value + ) # remove quote literal from end - current_end = self.changed_nodes.get(end) or end - append_raw_value = None + updated_end = self.changed_nodes.get(end) or end - t = _extract_prefix_raw_value(self, current_end) + t = _extract_prefix_raw_value(self, updated_end) prefix, raw_value = t if t else ("", "") if "r" in prefix: quote_span = list(raw_quote_pattern.finditer(raw_value))[0] @@ -234,36 +219,52 @@ def _fix_injection( new_raw_value = raw_value[quote_span.end() :] append_raw_value = raw_value[: quote_span.start()] - match current_end: + + append = self._remove_literal_and_gather_extra( + end, updated_end, prefix, new_raw_value, append_raw_value + ) + + return (prepend, append) + + # pylint: disable-next=too-many-arguments + def _remove_literal_and_gather_extra( + self, original_node, updated_node, prefix, new_raw_value, extra_raw_value + ) -> Optional[SimpleString]: + extra = None + match updated_node: case cst.SimpleString(): - # gather string up to quote to parameter - if append_raw_value: - append = cst.SimpleString( - value=current_end.prefix - + current_end.quote - + append_raw_value - + current_end.quote + # gather string after or before the quote + if extra_raw_value: + extra = cst.SimpleString( + value=updated_node.prefix + + updated_node.quote + + extra_raw_value + + updated_node.quote ) new_value = ( - current_end.prefix - + current_end.quote + updated_node.prefix + + updated_node.quote + new_raw_value - + current_end.quote + + updated_node.quote + ) + self.changed_nodes[original_node] = updated_node.with_changes( + value=new_value ) - self.changed_nodes[end] = current_end.with_changes(value=new_value) case cst.FormattedStringText(): - if append_raw_value: - append = cst.SimpleString( + if extra_raw_value: + extra = cst.SimpleString( value=("r" if "r" in prefix else "") + "'" - + append_raw_value + + extra_raw_value + "'" ) new_value = new_raw_value - self.changed_nodes[end] = current_end.with_changes(value=new_value) - return (prepend, append) + self.changed_nodes[original_node] = updated_node.with_changes( + value=new_value + ) + return extra class LinearizeQuery(ContextAwareVisitor, NameResolutionMixin): @@ -303,7 +304,7 @@ def on_visit(self, node: cst.CSTNode): return False def visit_BinaryOperation(self, node: cst.BinaryOperation) -> Optional[bool]: - maybe_type = BaseType.infer_expression_type(node) + maybe_type = infer_expression_type(node) if not maybe_type or maybe_type == BaseType.STRING: return True self.leaves.append(node) @@ -417,7 +418,7 @@ def _is_not_a_single_quote(self, expression: cst.CSTNode) -> bool: return quote_pattern.fullmatch(raw_value) is None def _is_injectable(self, expression: cst.BaseExpression) -> bool: - return not bool(BaseType.infer_expression_type(expression)) + return not bool(infer_expression_type(expression)) def _is_literal_start(self, node: cst.CSTNode, modulo_2: int) -> bool: t = _extract_prefix_raw_value(self, node) diff --git a/tests/test_basetype.py b/tests/test_basetype.py index e72907ec..2096a31e 100644 --- a/tests/test_basetype.py +++ b/tests/test_basetype.py @@ -1,36 +1,36 @@ import libcst as cst -from codemodder.codemods.utils import BaseType +from codemodder.codemods.utils import BaseType, infer_expression_type class TestBaseType: def test_binary_op_number(self): e = cst.parse_expression("1 + float(2)") - assert BaseType.infer_expression_type(e) == BaseType.NUMBER + assert infer_expression_type(e) == BaseType.NUMBER def test_binary_op_string_mixed(self): e = cst.parse_expression('f"a"+foo()') - assert BaseType.infer_expression_type(e) == BaseType.STRING + assert infer_expression_type(e) == BaseType.STRING def test_binary_op_list(self): e = cst.parse_expression("[1,2] + [x for x in [3]] + list((4,5))") - assert BaseType.infer_expression_type(e) == BaseType.LIST + assert infer_expression_type(e) == BaseType.LIST def test_binary_op_none(self): e = cst.parse_expression("foo() + bar()") - assert BaseType.infer_expression_type(e) == None + assert infer_expression_type(e) == None def test_bytes(self): e = cst.parse_expression('b"123"') - assert BaseType.infer_expression_type(e) == BaseType.BYTES + assert infer_expression_type(e) == BaseType.BYTES def test_if_mixed(self): e = cst.parse_expression('1 if True else "a"') - assert BaseType.infer_expression_type(e) == None + assert infer_expression_type(e) == None def test_if_numbers(self): e = cst.parse_expression("abs(1) if True else 2") - assert BaseType.infer_expression_type(e) == BaseType.NUMBER + assert infer_expression_type(e) == BaseType.NUMBER def test_if_numbers2(self): e = cst.parse_expression("float(1) if True else len([1,2])") - assert BaseType.infer_expression_type(e) == BaseType.NUMBER + assert infer_expression_type(e) == BaseType.NUMBER