diff --git a/src/codemodder/cli.py b/src/codemodder/cli.py index 271dabe9..825a8d28 100644 --- a/src/codemodder/cli.py +++ b/src/codemodder/cli.py @@ -43,7 +43,7 @@ class CsvListAction(argparse.Action): argparse Action to convert "a,b,c" into ["a", "b", "c"] """ - def __call__(self, parser, namespace, values: str, option_string=None): + def __call__(self, parser, namespace, values, option_string=None): # Conversion to dict removes duplicates while preserving order items = list(dict.fromkeys(values.split(",")).keys()) self.validate_items(items) diff --git a/src/codemodder/codemods/imported_call_modifier.py b/src/codemodder/codemods/imported_call_modifier.py new file mode 100644 index 00000000..b27c87b5 --- /dev/null +++ b/src/codemodder/codemods/imported_call_modifier.py @@ -0,0 +1,109 @@ +import abc +from typing import Generic, Mapping, Sequence, Set, TypeVar, Union + +import libcst as cst +from libcst import matchers +from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand +from libcst.metadata import PositionProvider + +from codemodder.change import Change +from codemodder.codemods.utils_mixin import NameResolutionMixin +from codemodder.file_context import FileContext + + +# It seems to me like we actually want two separate bounds instead of a Union but this is what mypy wants +FunctionMatchType = TypeVar("FunctionMatchType", bound=Union[Mapping, Set]) + + +class ImportedCallModifier( + Generic[FunctionMatchType], + VisitorBasedCodemodCommand, + NameResolutionMixin, + metaclass=abc.ABCMeta, +): + METADATA_DEPENDENCIES = (PositionProvider,) + + def __init__( + self, + codemod_context: CodemodContext, + file_context: FileContext, + matching_functions: FunctionMatchType, + change_description: str, + ): + super().__init__(codemod_context) + self.line_exclude = file_context.line_exclude + self.line_include = file_context.line_include + self.matching_functions: FunctionMatchType = matching_functions + self.change_description = change_description + self.changes_in_file: list[Mapping] = [] + + def updated_args(self, original_args: Sequence[cst.Arg]): + return original_args + + @abc.abstractmethod + def update_attribute( + self, + true_name: str, + original_node: cst.Call, + updated_node: cst.Call, + new_args: Sequence[cst.Arg], + ): + """Callback to modify tree when the detected call is of the form a.call()""" + + @abc.abstractmethod + def update_simple_name( + self, + true_name: str, + original_node: cst.Call, + updated_node: cst.Call, + new_args: Sequence[cst.Arg], + ): + """Callback to modify tree when the detected call is of the form call()""" + + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): + pos_to_match = self.node_position(original_node) + line_number = pos_to_match.start.line + if self.filter_by_path_includes_or_excludes(pos_to_match): + true_name = self.find_base_name(original_node.func) + if ( + self._is_direct_call_from_imported_module(original_node) + and true_name + and true_name in self.matching_functions + ): + self.changes_in_file.append( + Change(str(line_number), self.change_description).to_json() + ) + + new_args = self.updated_args(updated_node.args) + + # has a prefix, e.g. a.call() -> a.new_call() + if matchers.matches(original_node.func, matchers.Attribute()): + return self.update_attribute( + true_name, original_node, updated_node, new_args + ) + + # it is a simple name, e.g. call() -> module.new_call() + return self.update_simple_name( + true_name, original_node, updated_node, new_args + ) + + return updated_node + + def filter_by_path_includes_or_excludes(self, pos_to_match): + """ + Returns False if the node, whose position in the file is pos_to_match, matches any of the lines specified in the path-includes or path-excludes flags. + """ + # excludes takes precedence if defined + if self.line_exclude: + return not any(match_line(pos_to_match, line) for line in self.line_exclude) + if self.line_include: + return any(match_line(pos_to_match, line) for line in self.line_include) + return True + + def node_position(self, node): + # See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112 + return self.get_metadata(PositionProvider, node) + + +def match_line(pos, line): + return pos.start.line == line and pos.end.line == line diff --git a/src/core_codemods/https_connection.py b/src/core_codemods/https_connection.py index 0ca170c8..6f846102 100644 --- a/src/core_codemods/https_connection.py +++ b/src/core_codemods/https_connection.py @@ -1,23 +1,55 @@ -from typing import Sequence -from libcst import matchers +from typing import Sequence, Set + +import libcst as cst +from libcst.codemod import Codemod, CodemodContext from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor -from libcst.metadata import ( - PositionProvider, -) +from libcst.metadata import PositionProvider + from codemodder.codemods.base_codemod import ( BaseCodemod, CodemodMetadata, ReviewGuidance, ) -from codemodder.change import Change -from codemodder.codemods.utils_mixin import NameResolutionMixin -from codemodder.file_context import FileContext -import libcst as cst -from libcst.codemod import ( - Codemod, - CodemodContext, - VisitorBasedCodemodCommand, -) +from codemodder.codemods.imported_call_modifier import ImportedCallModifier + + +class HTTPSConnectionModifier(ImportedCallModifier[Set[str]]): + def updated_args(self, original_args): + """ + Last argument _proxy_config does not match new method + + We convert it to keyword + """ + new_args = list(original_args) + if self.count_positional_args(new_args) == 10: + new_args[9] = new_args[9].with_changes( + keyword=cst.parse_expression("_proxy_config") + ) + return new_args + + def update_attribute(self, true_name, original_node, updated_node, new_args): + del true_name, original_node + return updated_node.with_changes( + args=new_args, + func=updated_node.func.with_changes( + attr=cst.Name(value="HTTPSConnectionPool") + ), + ) + + def update_simple_name(self, true_name, original_node, updated_node, new_args): + del true_name + AddImportsVisitor.add_needed_import(self.context, "urllib3") + RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node) + return updated_node.with_changes( + args=new_args, + func=cst.parse_expression("urllib3.HTTPSConnectionPool"), + ) + + def count_positional_args(self, arglist: Sequence[cst.Arg]) -> int: + for idx, arg in enumerate(arglist): + if arg.keyword: + return idx + return len(arglist) class HTTPSConnection(BaseCodemod, Codemod): @@ -43,7 +75,7 @@ class HTTPSConnection(BaseCodemod, Codemod): METADATA_DEPENDENCIES = (PositionProvider,) - matching_functions = { + matching_functions: set[str] = { "urllib3.HTTPConnectionPool", "urllib3.connectionpool.HTTPConnectionPool", } @@ -53,83 +85,12 @@ def __init__(self, codemod_context: CodemodContext, *codemod_args): BaseCodemod.__init__(self, *codemod_args) def transform_module_impl(self, tree: cst.Module) -> cst.Module: - visitor = ConnectionPollVisitor(self.context, self.file_context) + visitor = HTTPSConnectionModifier( + self.context, + self.file_context, + self.matching_functions, + self.CHANGE_DESCRIPTION, + ) result_tree = visitor.transform_module(tree) self.file_context.codemod_changes.extend(visitor.changes_in_file) return result_tree - - -class ConnectionPollVisitor(VisitorBasedCodemodCommand, NameResolutionMixin): - METADATA_DEPENDENCIES = (PositionProvider,) - - def __init__(self, codemod_context: CodemodContext, file_context: FileContext): - super().__init__(codemod_context) - self.line_exclude = file_context.line_exclude - self.line_include = file_context.line_include - self.changes_in_file: list[Change] = [] - - def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): - pos_to_match = self.node_position(original_node) - line_number = pos_to_match.start.line - if self.filter_by_path_includes_or_excludes(pos_to_match): - true_name = self.find_base_name(original_node.func) - if ( - self._is_direct_call_from_imported_module(original_node) - and true_name in HTTPSConnection.matching_functions - ): - self.changes_in_file.append( - Change( - str(line_number), HTTPSConnection.CHANGE_DESCRIPTION - ).to_json() - ) - # last argument _proxy_config does not match new method - # we convert it to keyword - new_args = list(original_node.args) - if count_positional_args(original_node.args) == 10: - new_args[9] = new_args[9].with_changes( - keyword=cst.parse_expression("_proxy_config") - ) - # has a prefix, e.g. a.call() -> a.new_call() - if matchers.matches(original_node.func, matchers.Attribute()): - return updated_node.with_changes( - args=new_args, - func=updated_node.func.with_changes( - attr=cst.parse_expression("HTTPSConnectionPool") - ), - ) - # it is a simple name, e.g. call() -> module.new_call() - AddImportsVisitor.add_needed_import(self.context, "urllib3") - RemoveImportsVisitor.remove_unused_import_by_node( - self.context, original_node - ) - return updated_node.with_changes( - args=new_args, - func=cst.parse_expression("urllib3.HTTPSConnectionPool"), - ) - return updated_node - - def filter_by_path_includes_or_excludes(self, pos_to_match): - """ - Returns False if the node, whose position in the file is pos_to_match, matches any of the lines specified in the path-includes or path-excludes flags. - """ - # excludes takes precedence if defined - if self.line_exclude: - return not any(match_line(pos_to_match, line) for line in self.line_exclude) - if self.line_include: - return any(match_line(pos_to_match, line) for line in self.line_include) - return True - - def node_position(self, node): - # See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112 - return self.get_metadata(PositionProvider, node) - - -def match_line(pos, line): - return pos.start.line == line and pos.end.line == line - - -def count_positional_args(arglist: Sequence[cst.Arg]) -> int: - for i, arg in enumerate(arglist): - if arg.keyword: - return i - return len(arglist) diff --git a/src/core_codemods/use_defused_xml.py b/src/core_codemods/use_defused_xml.py new file mode 100644 index 00000000..a6db5563 --- /dev/null +++ b/src/core_codemods/use_defused_xml.py @@ -0,0 +1,95 @@ +from functools import cached_property +from typing import Mapping + +import libcst as cst +from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor + +from codemodder.codemods.base_codemod import ReviewGuidance +from codemodder.codemods.api import BaseCodemod +from codemodder.codemods.imported_call_modifier import ImportedCallModifier + + +class DefusedXmlModifier(ImportedCallModifier[Mapping[str, str]]): + def update_attribute(self, true_name, _, updated_node, new_args): + import_name = self.matching_functions[true_name] + AddImportsVisitor.add_needed_import(self.context, import_name) + return updated_node.with_changes( + args=new_args, + func=cst.Attribute( + value=cst.parse_expression(import_name), + attr=cst.Name(value=true_name.split(".")[-1]), + ), + ) + + def update_simple_name(self, true_name, _, updated_node, new_args): + import_name = self.matching_functions[true_name] + AddImportsVisitor.add_needed_import(self.context, import_name) + return updated_node.with_changes( + args=new_args, + func=cst.Attribute( + value=cst.parse_expression(import_name), + attr=cst.Name(value=true_name.split(".")[-1]), + ), + ) + + +ETREE_METHODS = ["parse", "fromstring", "iterparse", "XMLParser"] +SAX_METHODS = ["parse", "make_parser", "parseString"] +DOM_METHODS = ["parse", "parseString"] +# TODO: add expat methods? + + +class UseDefusedXml(BaseCodemod): + NAME = "use-defusedxml" + SUMMARY = "Use `defusedxml` for Parsing XML" + REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_REVIEW + DESCRIPTION = "Replace builtin xml method with safe defusedxml method" + REFERENCES = [ + { + "url": "https://docs.python.org/3/library/xml.html#xml-vulnerabilities", + "description": "", + }, + { + "url": "https://docs.python.org/3/library/xml.html#the-defusedxml-package", + "description": "", + }, + { + "url": "https://pypi.org/project/defusedxml/", + "description": "", + }, + { + "url": "https://cheatsheetseries.owasp.org/cheatsheets/XML_External_Entity_Prevention_Cheat_Sheet.html", + "description": "", + }, + ] + + @cached_property + def matching_functions(self) -> dict[str, str]: + """Build a mapping of functions to their defusedxml imports""" + _matching_functions: dict[str, str] = {} + for module, defusedxml, methods in [ + ("xml.etree.ElementTree", "defusedxml.ElementTree", ETREE_METHODS), + ("xml.etree.cElementTree", "defusedxml.ElementTree", ETREE_METHODS), + ("xml.sax", "defusedxml.sax", SAX_METHODS), + ("xml.dom.minidom", "defusedxml.minidom", DOM_METHODS), + ("xml.dom.pulldom", "defusedxml.pulldom", DOM_METHODS), + ]: + _matching_functions.update( + {f"{module}.{method}": defusedxml for method in methods} + ) + return _matching_functions + + def transform_module_impl(self, tree: cst.Module) -> cst.Module: + visitor = DefusedXmlModifier( + self.context, + self.file_context, + self.matching_functions, + self.CHANGE_DESCRIPTION, + ) + result_tree = visitor.transform_module(tree) + self.file_context.codemod_changes.extend(visitor.changes_in_file) + if visitor.changes_in_file: + RemoveImportsVisitor.remove_unused_import_by_node(self.context, tree) + # TODO: add dependency + + return result_tree diff --git a/tests/codemods/base_codemod_test.py b/tests/codemods/base_codemod_test.py index 0fbd4edd..7910d5bf 100644 --- a/tests/codemods/base_codemod_test.py +++ b/tests/codemods/base_codemod_test.py @@ -1,16 +1,18 @@ # pylint: disable=no-member,not-callable,attribute-defined-outside-init +from collections import defaultdict +import os +from pathlib import Path +from textwrap import dedent +from typing import ClassVar + import libcst as cst from libcst.codemod import CodemodContext -from pathlib import Path -import os -from collections import defaultdict +import mock + from codemodder.context import CodemodExecutionContext from codemodder.file_context import FileContext from codemodder.registry import CodemodRegistry, CodemodCollection from codemodder.semgrep import run as semgrep_run -from typing import ClassVar - -import mock class BaseCodemodTest: @@ -24,7 +26,7 @@ def run_and_assert(self, tmpdir, input_code, expected): self.run_and_assert_filepath(tmpdir, tmp_file_path, input_code, expected) def run_and_assert_filepath(self, root, file_path, input_code, expected): - input_tree = cst.parse_module(input_code) + input_tree = cst.parse_module(dedent(input_code)) self.execution_context = CodemodExecutionContext( directory=root, dry_run=True, @@ -45,7 +47,7 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected): ) output_tree = command_instance.transform_module(input_tree) - assert output_tree.code == expected + assert output_tree.code == dedent(expected) class BaseSemgrepCodemodTest(BaseCodemodTest): diff --git a/tests/codemods/test_base_codemod.py b/tests/codemods/test_base_codemod.py index 65f2b590..3d71663b 100644 --- a/tests/codemods/test_base_codemod.py +++ b/tests/codemods/test_base_codemod.py @@ -1,6 +1,7 @@ import libcst as cst from libcst.codemod import Codemod, CodemodContext import mock + from codemodder.codemods.base_codemod import ( SemgrepCodemod, CodemodMetadata, diff --git a/tests/codemods/test_use_defused_xml.py b/tests/codemods/test_use_defused_xml.py new file mode 100644 index 00000000..2e9549e5 --- /dev/null +++ b/tests/codemods/test_use_defused_xml.py @@ -0,0 +1,127 @@ +import pytest + +from core_codemods.use_defused_xml import ( + DOM_METHODS, + ETREE_METHODS, + SAX_METHODS, + UseDefusedXml, +) +from tests.codemods.base_codemod_test import BaseCodemodTest + + +class TestUseDefusedXml(BaseCodemodTest): + codemod = UseDefusedXml + + @pytest.mark.parametrize("method", ETREE_METHODS) + @pytest.mark.parametrize("module", ["ElementTree", "cElementTree"]) + def test_etree_simple_call(self, tmpdir, module, method): + original_code = f""" + from xml.etree.{module} import {method}, ElementPath + + et = {method}('some.xml') + """ + + new_code = f""" + from xml.etree.{module} import ElementPath + import defusedxml.ElementTree + + et = defusedxml.ElementTree.{method}('some.xml') + """ + + self.run_and_assert(tmpdir, original_code, new_code) + + @pytest.mark.parametrize("method", ETREE_METHODS) + @pytest.mark.parametrize("module", ["ElementTree", "cElementTree"]) + def test_etree_attribute_call(self, tmpdir, module, method): + original_code = f""" + from xml.etree import {module} + + et = {module}.{method}('some.xml') + """ + + new_code = f""" + import defusedxml.ElementTree + + et = defusedxml.ElementTree.{method}('some.xml') + """ + + self.run_and_assert(tmpdir, original_code, new_code) + + def test_etree_elementtree_with_alias(self, tmpdir): + original_code = """ + from xml.etree import ElementTree as ET + + et = ET.parse('some.xml') + """ + + new_code = """ + import defusedxml.ElementTree + + et = defusedxml.ElementTree.parse('some.xml') + """ + + self.run_and_assert(tmpdir, original_code, new_code) + + def test_etree_parse_with_alias(self, tmpdir): + original_code = """ + from xml.etree.ElementTree import parse as parse_xml + + et = parse_xml('some.xml') + """ + + new_code = """ + import defusedxml.ElementTree + + et = defusedxml.ElementTree.parse('some.xml') + """ + + self.run_and_assert(tmpdir, original_code, new_code) + + @pytest.mark.parametrize("method", SAX_METHODS) + def test_sax_simple_call(self, tmpdir, method): + original_code = f""" + from xml.sax import {method} + + et = {method}('some.xml') + """ + + new_code = f""" + import defusedxml.sax + + et = defusedxml.sax.{method}('some.xml') + """ + + self.run_and_assert(tmpdir, original_code, new_code) + + @pytest.mark.parametrize("method", SAX_METHODS) + def test_sax_attribute_call(self, tmpdir, method): + original_code = f""" + from xml import sax + + et = sax.{method}('some.xml') + """ + + new_code = f""" + import defusedxml.sax + + et = defusedxml.sax.{method}('some.xml') + """ + + self.run_and_assert(tmpdir, original_code, new_code) + + @pytest.mark.parametrize("method", DOM_METHODS) + @pytest.mark.parametrize("module", ["minidom", "pulldom"]) + def test_dom_simple_call(self, tmpdir, module, method): + original_code = f""" + from xml.dom.{module} import {method} + + et = {method}('some.xml') + """ + + new_code = f""" + import defusedxml.{module} + + et = defusedxml.{module}.{method}('some.xml') + """ + + self.run_and_assert(tmpdir, original_code, new_code)