Skip to content

Commit

Permalink
Implement codemod for replacing xml with defusedxml
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Oct 17, 2023
1 parent 21ae8b2 commit 6df53cc
Show file tree
Hide file tree
Showing 7 changed files with 396 additions and 101 deletions.
2 changes: 1 addition & 1 deletion src/codemodder/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class CsvListAction(argparse.Action):
argparse Action to convert "a,b,c" into ["a", "b", "c"]
"""

def __call__(self, parser, namespace, values: str, option_string=None):
def __call__(self, parser, namespace, values, option_string=None):
# Conversion to dict removes duplicates while preserving order
items = list(dict.fromkeys(values.split(",")).keys())
self.validate_items(items)
Expand Down
109 changes: 109 additions & 0 deletions src/codemodder/codemods/imported_call_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import abc
from typing import Generic, Mapping, Sequence, Set, TypeVar, Union

import libcst as cst
from libcst import matchers
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.metadata import PositionProvider

from codemodder.change import Change
from codemodder.codemods.utils_mixin import NameResolutionMixin
from codemodder.file_context import FileContext


# It seems to me like we actually want two separate bounds instead of a Union but this is what mypy wants
FunctionMatchType = TypeVar("FunctionMatchType", bound=Union[Mapping, Set])


class ImportedCallModifier(
Generic[FunctionMatchType],
VisitorBasedCodemodCommand,
NameResolutionMixin,
metaclass=abc.ABCMeta,
):
METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(
self,
codemod_context: CodemodContext,
file_context: FileContext,
matching_functions: FunctionMatchType,
change_description: str,
):
super().__init__(codemod_context)
self.line_exclude = file_context.line_exclude
self.line_include = file_context.line_include
self.matching_functions: FunctionMatchType = matching_functions
self.change_description = change_description
self.changes_in_file: list[Mapping] = []

def updated_args(self, original_args: Sequence[cst.Arg]):
return original_args

@abc.abstractmethod
def update_attribute(
self,
true_name: str,
original_node: cst.Call,
updated_node: cst.Call,
new_args: Sequence[cst.Arg],
):
"""Callback to modify tree when the detected call is of the form a.call()"""

@abc.abstractmethod
def update_simple_name(
self,
true_name: str,
original_node: cst.Call,
updated_node: cst.Call,
new_args: Sequence[cst.Arg],
):
"""Callback to modify tree when the detected call is of the form call()"""

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
pos_to_match = self.node_position(original_node)
line_number = pos_to_match.start.line
if self.filter_by_path_includes_or_excludes(pos_to_match):
true_name = self.find_base_name(original_node.func)
if (
self._is_direct_call_from_imported_module(original_node)
and true_name
and true_name in self.matching_functions
):
self.changes_in_file.append(
Change(str(line_number), self.change_description).to_json()
)

new_args = self.updated_args(updated_node.args)

# has a prefix, e.g. a.call() -> a.new_call()
if matchers.matches(original_node.func, matchers.Attribute()):
return self.update_attribute(
true_name, original_node, updated_node, new_args
)

# it is a simple name, e.g. call() -> module.new_call()
return self.update_simple_name(
true_name, original_node, updated_node, new_args
)

return updated_node

def filter_by_path_includes_or_excludes(self, pos_to_match):
"""
Returns False if the node, whose position in the file is pos_to_match, matches any of the lines specified in the path-includes or path-excludes flags.
"""
# excludes takes precedence if defined
if self.line_exclude:
return not any(match_line(pos_to_match, line) for line in self.line_exclude)
if self.line_include:
return any(match_line(pos_to_match, line) for line in self.line_include)
return True

def node_position(self, node):
# See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112
return self.get_metadata(PositionProvider, node)


def match_line(pos, line):
return pos.start.line == line and pos.end.line == line
145 changes: 53 additions & 92 deletions src/core_codemods/https_connection.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,55 @@
from typing import Sequence
from libcst import matchers
from typing import Sequence, Set

import libcst as cst
from libcst.codemod import Codemod, CodemodContext
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from libcst.metadata import (
PositionProvider,
)
from libcst.metadata import PositionProvider

from codemodder.codemods.base_codemod import (
BaseCodemod,
CodemodMetadata,
ReviewGuidance,
)
from codemodder.change import Change
from codemodder.codemods.utils_mixin import NameResolutionMixin
from codemodder.file_context import FileContext
import libcst as cst
from libcst.codemod import (
Codemod,
CodemodContext,
VisitorBasedCodemodCommand,
)
from codemodder.codemods.imported_call_modifier import ImportedCallModifier


class HTTPSConnectionModifier(ImportedCallModifier[Set[str]]):
def updated_args(self, original_args):
"""
Last argument _proxy_config does not match new method
We convert it to keyword
"""
new_args = list(original_args)
if self.count_positional_args(new_args) == 10:
new_args[9] = new_args[9].with_changes(
keyword=cst.parse_expression("_proxy_config")
)
return new_args

def update_attribute(self, true_name, original_node, updated_node, new_args):
del true_name, original_node
return updated_node.with_changes(
args=new_args,
func=updated_node.func.with_changes(
attr=cst.Name(value="HTTPSConnectionPool")
),
)

def update_simple_name(self, true_name, original_node, updated_node, new_args):
del true_name
AddImportsVisitor.add_needed_import(self.context, "urllib3")
RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node)
return updated_node.with_changes(
args=new_args,
func=cst.parse_expression("urllib3.HTTPSConnectionPool"),
)

def count_positional_args(self, arglist: Sequence[cst.Arg]) -> int:
for idx, arg in enumerate(arglist):
if arg.keyword:
return idx
return len(arglist)


class HTTPSConnection(BaseCodemod, Codemod):
Expand All @@ -43,7 +75,7 @@ class HTTPSConnection(BaseCodemod, Codemod):

METADATA_DEPENDENCIES = (PositionProvider,)

matching_functions = {
matching_functions: set[str] = {
"urllib3.HTTPConnectionPool",
"urllib3.connectionpool.HTTPConnectionPool",
}
Expand All @@ -53,83 +85,12 @@ def __init__(self, codemod_context: CodemodContext, *codemod_args):
BaseCodemod.__init__(self, *codemod_args)

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
visitor = ConnectionPollVisitor(self.context, self.file_context)
visitor = HTTPSConnectionModifier(
self.context,
self.file_context,
self.matching_functions,
self.CHANGE_DESCRIPTION,
)
result_tree = visitor.transform_module(tree)
self.file_context.codemod_changes.extend(visitor.changes_in_file)
return result_tree


class ConnectionPollVisitor(VisitorBasedCodemodCommand, NameResolutionMixin):
METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
super().__init__(codemod_context)
self.line_exclude = file_context.line_exclude
self.line_include = file_context.line_include
self.changes_in_file: list[Change] = []

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
pos_to_match = self.node_position(original_node)
line_number = pos_to_match.start.line
if self.filter_by_path_includes_or_excludes(pos_to_match):
true_name = self.find_base_name(original_node.func)
if (
self._is_direct_call_from_imported_module(original_node)
and true_name in HTTPSConnection.matching_functions
):
self.changes_in_file.append(
Change(
str(line_number), HTTPSConnection.CHANGE_DESCRIPTION
).to_json()
)
# last argument _proxy_config does not match new method
# we convert it to keyword
new_args = list(original_node.args)
if count_positional_args(original_node.args) == 10:
new_args[9] = new_args[9].with_changes(
keyword=cst.parse_expression("_proxy_config")
)
# has a prefix, e.g. a.call() -> a.new_call()
if matchers.matches(original_node.func, matchers.Attribute()):
return updated_node.with_changes(
args=new_args,
func=updated_node.func.with_changes(
attr=cst.parse_expression("HTTPSConnectionPool")
),
)
# it is a simple name, e.g. call() -> module.new_call()
AddImportsVisitor.add_needed_import(self.context, "urllib3")
RemoveImportsVisitor.remove_unused_import_by_node(
self.context, original_node
)
return updated_node.with_changes(
args=new_args,
func=cst.parse_expression("urllib3.HTTPSConnectionPool"),
)
return updated_node

def filter_by_path_includes_or_excludes(self, pos_to_match):
"""
Returns False if the node, whose position in the file is pos_to_match, matches any of the lines specified in the path-includes or path-excludes flags.
"""
# excludes takes precedence if defined
if self.line_exclude:
return not any(match_line(pos_to_match, line) for line in self.line_exclude)
if self.line_include:
return any(match_line(pos_to_match, line) for line in self.line_include)
return True

def node_position(self, node):
# See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112
return self.get_metadata(PositionProvider, node)


def match_line(pos, line):
return pos.start.line == line and pos.end.line == line


def count_positional_args(arglist: Sequence[cst.Arg]) -> int:
for i, arg in enumerate(arglist):
if arg.keyword:
return i
return len(arglist)
95 changes: 95 additions & 0 deletions src/core_codemods/use_defused_xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from functools import cached_property
from typing import Mapping

import libcst as cst
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor

from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import BaseCodemod
from codemodder.codemods.imported_call_modifier import ImportedCallModifier


class DefusedXmlModifier(ImportedCallModifier[Mapping[str, str]]):
def update_attribute(self, true_name, _, updated_node, new_args):
import_name = self.matching_functions[true_name]
AddImportsVisitor.add_needed_import(self.context, import_name)
return updated_node.with_changes(
args=new_args,
func=cst.Attribute(
value=cst.parse_expression(import_name),
attr=cst.Name(value=true_name.split(".")[-1]),
),
)

def update_simple_name(self, true_name, _, updated_node, new_args):
import_name = self.matching_functions[true_name]
AddImportsVisitor.add_needed_import(self.context, import_name)
return updated_node.with_changes(
args=new_args,
func=cst.Attribute(
value=cst.parse_expression(import_name),
attr=cst.Name(value=true_name.split(".")[-1]),
),
)


ETREE_METHODS = ["parse", "fromstring", "iterparse", "XMLParser"]
SAX_METHODS = ["parse", "make_parser", "parseString"]
DOM_METHODS = ["parse", "parseString"]
# TODO: add expat methods?


class UseDefusedXml(BaseCodemod):
NAME = "use-defusedxml"
SUMMARY = "Use `defusedxml` for Parsing XML"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_REVIEW
DESCRIPTION = "Replace builtin xml method with safe defusedxml method"
REFERENCES = [
{
"url": "https://docs.python.org/3/library/xml.html#xml-vulnerabilities",
"description": "",
},
{
"url": "https://docs.python.org/3/library/xml.html#the-defusedxml-package",
"description": "",
},
{
"url": "https://pypi.org/project/defusedxml/",
"description": "",
},
{
"url": "https://cheatsheetseries.owasp.org/cheatsheets/XML_External_Entity_Prevention_Cheat_Sheet.html",
"description": "",
},
]

@cached_property
def matching_functions(self) -> dict[str, str]:
"""Build a mapping of functions to their defusedxml imports"""
_matching_functions: dict[str, str] = {}
for module, defusedxml, methods in [
("xml.etree.ElementTree", "defusedxml.ElementTree", ETREE_METHODS),
("xml.etree.cElementTree", "defusedxml.ElementTree", ETREE_METHODS),
("xml.sax", "defusedxml.sax", SAX_METHODS),
("xml.dom.minidom", "defusedxml.minidom", DOM_METHODS),
("xml.dom.pulldom", "defusedxml.pulldom", DOM_METHODS),
]:
_matching_functions.update(
{f"{module}.{method}": defusedxml for method in methods}
)
return _matching_functions

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
visitor = DefusedXmlModifier(
self.context,
self.file_context,
self.matching_functions,
self.CHANGE_DESCRIPTION,
)
result_tree = visitor.transform_module(tree)
self.file_context.codemod_changes.extend(visitor.changes_in_file)
if visitor.changes_in_file:
RemoveImportsVisitor.remove_unused_import_by_node(self.context, tree)
# TODO: add dependency

return result_tree
Loading

0 comments on commit 6df53cc

Please sign in to comment.