diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index af6554fe1..7f82c42a9 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -1,5 +1,5 @@ import re -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import libcst as cst from libcst import CSTTransformer, SimpleWhitespace, matchers from libcst.codemod import Codemod, CodemodContext, ContextAwareVisitor @@ -17,7 +17,6 @@ CodemodMetadata, ReviewGuidance, ) -from codemodder.codemods.utils import ReplaceNodes from codemodder.codemods.utils_mixin import NameResolutionMixin parameter_token = "?" @@ -27,6 +26,57 @@ literal = literal_number | literal_string +class Append: + def __init__(self, sequence: list[cst.CSTNode]) -> None: + self.sequence = sequence + + +class Prepend: + def __init__(self, sequence: list[cst.CSTNode]) -> None: + self.sequence = sequence + + +class ReplaceNodes(cst.CSTTransformer): + """ + Replace nodes with their corresponding values in a given dict. + You can replace the entire node, some attributes of it via a dict(). Addionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, repectivelly a sequence. + """ + + def __init__( + self, + replacements: dict[ + cst.CSTNode, + dict[str, cst.CSTNode | dict[str, Any]], + ], + ): + self.replacements = replacements + + def on_leave(self, original_node, updated_node): + if original_node in self.replacements.keys(): + replacement = self.replacements[original_node] + match replacement: + case dict(): + changes_dict = {} + for key, value in replacement.items(): + match value: + case Prepend(): + changes_dict[key] = value.sequence + [ + *getattr(updated_node, key) + ] + + case Append(): + changes_dict[key] = [ + *getattr(updated_node, key) + ] + value.sequence + print(changes_dict) + case _: + changes_dict[key] = value + return updated_node.with_changes(**changes_dict) + case cst.CSTNode(): + return replacement + return updated_node + + class SQLQueryParameterization(BaseCodemod, Codemod): METADATA = CodemodMetadata( DESCRIPTION=("Parameterize SQL queries."), @@ -96,7 +146,8 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: # TODO research if named parameters are widely supported # it could solve for the case of existing parameters tuple_arg = cst.Arg(cst.Tuple(elements=params_elements)) - self.changed_nodes[call] = call.with_changes(args=[*call.args, tuple_arg]) + # self.changed_nodes[call] = call.with_changes(args=[*call.args, tuple_arg]) + self.changed_nodes[call] = {"args": Append([tuple_arg])} # made changes if self.changed_nodes: @@ -112,6 +163,9 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: return result + # Replace node entirelly -> CSTNode + # update some attribute -> dict[str,any] + # apend or prepend an argument -> Append([]) / Preprend([]) def _fix_injection( self, start: cst.CSTNode, middle: list[cst.CSTNode], end: cst.CSTNode ): diff --git a/tests/codemods/test_sql_parameterization.py b/tests/codemods/test_sql_parameterization.py new file mode 100644 index 000000000..9eddd8a16 --- /dev/null +++ b/tests/codemods/test_sql_parameterization.py @@ -0,0 +1,30 @@ +from core_codemods.sql_parameterization import SQLQueryParameterization +from tests.codemods.base_codemod_test import BaseCodemodTest +from textwrap import dedent + + +class TestSQLQueryParameterization(BaseCodemodTest): + codemod = SQLQueryParameterization + + def test_name(self): + assert self.codemod.name() == "sql-parameterization" + + def test_simple(self, tmpdir): + input_code = """\ + import sqlite3 + from a import name + + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS where name ='" + name + "'") + """ + expected = """\ + import sqlite3 + from a import name + + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + cursor.execute("SELECT * from USERS where name =?", (name, )) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1