diff --git a/src/codemodder/codemods/utils.py b/src/codemodder/codemods/utils.py index 2e9d5e1d..a4a9c0ab 100644 --- a/src/codemodder/codemods/utils.py +++ b/src/codemodder/codemods/utils.py @@ -1,27 +1,65 @@ from pathlib import Path -from typing import Optional, Union +from typing import Optional, Any from libcst import matchers import libcst as cst +class SequenceExtension: + def __init__(self, sequence: list[cst.CSTNode]) -> None: + self.sequence = sequence + + +class Append(SequenceExtension): + pass + + +class Prepend(SequenceExtension): + pass + + class ReplaceNodes(cst.CSTTransformer): """ - Replace nodes with their corresponding values in a given dict. + Replace nodes with their corresponding values in a given dict. The replacements dictionary should either contain a mapping from a node to another node, RemovalSentinel, or FlattenSentinel to be replaced, or a dict mapping each attribute, by name, to a new value. Additionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, respectively. """ def __init__( self, replacements: dict[ cst.CSTNode, - Union[cst.CSTNode, cst.FlattenSentinel[cst.CSTNode], cst.RemovalSentinel], + dict[ + str, + cst.CSTNode + | cst.FlattenSentinel + | cst.RemovalSentinel + | dict[str, Any], + ], ], ): self.replacements = replacements def on_leave(self, original_node, updated_node): if original_node in self.replacements.keys(): - return self.replacements[original_node] + replacement = self.replacements[original_node] + match replacement: + case dict(): + changes_dict = {} + for key, value in replacement.items(): + match value: + case Prepend(): + changes_dict[key] = value.sequence + [ + *getattr(updated_node, key) + ] + + case Append(): + changes_dict[key] = [ + *getattr(updated_node, key) + ] + value.sequence + case _: + changes_dict[key] = value + return updated_node.with_changes(**changes_dict) + case cst.CSTNode() | cst.RemovalSentinel() | cst.FlattenSentinel(): + return replacement return updated_node diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index c4bd7f40..061fcccb 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -1,5 +1,5 @@ import re -from typing import Any, Optional, Tuple +from typing import Optional, Tuple import libcst as cst from libcst import SimpleWhitespace, matchers from libcst.codemod import Codemod, CodemodContext, ContextAwareVisitor @@ -20,7 +20,7 @@ from codemodder.codemods.transformations.remove_empty_string_concatenation import ( RemoveEmptyStringConcatenation, ) -from codemodder.codemods.utils import get_function_name_node +from codemodder.codemods.utils import Append, ReplaceNodes, get_function_name_node from codemodder.codemods.utils_mixin import NameResolutionMixin parameter_token = "?" @@ -30,58 +30,6 @@ literal = literal_number | literal_string -class SequenceExtension: - def __init__(self, sequence: list[cst.CSTNode]) -> None: - self.sequence = sequence - - -class Append(SequenceExtension): - pass - - -class Prepend(SequenceExtension): - pass - - -class ReplaceNodes(cst.CSTTransformer): - """ - Replace nodes with their corresponding values in a given dict. The replacements dictionary should either contain a mapping from a node to another node to be replaced, or a dict mapping each attribute, by name, to a new value. Additionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, respectively. - """ - - def __init__( - self, - replacements: dict[ - cst.CSTNode, - dict[str, cst.CSTNode | dict[str, Any]], - ], - ): - self.replacements = replacements - - def on_leave(self, original_node, updated_node): - if original_node in self.replacements.keys(): - replacement = self.replacements[original_node] - match replacement: - case dict(): - changes_dict = {} - for key, value in replacement.items(): - match value: - case Prepend(): - changes_dict[key] = value.sequence + [ - *getattr(updated_node, key) - ] - - case Append(): - changes_dict[key] = [ - *getattr(updated_node, key) - ] + value.sequence - case _: - changes_dict[key] = value - return updated_node.with_changes(**changes_dict) - case cst.CSTNode(): - return replacement - return updated_node - - class SQLQueryParameterization(BaseCodemod, Codemod): SUMMARY = "Parameterize SQL queries." METADATA = CodemodMetadata(