Skip to content

Commit

Permalink
Refactored ReplaceNodes into utils
Browse files Browse the repository at this point in the history
  • Loading branch information
andrecsilva committed Oct 24, 2023
1 parent 3b87239 commit b24e489
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 58 deletions.
46 changes: 42 additions & 4 deletions src/codemodder/codemods/utils.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,65 @@
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Any

from libcst import matchers
import libcst as cst


class SequenceExtension:
def __init__(self, sequence: list[cst.CSTNode]) -> None:
self.sequence = sequence


class Append(SequenceExtension):
pass


class Prepend(SequenceExtension):
pass


class ReplaceNodes(cst.CSTTransformer):
"""
Replace nodes with their corresponding values in a given dict.
Replace nodes with their corresponding values in a given dict. The replacements dictionary should either contain a mapping from a node to another node, RemovalSentinel, or FlattenSentinel to be replaced, or a dict mapping each attribute, by name, to a new value. Additionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, respectively.
"""

def __init__(
self,
replacements: dict[
cst.CSTNode,
Union[cst.CSTNode, cst.FlattenSentinel[cst.CSTNode], cst.RemovalSentinel],
dict[
str,
cst.CSTNode
| cst.FlattenSentinel
| cst.RemovalSentinel
| dict[str, Any],
],
],
):
self.replacements = replacements

def on_leave(self, original_node, updated_node):
if original_node in self.replacements.keys():
return self.replacements[original_node]
replacement = self.replacements[original_node]
match replacement:
case dict():
changes_dict = {}
for key, value in replacement.items():
match value:
case Prepend():
changes_dict[key] = value.sequence + [
*getattr(updated_node, key)
]

case Append():
changes_dict[key] = [
*getattr(updated_node, key)
] + value.sequence
case _:
changes_dict[key] = value
return updated_node.with_changes(**changes_dict)
case cst.CSTNode() | cst.RemovalSentinel() | cst.FlattenSentinel():
return replacement
return updated_node


Expand Down
56 changes: 2 additions & 54 deletions src/core_codemods/sql_parameterization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, Optional, Tuple
from typing import Optional, Tuple
import libcst as cst
from libcst import SimpleWhitespace, matchers
from libcst.codemod import Codemod, CodemodContext, ContextAwareVisitor
Expand All @@ -20,7 +20,7 @@
from codemodder.codemods.transformations.remove_empty_string_concatenation import (
RemoveEmptyStringConcatenation,
)
from codemodder.codemods.utils import get_function_name_node
from codemodder.codemods.utils import Append, ReplaceNodes, get_function_name_node
from codemodder.codemods.utils_mixin import NameResolutionMixin

parameter_token = "?"
Expand All @@ -30,58 +30,6 @@
literal = literal_number | literal_string


class SequenceExtension:
def __init__(self, sequence: list[cst.CSTNode]) -> None:
self.sequence = sequence


class Append(SequenceExtension):
pass


class Prepend(SequenceExtension):
pass


class ReplaceNodes(cst.CSTTransformer):
"""
Replace nodes with their corresponding values in a given dict. The replacements dictionary should either contain a mapping from a node to another node to be replaced, or a dict mapping each attribute, by name, to a new value. Additionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, respectively.
"""

def __init__(
self,
replacements: dict[
cst.CSTNode,
dict[str, cst.CSTNode | dict[str, Any]],
],
):
self.replacements = replacements

def on_leave(self, original_node, updated_node):
if original_node in self.replacements.keys():
replacement = self.replacements[original_node]
match replacement:
case dict():
changes_dict = {}
for key, value in replacement.items():
match value:
case Prepend():
changes_dict[key] = value.sequence + [
*getattr(updated_node, key)
]

case Append():
changes_dict[key] = [
*getattr(updated_node, key)
] + value.sequence
case _:
changes_dict[key] = value
return updated_node.with_changes(**changes_dict)
case cst.CSTNode():
return replacement
return updated_node


class SQLQueryParameterization(BaseCodemod, Codemod):
SUMMARY = "Parameterize SQL queries."
METADATA = CodemodMetadata(
Expand Down

0 comments on commit b24e489

Please sign in to comment.