-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement codemod for replacing xml with defusedxml
- Loading branch information
Showing
6 changed files
with
394 additions
and
100 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import abc | ||
from typing import Generic, Mapping, Sequence, Set, TypeVar | ||
|
||
import libcst as cst | ||
from libcst import matchers | ||
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand | ||
from libcst.metadata import PositionProvider | ||
|
||
from codemodder.change import Change | ||
from codemodder.codemods.utils_mixin import NameResolutionMixin | ||
from codemodder.file_context import FileContext | ||
|
||
|
||
FunctionMatchType = TypeVar("FunctionMatchType", Mapping, Set) | ||
|
||
|
||
class ImportedCallModifier( | ||
VisitorBasedCodemodCommand, | ||
NameResolutionMixin, | ||
Generic[FunctionMatchType], | ||
metaclass=abc.ABCMeta, | ||
): | ||
METADATA_DEPENDENCIES = (PositionProvider,) | ||
|
||
def __init__( | ||
self, | ||
codemod_context: CodemodContext, | ||
file_context: FileContext, | ||
matching_functions: FunctionMatchType, | ||
change_description: str, | ||
): | ||
super().__init__(codemod_context) | ||
self.line_exclude = file_context.line_exclude | ||
self.line_include = file_context.line_include | ||
self.matching_functions = matching_functions | ||
self.change_description = change_description | ||
self.changes_in_file = [] | ||
|
||
def updated_args(self, original_args: Sequence[cst.Arg]): | ||
return original_args | ||
|
||
@abc.abstractmethod | ||
def update_attribute( | ||
self, | ||
true_name: str, | ||
original_node: cst.Call, | ||
updated_node: cst.Call, | ||
new_args: Sequence[cst.Arg], | ||
): | ||
"""Callback to modify tree when the detected call is of the form a.call()""" | ||
|
||
@abc.abstractmethod | ||
def update_simple_name( | ||
self, | ||
true_name: str, | ||
original_node: cst.Call, | ||
updated_node: cst.Call, | ||
new_args: Sequence[cst.Arg], | ||
): | ||
"""Callback to modify tree when the detected call is of the form call()""" | ||
|
||
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): | ||
pos_to_match = self.node_position(original_node) | ||
line_number = pos_to_match.start.line | ||
if self.filter_by_path_includes_or_excludes(pos_to_match): | ||
true_name = self.find_base_name(original_node.func) | ||
if ( | ||
self._is_direct_call_from_imported_module(original_node) | ||
and true_name | ||
and true_name in self.matching_functions | ||
): | ||
self.changes_in_file.append( | ||
Change(str(line_number), self.change_description).to_json() | ||
) | ||
|
||
new_args = self.updated_args(updated_node.args) | ||
|
||
# has a prefix, e.g. a.call() -> a.new_call() | ||
if matchers.matches(original_node.func, matchers.Attribute()): | ||
return self.update_attribute( | ||
true_name, original_node, updated_node, new_args | ||
) | ||
|
||
# it is a simple name, e.g. call() -> module.new_call() | ||
return self.update_simple_name( | ||
true_name, original_node, updated_node, new_args | ||
) | ||
|
||
return updated_node | ||
|
||
def filter_by_path_includes_or_excludes(self, pos_to_match): | ||
""" | ||
Returns False if the node, whose position in the file is pos_to_match, matches any of the lines specified in the path-includes or path-excludes flags. | ||
""" | ||
# excludes takes precedence if defined | ||
if self.line_exclude: | ||
return not any(match_line(pos_to_match, line) for line in self.line_exclude) | ||
if self.line_include: | ||
return any(match_line(pos_to_match, line) for line in self.line_include) | ||
return True | ||
|
||
def node_position(self, node): | ||
# See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112 | ||
return self.get_metadata(PositionProvider, node) | ||
|
||
|
||
def match_line(pos, line): | ||
return pos.start.line == line and pos.end.line == line |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from functools import cached_property | ||
from typing import Mapping | ||
|
||
import libcst as cst | ||
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor | ||
|
||
from codemodder.codemods.base_codemod import ReviewGuidance | ||
from codemodder.codemods.api import BaseCodemod | ||
from codemodder.codemods.imported_call_modifier import ImportedCallModifier | ||
|
||
|
||
class DefusedXmlModifier(ImportedCallModifier[Mapping[str, str]]): | ||
def update_attribute(self, true_name, _, updated_node, new_args): | ||
import_name = self.matching_functions[true_name] | ||
AddImportsVisitor.add_needed_import(self.context, import_name) | ||
return updated_node.with_changes( | ||
args=new_args, | ||
func=cst.Attribute( | ||
value=cst.parse_expression(import_name), | ||
attr=cst.Name(value=true_name.split(".")[-1]), | ||
), | ||
) | ||
|
||
def update_simple_name(self, true_name, _, updated_node, new_args): | ||
import_name = self.matching_functions[true_name] | ||
AddImportsVisitor.add_needed_import(self.context, import_name) | ||
return updated_node.with_changes( | ||
args=new_args, | ||
func=cst.Attribute( | ||
value=cst.parse_expression(import_name), | ||
attr=cst.Name(value=true_name.split(".")[-1]), | ||
), | ||
) | ||
|
||
|
||
ETREE_METHODS = ["parse", "fromstring", "iterparse", "XMLParser"] | ||
SAX_METHODS = ["parse", "make_parser", "parseString"] | ||
DOM_METHODS = ["parse", "parseString"] | ||
# TODO: add expat methods? | ||
|
||
|
||
class UseDefusedXml(BaseCodemod): | ||
NAME = "use-defusedxml" | ||
SUMMARY = "Use `defusedxml` for Parsing XML" | ||
REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_REVIEW | ||
DESCRIPTION = "Replace builtin xml method with safe defusedxml method" | ||
REFERENCES = [ | ||
{ | ||
"url": "https://docs.python.org/3/library/xml.html#xml-vulnerabilities", | ||
"description": "", | ||
}, | ||
{ | ||
"url": "https://docs.python.org/3/library/xml.html#the-defusedxml-package", | ||
"description": "", | ||
}, | ||
{ | ||
"url": "https://pypi.org/project/defusedxml/", | ||
"description": "", | ||
}, | ||
{ | ||
"url": "https://cheatsheetseries.owasp.org/cheatsheets/XML_External_Entity_Prevention_Cheat_Sheet.html", | ||
"description": "", | ||
}, | ||
] | ||
|
||
@cached_property | ||
def matching_functions(self) -> dict[str, str]: | ||
"""Build a mapping of functions to their defusedxml imports""" | ||
_matching_functions: dict[str, str] = {} | ||
for module, defusedxml, methods in [ | ||
("xml.etree.ElementTree", "defusedxml.ElementTree", ETREE_METHODS), | ||
("xml.etree.cElementTree", "defusedxml.ElementTree", ETREE_METHODS), | ||
("xml.sax", "defusedxml.sax", SAX_METHODS), | ||
("xml.dom.minidom", "defusedxml.minidom", DOM_METHODS), | ||
("xml.dom.pulldom", "defusedxml.pulldom", DOM_METHODS), | ||
]: | ||
_matching_functions.update( | ||
{f"{module}.{method}": defusedxml for method in methods} | ||
) | ||
return _matching_functions | ||
|
||
def transform_module_impl(self, tree: cst.Module) -> cst.Module: | ||
visitor = DefusedXmlModifier( | ||
self.context, | ||
self.file_context, | ||
self.matching_functions, | ||
self.CHANGE_DESCRIPTION, | ||
) | ||
result_tree = visitor.transform_module(tree) | ||
self.file_context.codemod_changes.extend(visitor.changes_in_file) | ||
if visitor.changes_in_file: | ||
RemoveImportsVisitor.remove_unused_import_by_node(self.context, tree) | ||
# TODO: add dependency | ||
|
||
return result_tree |
Oops, something went wrong.