From 4e496e77538ae63ef65a62371a56aa69f39b614a Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:46:26 -0300 Subject: [PATCH 01/13] Initial implementation of SQLParameterization codemod --- src/core_codemods/__init__.py | 2 + src/core_codemods/sql_parameterization.py | 424 ++++++++++++++++++++++ tests/samples/sql_injection.py | 11 + 3 files changed, 437 insertions(+) create mode 100644 src/core_codemods/sql_parameterization.py create mode 100644 tests/samples/sql_injection.py diff --git a/src/core_codemods/__init__.py b/src/core_codemods/__init__.py index fed02aa5..b13081a0 100644 --- a/src/core_codemods/__init__.py +++ b/src/core_codemods/__init__.py @@ -1,4 +1,5 @@ from codemodder.registry import CodemodCollection +from core_codemods.sql_parameterization import SQLQueryParameterization from .django_debug_flag_on import DjangoDebugFlagOn from .django_session_cookie_secure_off import DjangoSessionCookieSecureOff @@ -52,5 +53,6 @@ UrlSandbox, UseWalrusIf, WithThreadingLock, + SQLQueryParameterization, ], ) diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py new file mode 100644 index 00000000..1be8b270 --- /dev/null +++ b/src/core_codemods/sql_parameterization.py @@ -0,0 +1,424 @@ +import re +from typing import Optional, Tuple, Union +import libcst as cst +from libcst import CSTTransformer, SimpleWhitespace, matchers +from libcst.codemod import Codemod, CodemodContext, ContextAwareVisitor +from libcst.metadata import ( + ClassScope, + GlobalScope, + ParentNodeProvider, + PositionProvider, + ScopeProvider, +) + +from codemodder.codemods.base_codemod import ( + BaseCodemod, + CodemodMetadata, + ReviewGuidance, +) +from codemodder.codemods.utils import ReplaceNodes +from codemodder.codemods.utils_mixin import NameResolutionMixin + +parameter_token = "?" + +literal_number = matchers.Integer() | matchers.Float() | matchers.Imaginary() +literal_string = matchers.SimpleString() +literal = literal_number | literal_string + + +class SQLQueryParameterization(BaseCodemod, Codemod): + METADATA = CodemodMetadata( + DESCRIPTION=("Parameterize SQL queries."), + NAME="sql-parameterization", + REVIEW_GUIDANCE=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + REFERENCES=[], + ) + SUMMARY = "Parameterize SQL queries." + CHANGE_DESCRIPTION = "" + + METADATA_DEPENDENCIES = ( + PositionProvider, + ScopeProvider, + ParentNodeProvider, + ) + METADATA_DEPENDENCIES = ( + ScopeProvider, + ParentNodeProvider, + ) + + def __init__(self, context: CodemodContext, *codemod_args) -> None: + self.changed_nodes = {} + self.parameters = [] + Codemod.__init__(self, context) + BaseCodemod.__init__(self, *codemod_args) + + def _build_param_element(self, middle, index: int): + # TODO maybe a parameterized string would be better here + # f-strings need python 3.6 though + if index == 0: + return middle[0] + operator = cst.Add( + whitespace_after=cst.SimpleWhitespace(" "), + whitespace_before=cst.SimpleWhitespace(" "), + ) + return cst.BinaryOperation( + operator=operator, + left=self._build_param_element(middle, index - 1), + right=middle[index], + ) + + def transform_module_impl(self, tree: cst.Module) -> cst.Module: + find_queries = FindQueryCalls(self.context) + tree.visit(find_queries) + + for call, query in find_queries.calls.items(): + ep = ExtractParameters(self.context, query) + tree.visit(ep) + params_elements: list[cst.Element] = [] + for start, middle, end in ep.injection_patterns: + if len(middle) == 1: + element_wrap = cst.Element( + value=middle[0], + comma=cst.Comma(whitespace_after=SimpleWhitespace(" ")), + ) + params_elements.append(element_wrap) + else: + # TODO support for elements from f-strings? + # reminder that python has no implicit conversion while concatenating with +, might need to use str() for a particular expression + expr = self._build_param_element(middle, len(middle) - 1) + params_elements.append(cst.Element(value=expr, comma=cst.Comma())) + self._fix_injection(start, middle, end) + # TODO research if named parameters are widely supported + # it could solve for the case of existing parameters + tuple_arg = cst.Arg(cst.Tuple(elements=params_elements)) + self.changed_nodes[call] = call.with_changes(args=[*call.args, tuple_arg]) + if self.changed_nodes: + result = tree.visit(ReplaceNodes(self.changed_nodes)) + return result.visit(RemoveEmptyStringConcatenation()) + return tree + + def _fix_injection( + self, start: cst.CSTNode, middle: list[cst.CSTNode], end: cst.CSTNode + ): + for expr in middle: + self.changed_nodes[expr] = cst.parse_expression('""') + # remove quote literal from end + match end: + # TODO test with escaped strings here... + case cst.SimpleString(): + current_end = self.changed_nodes.get(end) or end + new_raw_value = current_end.raw_value[1:] + new_value = ( + current_end.prefix + + current_end.quote + + new_raw_value + + current_end.quote + ) + self.changed_nodes[end] = current_end.with_changes(value=new_value) + case cst.FormattedStringText(): + # TODO formatted string case + pass + + # remove quote literal from start + match start: + case cst.SimpleString(): + current_start = self.changed_nodes.get(start) or start + new_raw_value = current_start.raw_value[:-1] + parameter_token + new_value = ( + current_start.prefix + + current_start.quote + + new_raw_value + + current_start.quote + ) + self.changed_nodes[start] = current_start.with_changes(value=new_value) + case cst.FormattedStringText(): + # TODO formatted string case + pass + + +class LinearizeQuery(ContextAwareVisitor, NameResolutionMixin): + METADATA_DEPENDENCIES = (ParentNodeProvider,) + + def __init__(self, context) -> None: + super().__init__(context) + self.leaves: list[cst.CSTNode] = [] + + def on_visit(self, node: cst.CSTNode): + # TODO function to detect if BinaryExpression results in a number or list? + # will it only matter inside fstrings? (outside it, we expect query to be string) + # check if any is a string should be necessary + + # We only care about expressions, ignore everything else + # Mostly as a sanity check, this may not be necessary since we start the visit with an expression node + if isinstance( + node, + ( + cst.BaseExpression, + cst.FormattedStringExpression, + cst.FormattedStringText, + ), + ): + # These will be the only types that will be properly visited + if not matchers.matches( + node, + matchers.Name() + | matchers.BinaryOperation() + | matchers.FormattedString() + | matchers.FormattedStringExpression() + | matchers.ConcatenatedString(), + ): + self.leaves.append(node) + else: + return super().on_visit(node) + return False + + # recursive search + def visit_Name(self, node: cst.Name) -> Optional[bool]: + self.leaves.extend(self.recurse_Name(node)) + + def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]: + self.leaves.append(node) + + def recurse_Name(self, node: cst.Name) -> list[cst.CSTNode]: + assignment = self.find_single_assignment(node) + if assignment: + base_scope = assignment.scope + # TODO make this check in detect injection, to be more precise + + # Ensure that this variable is not used anywhere else + # variables used in the global scope / class scope may be referenced in other files + if ( + not isinstance(base_scope, GlobalScope) + and not isinstance(base_scope, ClassScope) + and len(assignment.references) == 1 + ): + maybe_gparent = self._find_gparent(assignment.node) + if gparent := maybe_gparent: + match gparent: + case cst.AnnAssign() | cst.Assign(): + if gparent.value: + gparent_scope = self.get_metadata( + ScopeProvider, gparent + ) + if gparent_scope and gparent_scope == base_scope: + visitor = LinearizeQuery(self.context) + gparent.value.visit(visitor) + return visitor.leaves + return [node] + + def recurse_Attribute(self, node: cst.Attribute) -> list[cst.CSTNode]: + # TODO attributes may have been assigned, should those be modified? + return [node] + + def _find_gparent(self, n: cst.CSTNode) -> Optional[cst.CSTNode]: + gparent = None + try: + parent = self.get_metadata(ParentNodeProvider, n) + gparent = self.get_metadata(ParentNodeProvider, parent) + except Exception: + pass + return gparent + + +class ExtractParameters(ContextAwareVisitor): + METADATA_DEPENDENCIES = ( + ScopeProvider, + ParentNodeProvider, + ) + + quote_pattern = re.compile(r"(? None: + self.query: list[cst.CSTNode] = query + self.injection_patterns: list[ + Tuple[cst.CSTNode, list[cst.CSTNode], cst.CSTNode] + ] = [] + super().__init__(context) + + def leave_Module(self, tree: cst.Module): + leaves = list(reversed(self.query)) + # treat it as a stack + modulo_2 = 1 + while leaves: + # search for the literal start + start = leaves.pop() + if not self._is_literal_start(start, modulo_2): + continue + middle = [] + # gather expressions until the literal ends + while leaves and not self._is_literal_end(leaves[-1]): + middle.append(leaves.pop()) + # could not find the literal end + if not leaves: + break + end = leaves.pop() + if any(map(self._is_injectable, middle)): + self.injection_patterns.append((start, middle, end)) + # end may contain the start of anothe literal, put it back + # should not be a single quote + # TODO think of a better solution here + if self._is_not_a_single_quote(end): + modulo_2 = 0 + leaves.append(end) + else: + modulo_2 = 1 + + # TODO use changed nodes to detect if start has already been modified before + # this can happen if start = end of another expression + + def _is_not_a_single_quote(self, expression: cst.CSTNode) -> bool: + match expression: + case cst.SimpleString(): + prefix = expression.prefix.lower() + if "b" in prefix: + return False + if "r" in prefix: + return ( + self.raw_quote_pattern.fullmatch(expression.raw_value) is None + ) + return self.quote_pattern.fullmatch(expression.raw_value) is None + return True + + def _is_injectable(self, expression: cst.CSTNode) -> bool: + # TODO exceptions + # tuple and list literals ??? + # BinaryExpression case + match expression: + case cst.Integer() | cst.Float() | cst.Imaginary() | cst.SimpleString(): + return False + case cst.Call(func=cst.Name(value="str"), args=[arg, *_]): + # TODO + # treat str(encoding = 'utf-8', object=obj) + # ensure this is the built-in + if matchers.matches(arg, literal_number): + return False + case cst.FormattedStringExpression() if matchers.matches( + expression, literal + ): + return False + case cst.IfExp(): + return self._is_injectable(expression.body) or self._is_injectable( + expression.orelse + ) + return True + + def _is_literal_start(self, node: cst.CSTNode, modulo_2: int) -> bool: + match node: + case cst.SimpleString(): + prefix = node.prefix.lower() + if "b" in prefix: + return False + if "r" in prefix: + matches = list(self.raw_quote_pattern.finditer(node.raw_value)) + else: + matches = list(self.quote_pattern.finditer(node.raw_value)) + # avoid cases like: "where name = 'foo\\\'s name'" + # don't count \\' as these are escaped in string literals + return ( + matches[-1].end() == len(node.raw_value) + if matches and len(matches) % 2 == modulo_2 + else False + ) + case cst.FormattedStringText(): + # TODO may be in the middle i.e. f"name='home_{exp}'" + # be careful of f"name='literal'", it needs one but not two + return False + return False + + def _is_literal_end(self, node: cst.CSTNode) -> bool: + match node: + case cst.SimpleString(): + if "b" in node.prefix: + return False + if "r" in node.prefix: + matches = list(self.raw_quote_pattern.finditer(node.raw_value)) + else: + matches = list(self.quote_pattern.finditer(node.raw_value)) + return matches[0].start() == 0 if matches else False + case cst.FormattedStringText(): + # TODO may be in the middle i.e. f"'{exp}_home'" + # be careful of f"name='literal'", it needs one but not two + return False + return False + + +class FindQueryCalls(ContextAwareVisitor): + # right now it works by looking into some sql keywords in any pieces of the query + # Ideally we should infer what driver we are using + sql_keywords: list[str] = ["insert", "select", "delete", "create", "alter", "drop"] + + def __init__(self, context: CodemodContext) -> None: + self.calls: dict = {} + super().__init__(context) + + def _has_keyword(self, string: str) -> bool: + for keyword in self.sql_keywords: + if keyword in string.lower(): + return True + return False + + def leave_Call(self, original_node: cst.Call) -> None: + maybe_call_name = _get_function_name_node(original_node) + if maybe_call_name and maybe_call_name.value == "execute": + # TODO don't parameterize if there are parameters already + # may be temporary until I figure out if named parameter will work on most drivers + if len(original_node.args) > 0 and len(original_node.args) < 2: + first_arg = original_node.args[0] if original_node.args else None + if first_arg: + query_visitor = LinearizeQuery(self.context) + first_arg.value.visit(query_visitor) + for expr in query_visitor.leaves: + match expr: + case cst.SimpleString() | cst.FormattedStringText() if self._has_keyword( + expr.value + ): + self.calls[original_node] = query_visitor.leaves + + +def _get_function_name_node(call: cst.Call) -> Optional[cst.Name]: + match call.func: + case cst.Name(): + return call.func + case cst.Attribute(): + return call.func.attr + return None + + +class RemoveEmptyStringConcatenation(CSTTransformer): + """ + Removes concatenation with empty strings (e.g. "hello " + "") or "hello" "" + """ + + # TODO What about empty f-strings? they are a different type of node + # may not be necessary if handled correctly + def leave_BinaryOperation( + self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation + ) -> cst.BaseExpression: + return self.handle_node(updated_node) + + def leave_ConcatenatedString( + self, + original_node: cst.ConcatenatedString, + updated_node: cst.ConcatenatedString, + ) -> cst.BaseExpression: + return self.handle_node(updated_node) + + def handle_node( + self, updated_node: Union[cst.BinaryOperation, cst.ConcatenatedString] + ) -> cst.BaseExpression: + match updated_node.left: + case cst.SimpleString() if updated_node.left.raw_value == "": + match updated_node.right: + case cst.SimpleString() if updated_node.right.raw_value == "": + return cst.SimpleString(value='""') + case _: + return updated_node.right + match updated_node.right: + case cst.SimpleString() if updated_node.right.raw_value == "": + match updated_node.left: + case cst.SimpleString() if updated_node.left.raw_value == "": + return cst.SimpleString(value='""') + case _: + return updated_node.left + return updated_node diff --git a/tests/samples/sql_injection.py b/tests/samples/sql_injection.py new file mode 100644 index 00000000..2ec8e716 --- /dev/null +++ b/tests/samples/sql_injection.py @@ -0,0 +1,11 @@ +import sqlite3 + +connection = sqlite3.connect("my_db.db") + +def foo(cursor: sqlite3.Cursor, name: str, phone:str): + a = "SELECT * FROM Users" + b = "WHERE name ='" + name + c = "' AND phone = '" + phone + "'" + cursor.execute(a + b + c) + +foo(connection.cursor(), 'Jenny', '867-5309') From 9dfc54f850bb230fa21c28c334747cef302d7f5b Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:49:35 -0300 Subject: [PATCH 02/13] Linting --- src/core_codemods/sql_parameterization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 1be8b270..d1ed2e9e 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -236,7 +236,7 @@ def __init__(self, context: CodemodContext, query: list[cst.CSTNode]) -> None: ] = [] super().__init__(context) - def leave_Module(self, tree: cst.Module): + def leave_Module(self, original_node: cst.Module): leaves = list(reversed(self.query)) # treat it as a stack modulo_2 = 1 From fd150343d92b4ea7212091eb4c59153b3991db45 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Fri, 20 Oct 2023 08:16:06 -0300 Subject: [PATCH 03/13] Added integration test for SQLQueryParameterization --- .../test_sql_parameterization.py | 39 +++++++++++++++++ src/core_codemods/sql_parameterization.py | 41 +++++++++++------- tests/samples/my_db.db | Bin 0 -> 8192 bytes 3 files changed, 64 insertions(+), 16 deletions(-) create mode 100644 integration_tests/test_sql_parameterization.py create mode 100644 tests/samples/my_db.db diff --git a/integration_tests/test_sql_parameterization.py b/integration_tests/test_sql_parameterization.py new file mode 100644 index 00000000..c059bd3d --- /dev/null +++ b/integration_tests/test_sql_parameterization.py @@ -0,0 +1,39 @@ +from core_codemods.sql_parameterization import SQLQueryParameterization +from integration_tests.base_test import ( + BaseIntegrationTest, + original_and_expected_from_code_path, +) + + +class TestSQLQueryParameterization(BaseIntegrationTest): + codemod = SQLQueryParameterization + code_path = "tests/samples/sql_injection.py" + original_code, expected_new_code = original_and_expected_from_code_path( + code_path, + [ + (6, """ b = " WHERE name =?"\n"""), + (7, """ c = " AND phone = ?"\n""" ), + (8, """ r = cursor.execute(a + b + c, (name, phone, ))\n"""), + ], + ) + + expected_diff ="""\ +--- ++++ +@@ -4,9 +4,9 @@ + + def foo(cursor: sqlite3.Cursor, name: str, phone:str): + a = "SELECT * FROM Users" +- b = " WHERE name ='" + name +- c = "' AND phone = '" + phone + "'" +- r = cursor.execute(a + b + c) ++ b = " WHERE name =?" ++ c = " AND phone = ?" ++ r = cursor.execute(a + b + c, (name, phone, )) + print(r.fetchone()) + + foo(connection.cursor(), 'Jenny', '867-5309') +""" + expected_line_change = "9" + change_description = SQLQueryParameterization.CHANGE_DESCRIPTION + num_changed_files = 1 diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index d1ed2e9e..97d7cc73 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -10,6 +10,7 @@ PositionProvider, ScopeProvider, ) +from codemodder.change import Change from codemodder.codemods.base_codemod import ( BaseCodemod, @@ -44,6 +45,7 @@ class SQLQueryParameterization(BaseCodemod, Codemod): METADATA_DEPENDENCIES = ( ScopeProvider, ParentNodeProvider, + PositionProvider, ) def __init__(self, context: CodemodContext, *codemod_args) -> None: @@ -71,31 +73,38 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: find_queries = FindQueryCalls(self.context) tree.visit(find_queries) + result = tree for call, query in find_queries.calls.items(): ep = ExtractParameters(self.context, query) tree.visit(ep) + + # build tuple elements and fix injection params_elements: list[cst.Element] = [] for start, middle, end in ep.injection_patterns: - if len(middle) == 1: - element_wrap = cst.Element( - value=middle[0], - comma=cst.Comma(whitespace_after=SimpleWhitespace(" ")), - ) - params_elements.append(element_wrap) - else: - # TODO support for elements from f-strings? - # reminder that python has no implicit conversion while concatenating with +, might need to use str() for a particular expression - expr = self._build_param_element(middle, len(middle) - 1) - params_elements.append(cst.Element(value=expr, comma=cst.Comma())) + # TODO support for elements from f-strings? + # reminder that python has no implicit conversion while concatenating with +, might need to use str() for a particular expression + expr = self._build_param_element(middle, len(middle) - 1) + params_elements.append(cst.Element(value=expr, comma=cst.Comma(whitespace_after=SimpleWhitespace(" ")))) self._fix_injection(start, middle, end) + # TODO research if named parameters are widely supported # it could solve for the case of existing parameters tuple_arg = cst.Arg(cst.Tuple(elements=params_elements)) self.changed_nodes[call] = call.with_changes(args=[*call.args, tuple_arg]) - if self.changed_nodes: - result = tree.visit(ReplaceNodes(self.changed_nodes)) - return result.visit(RemoveEmptyStringConcatenation()) - return tree + + # made changes + if self.changed_nodes: + result = result.visit(ReplaceNodes(self.changed_nodes)) + self.changed_nodes = {} + line_number = self.get_metadata(PositionProvider, call).start.line + self.file_context.codemod_changes.append( + Change( + str(line_number), SQLQueryParameterization.CHANGE_DESCRIPTION + ).to_json() + ) + result = result.visit(RemoveEmptyStringConcatenation()) + + return result def _fix_injection( self, start: cst.CSTNode, middle: list[cst.CSTNode], end: cst.CSTNode @@ -258,7 +267,7 @@ def leave_Module(self, original_node: cst.Module): # end may contain the start of anothe literal, put it back # should not be a single quote # TODO think of a better solution here - if self._is_not_a_single_quote(end): + if self._is_literal_start(end, 0) and self._is_not_a_single_quote(end): modulo_2 = 0 leaves.append(end) else: diff --git a/tests/samples/my_db.db b/tests/samples/my_db.db new file mode 100644 index 0000000000000000000000000000000000000000..9a18d91146cc9f9a981c9c7312704910f3d639e0 GIT binary patch literal 8192 zcmeI#y$ZrG5C`x}6*pqxA_@hlQW%hfVb z;wsZ!EK|1|1YT2|?fK4 literal 0 HcmV?d00001 From d1af4faf6d8f84eb4ec8d4614845bc64c9137b7f Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Fri, 20 Oct 2023 08:42:04 -0300 Subject: [PATCH 04/13] Linting --- src/core_codemods/sql_parameterization.py | 24 +++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 97d7cc73..af6554fe 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -1,5 +1,5 @@ import re -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import libcst as cst from libcst import CSTTransformer, SimpleWhitespace, matchers from libcst.codemod import Codemod, CodemodContext, ContextAwareVisitor @@ -49,12 +49,13 @@ class SQLQueryParameterization(BaseCodemod, Codemod): ) def __init__(self, context: CodemodContext, *codemod_args) -> None: - self.changed_nodes = {} - self.parameters = [] + self.changed_nodes: dict[ + cst.CSTNode, cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel + ] = {} Codemod.__init__(self, context) BaseCodemod.__init__(self, *codemod_args) - def _build_param_element(self, middle, index: int): + def _build_param_element(self, middle, index: int) -> cst.BaseExpression: # TODO maybe a parameterized string would be better here # f-strings need python 3.6 though if index == 0: @@ -84,14 +85,19 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: # TODO support for elements from f-strings? # reminder that python has no implicit conversion while concatenating with +, might need to use str() for a particular expression expr = self._build_param_element(middle, len(middle) - 1) - params_elements.append(cst.Element(value=expr, comma=cst.Comma(whitespace_after=SimpleWhitespace(" ")))) + params_elements.append( + cst.Element( + value=expr, + comma=cst.Comma(whitespace_after=SimpleWhitespace(" ")), + ) + ) self._fix_injection(start, middle, end) # TODO research if named parameters are widely supported # it could solve for the case of existing parameters tuple_arg = cst.Arg(cst.Tuple(elements=params_elements)) self.changed_nodes[call] = call.with_changes(args=[*call.args, tuple_arg]) - + # made changes if self.changed_nodes: result = result.visit(ReplaceNodes(self.changed_nodes)) @@ -184,9 +190,11 @@ def on_visit(self, node: cst.CSTNode): # recursive search def visit_Name(self, node: cst.Name) -> Optional[bool]: self.leaves.extend(self.recurse_Name(node)) + return False def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]: self.leaves.append(node) + return False def recurse_Name(self, node: cst.Name) -> list[cst.CSTNode]: assignment = self.find_single_assignment(node) @@ -300,7 +308,7 @@ def _is_injectable(self, expression: cst.CSTNode) -> bool: # TODO # treat str(encoding = 'utf-8', object=obj) # ensure this is the built-in - if matchers.matches(arg, literal_number): + if matchers.matches(arg, literal_number): # type: ignore return False case cst.FormattedStringExpression() if matchers.matches( expression, literal @@ -414,7 +422,7 @@ def leave_ConcatenatedString( return self.handle_node(updated_node) def handle_node( - self, updated_node: Union[cst.BinaryOperation, cst.ConcatenatedString] + self, updated_node: cst.BinaryOperation | cst.ConcatenatedString ) -> cst.BaseExpression: match updated_node.left: case cst.SimpleString() if updated_node.left.raw_value == "": From 975b37aa1e851c3a9f11687fc18e4149af3816bf Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Fri, 20 Oct 2023 10:03:21 -0300 Subject: [PATCH 05/13] Fixed SQLQueryParameterization integration test --- .../test_sql_parameterization.py | 44 ++++++++++--------- tests/samples/sql_injection.py | 13 +++--- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/integration_tests/test_sql_parameterization.py b/integration_tests/test_sql_parameterization.py index c059bd3d..6e8cd9da 100644 --- a/integration_tests/test_sql_parameterization.py +++ b/integration_tests/test_sql_parameterization.py @@ -11,29 +11,31 @@ class TestSQLQueryParameterization(BaseIntegrationTest): original_code, expected_new_code = original_and_expected_from_code_path( code_path, [ - (6, """ b = " WHERE name =?"\n"""), - (7, """ c = " AND phone = ?"\n""" ), - (8, """ r = cursor.execute(a + b + c, (name, phone, ))\n"""), + (7, """ b = " WHERE name =?"\n"""), + (8, """ c = " AND phone = ?"\n"""), + (9, """ r = cursor.execute(a + b + c, (name, phone, ))\n"""), ], ) - expected_diff ="""\ ---- -+++ -@@ -4,9 +4,9 @@ - - def foo(cursor: sqlite3.Cursor, name: str, phone:str): - a = "SELECT * FROM Users" -- b = " WHERE name ='" + name -- c = "' AND phone = '" + phone + "'" -- r = cursor.execute(a + b + c) -+ b = " WHERE name =?" -+ c = " AND phone = ?" -+ r = cursor.execute(a + b + c, (name, phone, )) - print(r.fetchone()) - - foo(connection.cursor(), 'Jenny', '867-5309') -""" - expected_line_change = "9" + # fmt: off + expected_diff =( + """--- \n""" + """+++ \n""" + """@@ -5,9 +5,9 @@\n""" + """ \n""" + """ def foo(cursor: sqlite3.Cursor, name: str, phone: str):\n""" + """ a = "SELECT * FROM Users"\n""" + """- b = " WHERE name ='" + name\n""" + """- c = "' AND phone = '" + phone + "'"\n""" + """- r = cursor.execute(a + b + c)\n""" + """+ b = " WHERE name =?"\n""" + """+ c = " AND phone = ?"\n""" + """+ r = cursor.execute(a + b + c, (name, phone, ))\n""" + """ print(r.fetchone())\n""" + """ \n""" + """ \n""") + # fmt: on + + expected_line_change = "10" change_description = SQLQueryParameterization.CHANGE_DESCRIPTION num_changed_files = 1 diff --git a/tests/samples/sql_injection.py b/tests/samples/sql_injection.py index 2ec8e716..483aa8ee 100644 --- a/tests/samples/sql_injection.py +++ b/tests/samples/sql_injection.py @@ -1,11 +1,14 @@ import sqlite3 -connection = sqlite3.connect("my_db.db") +connection = sqlite3.connect("tests/samples/my_db.db") -def foo(cursor: sqlite3.Cursor, name: str, phone:str): + +def foo(cursor: sqlite3.Cursor, name: str, phone: str): a = "SELECT * FROM Users" - b = "WHERE name ='" + name + b = " WHERE name ='" + name c = "' AND phone = '" + phone + "'" - cursor.execute(a + b + c) + r = cursor.execute(a + b + c) + print(r.fetchone()) + -foo(connection.cursor(), 'Jenny', '867-5309') +foo(connection.cursor(), "Jenny", "867-5309") From 769d437956d9861583713635035abef702ac5c69 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Fri, 20 Oct 2023 11:31:09 -0300 Subject: [PATCH 06/13] Fixed a bug where the query wasn't being updated Added a couple unit tests --- src/core_codemods/sql_parameterization.py | 60 +++++++++++++++++++-- tests/codemods/test_sql_parameterization.py | 30 +++++++++++ 2 files changed, 87 insertions(+), 3 deletions(-) create mode 100644 tests/codemods/test_sql_parameterization.py diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index af6554fe..7f82c42a 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -1,5 +1,5 @@ import re -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import libcst as cst from libcst import CSTTransformer, SimpleWhitespace, matchers from libcst.codemod import Codemod, CodemodContext, ContextAwareVisitor @@ -17,7 +17,6 @@ CodemodMetadata, ReviewGuidance, ) -from codemodder.codemods.utils import ReplaceNodes from codemodder.codemods.utils_mixin import NameResolutionMixin parameter_token = "?" @@ -27,6 +26,57 @@ literal = literal_number | literal_string +class Append: + def __init__(self, sequence: list[cst.CSTNode]) -> None: + self.sequence = sequence + + +class Prepend: + def __init__(self, sequence: list[cst.CSTNode]) -> None: + self.sequence = sequence + + +class ReplaceNodes(cst.CSTTransformer): + """ + Replace nodes with their corresponding values in a given dict. + You can replace the entire node, some attributes of it via a dict(). Addionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, repectivelly a sequence. + """ + + def __init__( + self, + replacements: dict[ + cst.CSTNode, + dict[str, cst.CSTNode | dict[str, Any]], + ], + ): + self.replacements = replacements + + def on_leave(self, original_node, updated_node): + if original_node in self.replacements.keys(): + replacement = self.replacements[original_node] + match replacement: + case dict(): + changes_dict = {} + for key, value in replacement.items(): + match value: + case Prepend(): + changes_dict[key] = value.sequence + [ + *getattr(updated_node, key) + ] + + case Append(): + changes_dict[key] = [ + *getattr(updated_node, key) + ] + value.sequence + print(changes_dict) + case _: + changes_dict[key] = value + return updated_node.with_changes(**changes_dict) + case cst.CSTNode(): + return replacement + return updated_node + + class SQLQueryParameterization(BaseCodemod, Codemod): METADATA = CodemodMetadata( DESCRIPTION=("Parameterize SQL queries."), @@ -96,7 +146,8 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: # TODO research if named parameters are widely supported # it could solve for the case of existing parameters tuple_arg = cst.Arg(cst.Tuple(elements=params_elements)) - self.changed_nodes[call] = call.with_changes(args=[*call.args, tuple_arg]) + # self.changed_nodes[call] = call.with_changes(args=[*call.args, tuple_arg]) + self.changed_nodes[call] = {"args": Append([tuple_arg])} # made changes if self.changed_nodes: @@ -112,6 +163,9 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: return result + # Replace node entirelly -> CSTNode + # update some attribute -> dict[str,any] + # apend or prepend an argument -> Append([]) / Preprend([]) def _fix_injection( self, start: cst.CSTNode, middle: list[cst.CSTNode], end: cst.CSTNode ): diff --git a/tests/codemods/test_sql_parameterization.py b/tests/codemods/test_sql_parameterization.py new file mode 100644 index 00000000..9eddd8a1 --- /dev/null +++ b/tests/codemods/test_sql_parameterization.py @@ -0,0 +1,30 @@ +from core_codemods.sql_parameterization import SQLQueryParameterization +from tests.codemods.base_codemod_test import BaseCodemodTest +from textwrap import dedent + + +class TestSQLQueryParameterization(BaseCodemodTest): + codemod = SQLQueryParameterization + + def test_name(self): + assert self.codemod.name() == "sql-parameterization" + + def test_simple(self, tmpdir): + input_code = """\ + import sqlite3 + from a import name + + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS where name ='" + name + "'") + """ + expected = """\ + import sqlite3 + from a import name + + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS where name =?", (name, )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 From d9bb311753534142443843145ba5296bea8fed4c Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Mon, 23 Oct 2023 09:19:53 -0300 Subject: [PATCH 07/13] Tests, docs and bugfixes --- .../docs/pixee_python_sql-parameterization.md | 23 ++ src/core_codemods/sql_parameterization.py | 39 ++-- tests/codemods/test_sql_parameterization.py | 207 +++++++++++++++++- 3 files changed, 252 insertions(+), 17 deletions(-) create mode 100644 src/core_codemods/docs/pixee_python_sql-parameterization.md diff --git a/src/core_codemods/docs/pixee_python_sql-parameterization.md b/src/core_codemods/docs/pixee_python_sql-parameterization.md new file mode 100644 index 00000000..846d0147 --- /dev/null +++ b/src/core_codemods/docs/pixee_python_sql-parameterization.md @@ -0,0 +1,23 @@ +This codemod refactors SQL statements to be parameterized, rather than built by hand. + +Without parameterization, developers must remember to escape string inputs using the rules for that column type and database. This usually results in bugs -- and sometimes vulnerability. Although it's not clear if this code is exploitable today, this change will make the code more robust in case the conditions which prevent exploitation today ever go away. + +Our changes look something like this: + +```diff +import sqlite3 + +name = input() +connection = sqlite3.connect("my_db.db") +cursor = connection.cursor() +- cursor.execute("SELECT * from USERS WHERE name ='" + name + "'") ++ cursor.execute("SELECT * from USERS WHERE name =?", (name, )) +``` + +If you have feedback on this codemod, [please let us know](mailto:feedback@pixee.ai)! + +## F.A.Q. + +### Why is this codemod marked as Merge With Cursory Review + +Python has a wealth of database drivers that all use the same interface. Different drivers may require different string tokens used for parameterization, and Python's dynamic typing makes it quite hard, and sometimes impossible, to detect which driver is being used just by looking at the code. diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 7f82c42a..7ee49bb9 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -68,7 +68,6 @@ def on_leave(self, original_node, updated_node): changes_dict[key] = [ *getattr(updated_node, key) ] + value.sequence - print(changes_dict) case _: changes_dict[key] = value return updated_node.with_changes(**changes_dict) @@ -82,7 +81,16 @@ class SQLQueryParameterization(BaseCodemod, Codemod): DESCRIPTION=("Parameterize SQL queries."), NAME="sql-parameterization", REVIEW_GUIDANCE=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, - REFERENCES=[], + REFERENCES=[ + { + "url": "https://cwe.mitre.org/data/definitions/89.html", + "description": "", + }, + { + "url": "https://owasp.org/www-community/attacks/SQL_Injection", + "description": "", + }, + ], ) SUMMARY = "Parameterize SQL queries." CHANGE_DESCRIPTION = "" @@ -145,9 +153,10 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: # TODO research if named parameters are widely supported # it could solve for the case of existing parameters - tuple_arg = cst.Arg(cst.Tuple(elements=params_elements)) - # self.changed_nodes[call] = call.with_changes(args=[*call.args, tuple_arg]) - self.changed_nodes[call] = {"args": Append([tuple_arg])} + if params_elements: + tuple_arg = cst.Arg(cst.Tuple(elements=params_elements)) + # self.changed_nodes[call] = call.with_changes(args=[*call.args, tuple_arg]) + self.changed_nodes[call] = {"args": Append([tuple_arg])} # made changes if self.changed_nodes: @@ -173,10 +182,12 @@ def _fix_injection( self.changed_nodes[expr] = cst.parse_expression('""') # remove quote literal from end match end: - # TODO test with escaped strings here... case cst.SimpleString(): current_end = self.changed_nodes.get(end) or end - new_raw_value = current_end.raw_value[1:] + if current_end.raw_value.startswith("\\'"): + new_raw_value = current_end.raw_value[2:] + else: + new_raw_value = current_end.raw_value[1:] new_value = ( current_end.prefix + current_end.quote @@ -192,7 +203,10 @@ def _fix_injection( match start: case cst.SimpleString(): current_start = self.changed_nodes.get(start) or start - new_raw_value = current_start.raw_value[:-1] + parameter_token + if current_start.raw_value.endswith("\\'"): + new_raw_value = current_start.raw_value[:-2] + parameter_token + else: + new_raw_value = current_start.raw_value[:-1] + parameter_token new_value = ( current_start.prefix + current_start.quote @@ -279,6 +293,7 @@ def recurse_Name(self, node: cst.Name) -> list[cst.CSTNode]: def recurse_Attribute(self, node: cst.Attribute) -> list[cst.CSTNode]: # TODO attributes may have been assigned, should those be modified? + # research how to detect attribute assigns in libcst return [node] def _find_gparent(self, n: cst.CSTNode) -> Optional[cst.CSTNode]: @@ -328,6 +343,7 @@ def leave_Module(self, original_node: cst.Module): self.injection_patterns.append((start, middle, end)) # end may contain the start of anothe literal, put it back # should not be a single quote + # TODO think of a better solution here if self._is_literal_start(end, 0) and self._is_not_a_single_quote(end): modulo_2 = 0 @@ -335,9 +351,6 @@ def leave_Module(self, original_node: cst.Module): else: modulo_2 = 1 - # TODO use changed nodes to detect if start has already been modified before - # this can happen if start = end of another expression - def _is_not_a_single_quote(self, expression: cst.CSTNode) -> bool: match expression: case cst.SimpleString(): @@ -358,11 +371,11 @@ def _is_injectable(self, expression: cst.CSTNode) -> bool: match expression: case cst.Integer() | cst.Float() | cst.Imaginary() | cst.SimpleString(): return False - case cst.Call(func=cst.Name(value="str"), args=[arg, *_]): + case cst.Call(func=cst.Name(value="str"), args=[cst.Arg(value=arg), *_]): # TODO # treat str(encoding = 'utf-8', object=obj) # ensure this is the built-in - if matchers.matches(arg, literal_number): # type: ignore + if matchers.matches(arg, literal): # type: ignore return False case cst.FormattedStringExpression() if matchers.matches( expression, literal diff --git a/tests/codemods/test_sql_parameterization.py b/tests/codemods/test_sql_parameterization.py index 9eddd8a1..69d711bd 100644 --- a/tests/codemods/test_sql_parameterization.py +++ b/tests/codemods/test_sql_parameterization.py @@ -12,19 +12,218 @@ def test_name(self): def test_simple(self, tmpdir): input_code = """\ import sqlite3 - from a import name + name = input() connection = sqlite3.connect("my_db.db") cursor = connection.cursor() - cursor.execute("SELECT * from USERS where name ='" + name + "'") + cursor.execute("SELECT * from USERS WHERE name ='" + name + "'") """ expected = """\ import sqlite3 - from a import name + name = input() connection = sqlite3.connect("my_db.db") cursor = connection.cursor() - cursor.execute("SELECT * from USERS where name =?", (name, )) + cursor.execute("SELECT * from USERS WHERE name =?", (name, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 + + def test_multiple(self, tmpdir): + input_code = """\ + import sqlite3 + + name = input() + phone = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name ='" + name + "' AND phone ='" + phone + "'" ) + """ + expected = """\ + import sqlite3 + + name = input() + phone = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name =?" + " AND phone =?", (name, phone, )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_can_deal_with_multiple_variables(self, tmpdir): + input_code = """\ + import sqlite3 + + def foo(self, cursor, name, phone): + + a = "SELECT * from USERS " + b = "WHERE name = '" + name + c = "' AND phone = '" + phone + "'" + return cursor.execute(a + b + c) + """ + + expected = """\ + import sqlite3 + + def foo(self, cursor, name, phone): + + a = "SELECT * from USERS " + b = "WHERE name = ?" + c = " AND phone = ?" + return cursor.execute(a + b + c, (name, phone, )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_simple_if(self, tmpdir): + input_code = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name ='" + ('Jenny' if True else name) + "'") + """ + expected = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name =?", (('Jenny' if True else name), )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_multiple_escaped_quote(self, tmpdir): + input_code = """\ + import sqlite3 + + name = input() + phone = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute('SELECT * from USERS WHERE name =\\'' + name + '\\' AND phone =\\'' + phone + '\\'' ) + """ + expected = """\ + import sqlite3 + + name = input() + phone = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute('SELECT * from USERS WHERE name =?' + ' AND phone =?', (name, phone, )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + # negative tests below + + def test_no_sql_keyword(self, tmpdir): + input_code = """\ + import sqlite3 + + def foo(self, cursor, name, phone): + + a = "COLLECT * from USERS " + b = "WHERE name = '" + name + c = "' AND phone = '" + phone + "'" + return cursor.execute(a + b + c) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 + + def test_multiple_expressions_injection(self, tmpdir): + input_code = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name ='" + name + "_username" + "'") + """ + expected = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name =?", (name + "_username", )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_wont_parameterize_literals(self, tmpdir): + input_code = """\ + import sqlite3 + + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name ='" + str(1234) + "'") + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 + + def test_wont_parameterize_literals_if(self, tmpdir): + input_code = """\ + import sqlite3 + + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name ='" + ('Jenny' if True else 123) + "'") + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 + + def test_will_ignore_escaped_quote(self, tmpdir): + input_code = """\ + import sqlite3 + + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS WHERE name ='Jenny\'s username'") + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 + + def test_already_has_parameters(self, tmpdir): + input_code = """\ + import sqlite3 + + def foo(self, cursor, name, phone): + + a = "SELECT * from USERS " + b = "WHERE name = '" + name + c = "' AND phone = ?" + return cursor.execute(a + b + c, (phone,)) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 + + def test_wont_change_class_attribute(self, tmpdir): + input_code = """\ + import sqlite3 + + + class A(): + + query = "SELECT * from USERS WHERE name ='" + + def foo(self, name, cursor): + return cursor.execute(query + name + "'") + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 + + def test_wont_change_module_variable(self, tmpdir): + input_code = """\ + import sqlite3 + + query = "SELECT * from USERS WHERE name ='" + + def foo(name, cursor): + return cursor.execute(query + name + "'") + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 From ac5d3b9d9612cd1d0b6a8a508bb85fae082d7826 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Mon, 23 Oct 2023 10:31:18 -0300 Subject: [PATCH 08/13] Refactored RemoveEmptyStringsConcatenation Added separated tests --- .../remove_empty_string_concatenation.py | 40 +++++++++ src/core_codemods/sql_parameterization.py | 58 ++++--------- tests/codemods/test_sql_parameterization.py | 33 ++++++++ .../test_remove_empty_string_concatenation.py | 82 +++++++++++++++++++ 4 files changed, 172 insertions(+), 41 deletions(-) create mode 100644 src/codemodder/codemods/transformations/remove_empty_string_concatenation.py create mode 100644 tests/transformations/test_remove_empty_string_concatenation.py diff --git a/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py b/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py new file mode 100644 index 00000000..58bda6b0 --- /dev/null +++ b/src/codemodder/codemods/transformations/remove_empty_string_concatenation.py @@ -0,0 +1,40 @@ +import libcst as cst +from libcst import CSTTransformer + + +class RemoveEmptyStringConcatenation(CSTTransformer): + """ + Removes concatenation with empty strings (e.g. "hello " + "") or "hello" "" + """ + + def leave_BinaryOperation( + self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation + ) -> cst.BaseExpression: + return self.handle_node(updated_node) + + def leave_ConcatenatedString( + self, + original_node: cst.ConcatenatedString, + updated_node: cst.ConcatenatedString, + ) -> cst.BaseExpression: + return self.handle_node(updated_node) + + def handle_node( + self, updated_node: cst.BinaryOperation | cst.ConcatenatedString + ) -> cst.BaseExpression: + match updated_node.left: + # TODO f-string cases + case cst.SimpleString() if updated_node.left.raw_value == "": + match updated_node.right: + case cst.SimpleString() if updated_node.right.raw_value == "": + return cst.SimpleString(value='""') + case _: + return updated_node.right + match updated_node.right: + case cst.SimpleString() if updated_node.right.raw_value == "": + match updated_node.left: + case cst.SimpleString() if updated_node.left.raw_value == "": + return cst.SimpleString(value='""') + case _: + return updated_node.left + return updated_node diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 7ee49bb9..5c25028b 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -1,7 +1,7 @@ import re from typing import Any, Optional, Tuple import libcst as cst -from libcst import CSTTransformer, SimpleWhitespace, matchers +from libcst import SimpleWhitespace, matchers from libcst.codemod import Codemod, CodemodContext, ContextAwareVisitor from libcst.metadata import ( ClassScope, @@ -17,6 +17,9 @@ CodemodMetadata, ReviewGuidance, ) +from codemodder.codemods.transformations.remove_empty_string_concatenation import ( + RemoveEmptyStringConcatenation, +) from codemodder.codemods.utils_mixin import NameResolutionMixin parameter_token = "?" @@ -39,7 +42,7 @@ def __init__(self, sequence: list[cst.CSTNode]) -> None: class ReplaceNodes(cst.CSTTransformer): """ Replace nodes with their corresponding values in a given dict. - You can replace the entire node, some attributes of it via a dict(). Addionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, repectivelly a sequence. + You can replace the entire node or some attributes of it via a dict(). Addionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, respectivelly. """ def __init__( @@ -220,6 +223,10 @@ def _fix_injection( class LinearizeQuery(ContextAwareVisitor, NameResolutionMixin): + """ + Gather all the expressions that are concatenated to build the query. + """ + METADATA_DEPENDENCIES = (ParentNodeProvider,) def __init__(self, context) -> None: @@ -307,6 +314,10 @@ def _find_gparent(self, n: cst.CSTNode) -> Optional[cst.CSTNode]: class ExtractParameters(ContextAwareVisitor): + """ + Detects injections and gather the expressions that are injectable. + """ + METADATA_DEPENDENCIES = ( ScopeProvider, ParentNodeProvider, @@ -428,6 +439,10 @@ def _is_literal_end(self, node: cst.CSTNode) -> bool: class FindQueryCalls(ContextAwareVisitor): + """ + Find all the execute calls with a sql query as an argument. + """ + # right now it works by looking into some sql keywords in any pieces of the query # Ideally we should infer what driver we are using sql_keywords: list[str] = ["insert", "select", "delete", "create", "alter", "drop"] @@ -467,42 +482,3 @@ def _get_function_name_node(call: cst.Call) -> Optional[cst.Name]: case cst.Attribute(): return call.func.attr return None - - -class RemoveEmptyStringConcatenation(CSTTransformer): - """ - Removes concatenation with empty strings (e.g. "hello " + "") or "hello" "" - """ - - # TODO What about empty f-strings? they are a different type of node - # may not be necessary if handled correctly - def leave_BinaryOperation( - self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation - ) -> cst.BaseExpression: - return self.handle_node(updated_node) - - def leave_ConcatenatedString( - self, - original_node: cst.ConcatenatedString, - updated_node: cst.ConcatenatedString, - ) -> cst.BaseExpression: - return self.handle_node(updated_node) - - def handle_node( - self, updated_node: cst.BinaryOperation | cst.ConcatenatedString - ) -> cst.BaseExpression: - match updated_node.left: - case cst.SimpleString() if updated_node.left.raw_value == "": - match updated_node.right: - case cst.SimpleString() if updated_node.right.raw_value == "": - return cst.SimpleString(value='""') - case _: - return updated_node.right - match updated_node.right: - case cst.SimpleString() if updated_node.right.raw_value == "": - match updated_node.left: - case cst.SimpleString() if updated_node.left.raw_value == "": - return cst.SimpleString(value='""') - case _: - return updated_node.left - return updated_node diff --git a/tests/codemods/test_sql_parameterization.py b/tests/codemods/test_sql_parameterization.py index 69d711bd..a972c676 100644 --- a/tests/codemods/test_sql_parameterization.py +++ b/tests/codemods/test_sql_parameterization.py @@ -118,8 +118,41 @@ def test_multiple_escaped_quote(self, tmpdir): self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 + def test_simple_concatenated_strings(self, tmpdir): + input_code = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS" "WHERE name ='" + name + "'") + """ + expected = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS" "WHERE name =?", (name, )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + # negative tests below + def test_formatted_string_simple(self, tmpdir): + # TODO change when we add support for it + input_code = """\ + import sqlite3 + + name = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute(f"SELECT * from USERS WHERE name='{name}'") + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 + def test_no_sql_keyword(self, tmpdir): input_code = """\ import sqlite3 diff --git a/tests/transformations/test_remove_empty_string_concatenation.py b/tests/transformations/test_remove_empty_string_concatenation.py new file mode 100644 index 00000000..cfc6d631 --- /dev/null +++ b/tests/transformations/test_remove_empty_string_concatenation.py @@ -0,0 +1,82 @@ +import libcst as cst +from libcst.codemod import Codemod, CodemodTest + +from codemodder.codemods.transformations.remove_empty_string_concatenation import ( + RemoveEmptyStringConcatenation, +) + + +class RemoveEmptyStringConcatenationCodemod(Codemod): + def transform_module_impl(self, tree: cst.Module) -> cst.Module: + return tree.visit(RemoveEmptyStringConcatenation()) + + +class TestRemoveEmptyStringConcatenation(CodemodTest): + TRANSFORM = RemoveEmptyStringConcatenationCodemod + + def test_left(self): + before = """ + "" + "world" + """ + + after = """ + "world" + """ + + self.assertCodemod(before, after) + + def test_right(self): + before = """ + "hello" + "" + """ + + after = """ + "hello" + """ + + self.assertCodemod(before, after) + + def test_both(self): + before = """ + "" + "" + """ + + after = """ + "" + """ + + self.assertCodemod(before, after) + + def test_concatenated_string_right(self): + before = """ + "hello" "" + """ + + after = """ + "hello" + """ + + self.assertCodemod(before, after) + + def test_concatenated_string_left(self): + before = """ + "world" + """ + + after = """ + "world" + """ + + self.assertCodemod(before, after) + + def test_multiple_mixed(self): + before = ( + """ + "" + '' """ + """ + r'''''' + """ + ) + + after = '""' + + self.assertCodemod(before, after) From 2347422556646cdf31470f8924c5cfd4dc8a4261 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Tue, 24 Oct 2023 08:30:49 -0300 Subject: [PATCH 09/13] Refactor, tests and documentation --- src/codemodder/codemods/utils.py | 11 ++++- .../docs/pixee_python_sql-parameterization.md | 8 ---- src/core_codemods/sql_parameterization.py | 42 +++++++------------ src/scripts/generate_docs.py | 4 ++ tests/codemods/test_sql_parameterization.py | 2 + .../test_remove_empty_string_concatenation.py | 26 ++++++++++++ 6 files changed, 56 insertions(+), 37 deletions(-) diff --git a/src/codemodder/codemods/utils.py b/src/codemodder/codemods/utils.py index f547cc09..2e9d5e1d 100644 --- a/src/codemodder/codemods/utils.py +++ b/src/codemodder/codemods/utils.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Union +from typing import Optional, Union from libcst import matchers import libcst as cst @@ -44,3 +44,12 @@ def get_call_name(call: cst.Call) -> str: return call.func.attr.value # It's a simple Name return call.func.value + + +def get_function_name_node(call: cst.Call) -> Optional[cst.Name]: + match call.func: + case cst.Name(): + return call.func + case cst.Attribute(): + return call.func.attr + return None diff --git a/src/core_codemods/docs/pixee_python_sql-parameterization.md b/src/core_codemods/docs/pixee_python_sql-parameterization.md index 846d0147..b3c8603f 100644 --- a/src/core_codemods/docs/pixee_python_sql-parameterization.md +++ b/src/core_codemods/docs/pixee_python_sql-parameterization.md @@ -13,11 +13,3 @@ cursor = connection.cursor() - cursor.execute("SELECT * from USERS WHERE name ='" + name + "'") + cursor.execute("SELECT * from USERS WHERE name =?", (name, )) ``` - -If you have feedback on this codemod, [please let us know](mailto:feedback@pixee.ai)! - -## F.A.Q. - -### Why is this codemod marked as Merge With Cursory Review - -Python has a wealth of database drivers that all use the same interface. Different drivers may require different string tokens used for parameterization, and Python's dynamic typing makes it quite hard, and sometimes impossible, to detect which driver is being used just by looking at the code. diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 5c25028b..c4bd7f40 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -20,6 +20,7 @@ from codemodder.codemods.transformations.remove_empty_string_concatenation import ( RemoveEmptyStringConcatenation, ) +from codemodder.codemods.utils import get_function_name_node from codemodder.codemods.utils_mixin import NameResolutionMixin parameter_token = "?" @@ -29,20 +30,22 @@ literal = literal_number | literal_string -class Append: +class SequenceExtension: def __init__(self, sequence: list[cst.CSTNode]) -> None: self.sequence = sequence -class Prepend: - def __init__(self, sequence: list[cst.CSTNode]) -> None: - self.sequence = sequence +class Append(SequenceExtension): + pass + + +class Prepend(SequenceExtension): + pass class ReplaceNodes(cst.CSTTransformer): """ - Replace nodes with their corresponding values in a given dict. - You can replace the entire node or some attributes of it via a dict(). Addionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, respectivelly. + Replace nodes with their corresponding values in a given dict. The replacements dictionary should either contain a mapping from a node to another node to be replaced, or a dict mapping each attribute, by name, to a new value. Additionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, respectively. """ def __init__( @@ -80,8 +83,9 @@ def on_leave(self, original_node, updated_node): class SQLQueryParameterization(BaseCodemod, Codemod): + SUMMARY = "Parameterize SQL queries." METADATA = CodemodMetadata( - DESCRIPTION=("Parameterize SQL queries."), + DESCRIPTION=SUMMARY, NAME="sql-parameterization", REVIEW_GUIDANCE=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, REFERENCES=[ @@ -95,19 +99,13 @@ class SQLQueryParameterization(BaseCodemod, Codemod): }, ], ) - SUMMARY = "Parameterize SQL queries." - CHANGE_DESCRIPTION = "" + CHANGE_DESCRIPTION = "Parameterized SQL query execution." METADATA_DEPENDENCIES = ( PositionProvider, ScopeProvider, ParentNodeProvider, ) - METADATA_DEPENDENCIES = ( - ScopeProvider, - ParentNodeProvider, - PositionProvider, - ) def __init__(self, context: CodemodContext, *codemod_args) -> None: self.changed_nodes: dict[ @@ -175,9 +173,6 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: return result - # Replace node entirelly -> CSTNode - # update some attribute -> dict[str,any] - # apend or prepend an argument -> Append([]) / Preprend([]) def _fix_injection( self, start: cst.CSTNode, middle: list[cst.CSTNode], end: cst.CSTNode ): @@ -352,7 +347,7 @@ def leave_Module(self, original_node: cst.Module): end = leaves.pop() if any(map(self._is_injectable, middle)): self.injection_patterns.append((start, middle, end)) - # end may contain the start of anothe literal, put it back + # end may contain the start of another literal, put it back # should not be a single quote # TODO think of a better solution here @@ -458,7 +453,7 @@ def _has_keyword(self, string: str) -> bool: return False def leave_Call(self, original_node: cst.Call) -> None: - maybe_call_name = _get_function_name_node(original_node) + maybe_call_name = get_function_name_node(original_node) if maybe_call_name and maybe_call_name.value == "execute": # TODO don't parameterize if there are parameters already # may be temporary until I figure out if named parameter will work on most drivers @@ -473,12 +468,3 @@ def leave_Call(self, original_node: cst.Call) -> None: expr.value ): self.calls[original_node] = query_visitor.leaves - - -def _get_function_name_node(call: cst.Call) -> Optional[cst.Name]: - match call.func: - case cst.Name(): - return call.func - case cst.Attribute(): - return call.func.attr - return None diff --git a/src/scripts/generate_docs.py b/src/scripts/generate_docs.py index 7861c946..3d4d4777 100644 --- a/src/scripts/generate_docs.py +++ b/src/scripts/generate_docs.py @@ -118,6 +118,10 @@ class DocMetadata: importance="Low", guidance_explained="We believe this replacement is safe and should not result in any issues.", ), + "sql-parameterization": DocMetadata( + importance="High", + guidance_explained="Python has a wealth of database drivers that all use the same `dbapi2` interface detailed in [PEP249](https://peps.python.org/pep-0249/). Different drivers may require different string tokens used for parameterization, and Python's dynamic typing makes it quite hard, and sometimes impossible, to detect which driver is being used just by looking at the code.", + ), } diff --git a/tests/codemods/test_sql_parameterization.py b/tests/codemods/test_sql_parameterization.py index a972c676..425b16c7 100644 --- a/tests/codemods/test_sql_parameterization.py +++ b/tests/codemods/test_sql_parameterization.py @@ -235,6 +235,7 @@ def foo(self, cursor, name, phone): assert len(self.file_context.codemod_changes) == 0 def test_wont_change_class_attribute(self, tmpdir): + # query may be accesed from outside the module by importing A input_code = """\ import sqlite3 @@ -250,6 +251,7 @@ def foo(self, name, cursor): assert len(self.file_context.codemod_changes) == 0 def test_wont_change_module_variable(self, tmpdir): + # query may be accesed from outside the module by importing it input_code = """\ import sqlite3 diff --git a/tests/transformations/test_remove_empty_string_concatenation.py b/tests/transformations/test_remove_empty_string_concatenation.py index cfc6d631..22de7a5b 100644 --- a/tests/transformations/test_remove_empty_string_concatenation.py +++ b/tests/transformations/test_remove_empty_string_concatenation.py @@ -47,6 +47,19 @@ def test_both(self): self.assertCodemod(before, after) + def test_multiple(self): + before = """ + "" + "whatever" + "" + "hello" + "" + "world" + """ + + after = """ + "whatever" + "hello" + "world" + """ + + self.assertCodemod(before, after) + def test_concatenated_string_right(self): before = """ "hello" "" @@ -69,6 +82,19 @@ def test_concatenated_string_left(self): self.assertCodemod(before, after) + def test_concatenated_string_multiple(self): + before = """ + "" "whatever" "" + "hello" "" "world" + """ + + after = """ + "whatever" + "hello" "world" + """ + + self.assertCodemod(before, after) + def test_multiple_mixed(self): before = ( """ From f4ded4e32c77301240caac5ff0ba68b2a9d549a5 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Tue, 24 Oct 2023 08:47:30 -0300 Subject: [PATCH 10/13] Refactored ReplaceNodes into utils --- src/codemodder/codemods/utils.py | 46 +++++++++++++++++-- src/core_codemods/sql_parameterization.py | 56 +---------------------- 2 files changed, 44 insertions(+), 58 deletions(-) diff --git a/src/codemodder/codemods/utils.py b/src/codemodder/codemods/utils.py index 2e9d5e1d..a4a9c0ab 100644 --- a/src/codemodder/codemods/utils.py +++ b/src/codemodder/codemods/utils.py @@ -1,27 +1,65 @@ from pathlib import Path -from typing import Optional, Union +from typing import Optional, Any from libcst import matchers import libcst as cst +class SequenceExtension: + def __init__(self, sequence: list[cst.CSTNode]) -> None: + self.sequence = sequence + + +class Append(SequenceExtension): + pass + + +class Prepend(SequenceExtension): + pass + + class ReplaceNodes(cst.CSTTransformer): """ - Replace nodes with their corresponding values in a given dict. + Replace nodes with their corresponding values in a given dict. The replacements dictionary should either contain a mapping from a node to another node, RemovalSentinel, or FlattenSentinel to be replaced, or a dict mapping each attribute, by name, to a new value. Additionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, respectively. """ def __init__( self, replacements: dict[ cst.CSTNode, - Union[cst.CSTNode, cst.FlattenSentinel[cst.CSTNode], cst.RemovalSentinel], + dict[ + str, + cst.CSTNode + | cst.FlattenSentinel + | cst.RemovalSentinel + | dict[str, Any], + ], ], ): self.replacements = replacements def on_leave(self, original_node, updated_node): if original_node in self.replacements.keys(): - return self.replacements[original_node] + replacement = self.replacements[original_node] + match replacement: + case dict(): + changes_dict = {} + for key, value in replacement.items(): + match value: + case Prepend(): + changes_dict[key] = value.sequence + [ + *getattr(updated_node, key) + ] + + case Append(): + changes_dict[key] = [ + *getattr(updated_node, key) + ] + value.sequence + case _: + changes_dict[key] = value + return updated_node.with_changes(**changes_dict) + case cst.CSTNode() | cst.RemovalSentinel() | cst.FlattenSentinel(): + return replacement return updated_node diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index c4bd7f40..061fcccb 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -1,5 +1,5 @@ import re -from typing import Any, Optional, Tuple +from typing import Optional, Tuple import libcst as cst from libcst import SimpleWhitespace, matchers from libcst.codemod import Codemod, CodemodContext, ContextAwareVisitor @@ -20,7 +20,7 @@ from codemodder.codemods.transformations.remove_empty_string_concatenation import ( RemoveEmptyStringConcatenation, ) -from codemodder.codemods.utils import get_function_name_node +from codemodder.codemods.utils import Append, ReplaceNodes, get_function_name_node from codemodder.codemods.utils_mixin import NameResolutionMixin parameter_token = "?" @@ -30,58 +30,6 @@ literal = literal_number | literal_string -class SequenceExtension: - def __init__(self, sequence: list[cst.CSTNode]) -> None: - self.sequence = sequence - - -class Append(SequenceExtension): - pass - - -class Prepend(SequenceExtension): - pass - - -class ReplaceNodes(cst.CSTTransformer): - """ - Replace nodes with their corresponding values in a given dict. The replacements dictionary should either contain a mapping from a node to another node to be replaced, or a dict mapping each attribute, by name, to a new value. Additionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, respectively. - """ - - def __init__( - self, - replacements: dict[ - cst.CSTNode, - dict[str, cst.CSTNode | dict[str, Any]], - ], - ): - self.replacements = replacements - - def on_leave(self, original_node, updated_node): - if original_node in self.replacements.keys(): - replacement = self.replacements[original_node] - match replacement: - case dict(): - changes_dict = {} - for key, value in replacement.items(): - match value: - case Prepend(): - changes_dict[key] = value.sequence + [ - *getattr(updated_node, key) - ] - - case Append(): - changes_dict[key] = [ - *getattr(updated_node, key) - ] + value.sequence - case _: - changes_dict[key] = value - return updated_node.with_changes(**changes_dict) - case cst.CSTNode(): - return replacement - return updated_node - - class SQLQueryParameterization(BaseCodemod, Codemod): SUMMARY = "Parameterize SQL queries." METADATA = CodemodMetadata( From 22a86492270596c94d4ffc98bcb2e5aa294e2428 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Tue, 24 Oct 2023 09:15:20 -0300 Subject: [PATCH 11/13] Added line includes excludes to SQLQueryParameterization --- src/core_codemods/sql_parameterization.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 061fcccb..ec66f6b4 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -17,11 +17,14 @@ CodemodMetadata, ReviewGuidance, ) +from codemodder.codemods.base_visitor import UtilsMixin from codemodder.codemods.transformations.remove_empty_string_concatenation import ( RemoveEmptyStringConcatenation, ) from codemodder.codemods.utils import Append, ReplaceNodes, get_function_name_node from codemodder.codemods.utils_mixin import NameResolutionMixin +from codemodder.context import CodemodExecutionContext +from codemodder.file_context import FileContext parameter_token = "?" @@ -30,7 +33,7 @@ literal = literal_number | literal_string -class SQLQueryParameterization(BaseCodemod, Codemod): +class SQLQueryParameterization(BaseCodemod, UtilsMixin, Codemod): SUMMARY = "Parameterize SQL queries." METADATA = CodemodMetadata( DESCRIPTION=SUMMARY, @@ -55,12 +58,19 @@ class SQLQueryParameterization(BaseCodemod, Codemod): ParentNodeProvider, ) - def __init__(self, context: CodemodContext, *codemod_args) -> None: + def __init__( + self, + context: CodemodContext, + execution_context: CodemodExecutionContext, + file_context: FileContext, + *codemod_args, + ) -> None: self.changed_nodes: dict[ cst.CSTNode, cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel ] = {} + BaseCodemod.__init__(self, execution_context, file_context, *codemod_args) + UtilsMixin.__init__(self, context, {}) Codemod.__init__(self, context) - BaseCodemod.__init__(self, *codemod_args) def _build_param_element(self, middle, index: int) -> cst.BaseExpression: # TODO maybe a parameterized string would be better here @@ -83,6 +93,11 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: result = tree for call, query in find_queries.calls.items(): + # filter by line includes/excludes + call_pos = self.node_position(call) + if not self.filter_by_path_includes_or_excludes(call_pos): + break + ep = ExtractParameters(self.context, query) tree.visit(ep) @@ -342,6 +357,7 @@ def _is_injectable(self, expression: cst.CSTNode) -> bool: return True def _is_literal_start(self, node: cst.CSTNode, modulo_2: int) -> bool: + # TODO limited for now, won't include cases like "name = 'username_" + name + "_tail'" match node: case cst.SimpleString(): prefix = node.prefix.lower() From ce44d621215ab718d643834367d312a0a4e74e24 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Tue, 24 Oct 2023 14:14:56 -0300 Subject: [PATCH 12/13] Removed sample database file, adjusted tests --- integration_tests/test_sql_parameterization.py | 10 +++++----- tests/samples/my_db.db | Bin 8192 -> 0 bytes tests/samples/sql_injection.py | 4 +++- 3 files changed, 8 insertions(+), 6 deletions(-) delete mode 100644 tests/samples/my_db.db diff --git a/integration_tests/test_sql_parameterization.py b/integration_tests/test_sql_parameterization.py index 6e8cd9da..f27063fa 100644 --- a/integration_tests/test_sql_parameterization.py +++ b/integration_tests/test_sql_parameterization.py @@ -11,9 +11,9 @@ class TestSQLQueryParameterization(BaseIntegrationTest): original_code, expected_new_code = original_and_expected_from_code_path( code_path, [ - (7, """ b = " WHERE name =?"\n"""), - (8, """ c = " AND phone = ?"\n"""), - (9, """ r = cursor.execute(a + b + c, (name, phone, ))\n"""), + (9, """ b = " WHERE name =?"\n"""), + (10, """ c = " AND phone = ?"\n"""), + (11, """ r = cursor.execute(a + b + c, (name, phone, ))\n"""), ], ) @@ -21,7 +21,7 @@ class TestSQLQueryParameterization(BaseIntegrationTest): expected_diff =( """--- \n""" """+++ \n""" - """@@ -5,9 +5,9 @@\n""" + """@@ -7,9 +7,9 @@\n""" """ \n""" """ def foo(cursor: sqlite3.Cursor, name: str, phone: str):\n""" """ a = "SELECT * FROM Users"\n""" @@ -36,6 +36,6 @@ class TestSQLQueryParameterization(BaseIntegrationTest): """ \n""") # fmt: on - expected_line_change = "10" + expected_line_change = "12" change_description = SQLQueryParameterization.CHANGE_DESCRIPTION num_changed_files = 1 diff --git a/tests/samples/my_db.db b/tests/samples/my_db.db deleted file mode 100644 index 9a18d91146cc9f9a981c9c7312704910f3d639e0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8192 zcmeI#y$ZrG5C`x}6*pqxA_@hlQW%hfVb z;wsZ!EK|1|1YT2|?fK4 diff --git a/tests/samples/sql_injection.py b/tests/samples/sql_injection.py index 483aa8ee..3ef768ad 100644 --- a/tests/samples/sql_injection.py +++ b/tests/samples/sql_injection.py @@ -1,6 +1,8 @@ import sqlite3 -connection = sqlite3.connect("tests/samples/my_db.db") +connection = sqlite3.connect(":memory:") +connection.cursor().execute("CREATE TABLE Users (name, phone)") +connection.cursor().execute("INSERT INTO Users VALUES ('Jenny','867-5309')") def foo(cursor: sqlite3.Cursor, name: str, phone: str): From efdbd1424250ead8ee1722bfd07cc10c5dda6307 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Wed, 25 Oct 2023 07:15:45 -0300 Subject: [PATCH 13/13] Fixed issue with Change reporting --- src/core_codemods/sql_parameterization.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index ec66f6b4..c6ae7c62 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -128,9 +128,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: self.changed_nodes = {} line_number = self.get_metadata(PositionProvider, call).start.line self.file_context.codemod_changes.append( - Change( - str(line_number), SQLQueryParameterization.CHANGE_DESCRIPTION - ).to_json() + Change(line_number, SQLQueryParameterization.CHANGE_DESCRIPTION) ) result = result.visit(RemoveEmptyStringConcatenation())