Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement codemod for replacing unsafe xml methods with defusedxml #76

Merged
merged 2 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/codemodder/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class CsvListAction(argparse.Action):
argparse Action to convert "a,b,c" into ["a", "b", "c"]
"""

def __call__(self, parser, namespace, values: str, option_string=None):
def __call__(self, parser, namespace, values, option_string=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was mypy, right? Isn't it weird that we fix a typing problem by ... removing typing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I (wrongly) introduced this on a previous PR when mypy wasn't working for me. The underlying code itself is not typed correctly and I thought adding this would fix it (but it does not).

# Conversion to dict removes duplicates while preserving order
items = list(dict.fromkeys(values.split(",")).keys())
self.validate_items(items)
Expand Down
109 changes: 109 additions & 0 deletions src/codemodder/codemods/imported_call_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import abc
from typing import Generic, Mapping, Sequence, Set, TypeVar, Union

import libcst as cst
from libcst import matchers
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.metadata import PositionProvider

from codemodder.change import Change
from codemodder.codemods.utils_mixin import NameResolutionMixin
from codemodder.file_context import FileContext


# It seems to me like we actually want two separate bounds instead of a Union but this is what mypy wants
FunctionMatchType = TypeVar("FunctionMatchType", bound=Union[Mapping, Set])


class ImportedCallModifier(
Generic[FunctionMatchType],
VisitorBasedCodemodCommand,
NameResolutionMixin,
metaclass=abc.ABCMeta,
):
METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(
self,
codemod_context: CodemodContext,
file_context: FileContext,
matching_functions: FunctionMatchType,
change_description: str,
):
super().__init__(codemod_context)
self.line_exclude = file_context.line_exclude
self.line_include = file_context.line_include
self.matching_functions: FunctionMatchType = matching_functions
self.change_description = change_description
self.changes_in_file: list[Mapping] = []

def updated_args(self, original_args: Sequence[cst.Arg]):
return original_args

@abc.abstractmethod
def update_attribute(
self,
true_name: str,
original_node: cst.Call,
updated_node: cst.Call,
new_args: Sequence[cst.Arg],
):
"""Callback to modify tree when the detected call is of the form a.call()"""

@abc.abstractmethod
def update_simple_name(
self,
true_name: str,
original_node: cst.Call,
updated_node: cst.Call,
new_args: Sequence[cst.Arg],
):
"""Callback to modify tree when the detected call is of the form call()"""

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
pos_to_match = self.node_position(original_node)
line_number = pos_to_match.start.line
if self.filter_by_path_includes_or_excludes(pos_to_match):
drdavella marked this conversation as resolved.
Show resolved Hide resolved
true_name = self.find_base_name(original_node.func)
if (
self._is_direct_call_from_imported_module(original_node)
and true_name
and true_name in self.matching_functions
):
self.changes_in_file.append(
Change(str(line_number), self.change_description).to_json()
)

new_args = self.updated_args(updated_node.args)

# has a prefix, e.g. a.call() -> a.new_call()
if matchers.matches(original_node.func, matchers.Attribute()):
return self.update_attribute(
true_name, original_node, updated_node, new_args
)

# it is a simple name, e.g. call() -> module.new_call()
return self.update_simple_name(
true_name, original_node, updated_node, new_args
)
Comment on lines +85 to +88
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small clarification om this bit of code here.
Within the context of https codemod this comment and logic was accurate. But this may not be general enough.
Consider:

from lib import call
call()

We have the option to simply change the call...

from lib import call
import lib2
lib2.call2()

or changing the call and import:

from lib2 import call2
call2()

The second has an advantage of not having a possibly useless import (not much of a problem since we have a codemod to amend that). Then again, then reason why I've favored the first one for the https codemod was because the change became more explicit.

The second transformation is a bit harder to do but within the realm of possibility. Just tell me if you want me to support that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andrecsilva that makes sense. It might be good to revisit at some point but right now I like that the updated call becomes completely explicit.


return updated_node

def filter_by_path_includes_or_excludes(self, pos_to_match):
"""
Returns False if the node, whose position in the file is pos_to_match, matches any of the lines specified in the path-includes or path-excludes flags.
"""
# excludes takes precedence if defined
if self.line_exclude:
return not any(match_line(pos_to_match, line) for line in self.line_exclude)

Check warning on line 98 in src/codemodder/codemods/imported_call_modifier.py

View check run for this annotation

Codecov / codecov/patch

src/codemodder/codemods/imported_call_modifier.py#L98

Added line #L98 was not covered by tests
if self.line_include:
return any(match_line(pos_to_match, line) for line in self.line_include)

Check warning on line 100 in src/codemodder/codemods/imported_call_modifier.py

View check run for this annotation

Codecov / codecov/patch

src/codemodder/codemods/imported_call_modifier.py#L100

Added line #L100 was not covered by tests
return True

def node_position(self, node):
# See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112
return self.get_metadata(PositionProvider, node)


def match_line(pos, line):
return pos.start.line == line and pos.end.line == line

Check warning on line 109 in src/codemodder/codemods/imported_call_modifier.py

View check run for this annotation

Codecov / codecov/patch

src/codemodder/codemods/imported_call_modifier.py#L109

Added line #L109 was not covered by tests
189 changes: 68 additions & 121 deletions src/core_codemods/https_connection.py
Original file line number Diff line number Diff line change
@@ -1,135 +1,82 @@
from typing import Sequence
from libcst import matchers
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from libcst.metadata import (
PositionProvider,
)
from codemodder.codemods.base_codemod import (
BaseCodemod,
CodemodMetadata,
ReviewGuidance,
)
from codemodder.change import Change
from codemodder.codemods.utils_mixin import NameResolutionMixin
from codemodder.file_context import FileContext
from typing import Sequence, Set

import libcst as cst
from libcst.codemod import (
Codemod,
CodemodContext,
VisitorBasedCodemodCommand,
)
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from libcst.metadata import PositionProvider

from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import BaseCodemod
from codemodder.codemods.imported_call_modifier import ImportedCallModifier

class HTTPSConnection(BaseCodemod, Codemod):
METADATA = CodemodMetadata(
DESCRIPTION="Enforce HTTPS connection for `urllib3`.",
NAME="https-connection",
REVIEW_GUIDANCE=ReviewGuidance.MERGE_WITHOUT_REVIEW,
REFERENCES=[
{
"url": "https://owasp.org/www-community/vulnerabilities/Insecure_Transport",
"description": "",
},
{
"url": "https://urllib3.readthedocs.io/en/stable/reference/urllib3.connectionpool.html#urllib3.HTTPConnectionPool",
"description": "",
},
],
)
CHANGE_DESCRIPTION = METADATA.DESCRIPTION
SUMMARY = (
"Changes HTTPConnectionPool to HTTPSConnectionPool to Enforce Secure Connection"
)

class HTTPSConnectionModifier(ImportedCallModifier[Set[str]]):
def updated_args(self, original_args):
"""
Last argument _proxy_config does not match new method

We convert it to keyword
"""
new_args = list(original_args)
if self.count_positional_args(new_args) == 10:
new_args[9] = new_args[9].with_changes(
keyword=cst.parse_expression("_proxy_config")
)
return new_args

def update_attribute(self, true_name, original_node, updated_node, new_args):
del true_name, original_node
return updated_node.with_changes(
args=new_args,
func=updated_node.func.with_changes(
attr=cst.Name(value="HTTPSConnectionPool")
),
)

def update_simple_name(self, true_name, original_node, updated_node, new_args):
del true_name
AddImportsVisitor.add_needed_import(self.context, "urllib3")
RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node)
return updated_node.with_changes(
args=new_args,
func=cst.parse_expression("urllib3.HTTPSConnectionPool"),
)

def count_positional_args(self, arglist: Sequence[cst.Arg]) -> int:
for idx, arg in enumerate(arglist):
if arg.keyword:
return idx

Check warning on line 47 in src/core_codemods/https_connection.py

View check run for this annotation

Codecov / codecov/patch

src/core_codemods/https_connection.py#L47

Added line #L47 was not covered by tests
return len(arglist)


class HTTPSConnection(BaseCodemod):
SUMMARY = "Enforce HTTPS Connection for `urllib3`"
NAME = "https-connection"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW
REFERENCES = [
{
"url": "https://owasp.org/www-community/vulnerabilities/Insecure_Transport",
"description": "",
},
{
"url": "https://urllib3.readthedocs.io/en/stable/reference/urllib3.connectionpool.html#urllib3.HTTPConnectionPool",
"description": "",
},
]

METADATA_DEPENDENCIES = (PositionProvider,)

matching_functions = {
matching_functions: set[str] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could have this be a keys-only dict. It would feel more consistent.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not aware of any such construct; do you just mean a dict where the values are None?

"urllib3.HTTPConnectionPool",
"urllib3.connectionpool.HTTPConnectionPool",
}

def __init__(self, codemod_context: CodemodContext, *codemod_args):
Codemod.__init__(self, codemod_context)
BaseCodemod.__init__(self, *codemod_args)

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
visitor = ConnectionPollVisitor(self.context, self.file_context)
visitor = HTTPSConnectionModifier(
self.context,
self.file_context,
self.matching_functions,
self.CHANGE_DESCRIPTION,
)
result_tree = visitor.transform_module(tree)
self.file_context.codemod_changes.extend(visitor.changes_in_file)
return result_tree


class ConnectionPollVisitor(VisitorBasedCodemodCommand, NameResolutionMixin):
METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
super().__init__(codemod_context)
self.line_exclude = file_context.line_exclude
self.line_include = file_context.line_include
self.changes_in_file: list[Change] = []

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
pos_to_match = self.node_position(original_node)
line_number = pos_to_match.start.line
if self.filter_by_path_includes_or_excludes(pos_to_match):
true_name = self.find_base_name(original_node.func)
if (
self._is_direct_call_from_imported_module(original_node)
and true_name in HTTPSConnection.matching_functions
):
self.changes_in_file.append(
Change(
str(line_number), HTTPSConnection.CHANGE_DESCRIPTION
).to_json()
)
# last argument _proxy_config does not match new method
# we convert it to keyword
new_args = list(original_node.args)
if count_positional_args(original_node.args) == 10:
new_args[9] = new_args[9].with_changes(
keyword=cst.parse_expression("_proxy_config")
)
# has a prefix, e.g. a.call() -> a.new_call()
if matchers.matches(original_node.func, matchers.Attribute()):
return updated_node.with_changes(
args=new_args,
func=updated_node.func.with_changes(
attr=cst.parse_expression("HTTPSConnectionPool")
),
)
# it is a simple name, e.g. call() -> module.new_call()
AddImportsVisitor.add_needed_import(self.context, "urllib3")
RemoveImportsVisitor.remove_unused_import_by_node(
self.context, original_node
)
return updated_node.with_changes(
args=new_args,
func=cst.parse_expression("urllib3.HTTPSConnectionPool"),
)
return updated_node

def filter_by_path_includes_or_excludes(self, pos_to_match):
"""
Returns False if the node, whose position in the file is pos_to_match, matches any of the lines specified in the path-includes or path-excludes flags.
"""
# excludes takes precedence if defined
if self.line_exclude:
return not any(match_line(pos_to_match, line) for line in self.line_exclude)
if self.line_include:
return any(match_line(pos_to_match, line) for line in self.line_include)
return True

def node_position(self, node):
# See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112
return self.get_metadata(PositionProvider, node)


def match_line(pos, line):
return pos.start.line == line and pos.end.line == line


def count_positional_args(arglist: Sequence[cst.Arg]) -> int:
for i, arg in enumerate(arglist):
if arg.keyword:
return i
return len(arglist)
Loading