Skip to content

Commit

Permalink
handle no args and use api
Browse files Browse the repository at this point in the history
  • Loading branch information
clavedeluna committed Oct 27, 2023
1 parent 3d9991d commit ad97dd1
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 121 deletions.
13 changes: 10 additions & 3 deletions src/codemodder/codemods/api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,27 @@ def replace_args(self, original_node, args_info):
for arg in original_node.args:
arg_name, replacement_val, idx = _match_with_existing_arg(arg, args_info)
if arg_name is not None:
new = self.make_new_arg(arg_name, replacement_val, arg)
new = self.make_new_arg(replacement_val, arg_name, arg)
del args_info[idx]
else:
new = arg
new_args.append(new)

for arg_name, replacement_val, add_if_missing in args_info:
if add_if_missing:
new = self.make_new_arg(arg_name, replacement_val)
new = self.make_new_arg(replacement_val, arg_name)
new_args.append(new)

return new_args

def make_new_arg(self, name, value, existing_arg=None):
def make_new_arg(self, value, name=None, existing_arg=None):
if name is None:
# Make a positional argument
return cst.Arg(
value=cst.parse_expression(value),
)

# make a keyword argument
equal = (
existing_arg.equal
if existing_arg
Expand Down
22 changes: 0 additions & 22 deletions src/core_codemods/semgrep/upgrade_sslcontext_tls.yaml

This file was deleted.

149 changes: 57 additions & 92 deletions src/core_codemods/upgrade_sslcontext_tls.py
Original file line number Diff line number Diff line change
@@ -1,107 +1,72 @@
import libcst as cst
from libcst.codemod import CodemodContext
from codemodder.codemods.base_visitor import BaseTransformer
from codemodder.codemods.base_codemod import (
SemgrepCodemod,
CodemodMetadata,
ReviewGuidance,
)
from codemodder.change import Change
from codemodder.file_context import FileContext
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.api.helpers import NewArg


class UpgradeSSLContextTLS(SemgrepCodemod, BaseTransformer):
METADATA = CodemodMetadata(
DESCRIPTION="Replaces known insecure TLS/SSL protocol versions in SSLContext with secure ones.",
NAME="upgrade-sslcontext-tls",
REVIEW_GUIDANCE=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW,
REFERENCES=[
{
"url": "https://docs.python.org/3/library/ssl.html#security-considerations",
"description": "",
},
{"url": "https://datatracker.ietf.org/doc/rfc8996/", "description": ""},
{
"url": "https://www.digicert.com/blog/depreciating-tls-1-0-and-1-1",
"description": "",
},
],
)
class UpgradeSSLContextTLS(SemgrepCodemod):
NAME = "upgrade-sslcontext-tls"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW
SUMMARY = "Upgrade TLS Version In SSLContext"
DESCRIPTION = "Replaces known insecure TLS/SSL protocol versions in SSLContext with secure ones."
CHANGE_DESCRIPTION = "Upgrade to use a safe version of TLS in SSLContext"
YAML_FILES = [
"upgrade_sslcontext_tls.yaml",
REFERENCES = [
{
"url": "https://docs.python.org/3/library/ssl.html#security-considerations",
"description": "",
},
{"url": "https://datatracker.ietf.org/doc/rfc8996/", "description": ""},
{
"url": "https://www.digicert.com/blog/depreciating-tls-1-0-and-1-1",
"description": "",
},
]

# TODO: in the majority of cases, using PROTOCOL_TLS_CLIENT will be the
# right fix. However in some cases it will be appropriate to use
# PROTOCOL_TLS_SERVER instead. We currently don't have a good way to handle
# this. Eventually, when the platform supports parameters, we want to
# revisit this to provide PROTOCOL_TLS_SERVER as an alternative fix.
SAFE_TLS_PROTOCOL_VERSION = "PROTOCOL_TLS_CLIENT"
PROTOCOL_ARG_INDEX = 0
PROTOCOL_KWARG_NAME = "protocol"
SAFE_TLS_PROTOCOL_VERSION = "ssl.PROTOCOL_TLS_CLIENT"
# PROTOCOL_ARG_INDEX = 0
# PROTOCOL_KWARG_NAME = "protocol"

def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
SemgrepCodemod.__init__(self, file_context)
BaseTransformer.__init__(self, codemod_context, self._results)
@classmethod
def rule(cls):
return """
rules:
- patterns:
- pattern-inside: |
import ssl
...
- pattern-either:
- pattern: ssl.SSLContext()
- pattern: ssl.SSLContext(...,ssl.PROTOCOL_SSLv2,...)
- pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_SSLv2,...)
- pattern: ssl.SSLContext(...,ssl.PROTOCOL_SSLv3,...)
- pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_SSLv3,...)
- pattern: ssl.SSLContext(...,ssl.PROTOCOL_TLSv1,...)
- pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_TLSv1,...)
- pattern: ssl.SSLContext(...,ssl.PROTOCOL_TLSv1_1,...)
- pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_TLSv1_1,...)
- pattern: ssl.SSLContext(...,ssl.PROTOCOL_TLS,...)
- pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_TLS,...)
"""

# TODO: apply unused import remover
def on_result_found(self, original_node, updated_node):
self.remove_unused_import(original_node)
self.add_needed_import("ssl")

def update_arg(self, arg: cst.Arg) -> cst.Arg:
new_name = cst.Name(self.SAFE_TLS_PROTOCOL_VERSION)
# TODO: are there other cases to handle here?
new_value = (
arg.value.with_changes(attr=new_name)
if isinstance(arg.value, cst.Attribute)
else new_name
)
return arg.with_changes(value=new_value)

def leave_Call(self, original_node: cst.Call, updated_node: cst.Arg):
pos_to_match = self.get_metadata(self.METADATA_DEPENDENCIES[0], original_node)
if self.filter_by_result(
pos_to_match
) and self.filter_by_path_includes_or_excludes(pos_to_match):
line_number = pos_to_match.start.line
self.file_context.codemod_changes.append(
Change(line_number, self.CHANGE_DESCRIPTION)
)

if not updated_node.args:
return updated_node.with_changes(
args=[
self.make_new_arg(
self.PROTOCOL_KWARG_NAME,
f"ssl.{self.SAFE_TLS_PROTOCOL_VERSION}",
)
]
)

return updated_node.with_changes(
args=[
self.update_arg(arg)
if idx == self.PROTOCOL_ARG_INDEX
or (arg.keyword and arg.keyword.value == self.PROTOCOL_KWARG_NAME)
else arg
for idx, arg in enumerate(original_node.args)
]
)

return updated_node

# dedupe with api
def make_new_arg(self, name, value, existing_arg=None):
equal = (
existing_arg.equal
if existing_arg
else cst.AssignEqual(
whitespace_before=cst.SimpleWhitespace(""),
whitespace_after=cst.SimpleWhitespace(""),
if len((args := original_node.args)) == 1 and args[0].keyword is None:
new_args = [self.make_new_arg(self.SAFE_TLS_PROTOCOL_VERSION)]
else:
new_args = self.replace_args(
original_node,
[
NewArg(
name="protocol",
value=self.SAFE_TLS_PROTOCOL_VERSION,
add_if_missing=True,
)
],
)
)
return cst.Arg(
keyword=cst.parse_expression(name),
value=cst.parse_expression(value),
equal=equal,
)
return self.update_arg_target(updated_node, new_args)
9 changes: 5 additions & 4 deletions tests/codemods/test_upgrade_sslcontext_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ def test_upgrade_protocol_with_kwarg_import_alias(self, tmpdir, protocol):
var = "hello"
"""
expected_output = """import ssl as whatever
import ssl
context = whatever.SSLContext(protocol=whatever.PROTOCOL_TLS_CLIENT)
context = whatever.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
var = "hello"
"""
self.run_and_assert(tmpdir, input_code, expected_output)
Expand All @@ -116,21 +117,21 @@ def test_upgrade_protocol_in_expression_do_not_modify(self, tmpdir):
self.run_and_assert(tmpdir, input_code, expected_output)

def test_import_no_protocol(self, tmpdir):
input_code = f"""import ssl
input_code = """import ssl
context = ssl.SSLContext()
"""
expected_output = """import ssl
context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

@pytest.mark.skip()
def test_from_import_no_protocol(self, tmpdir):
input_code = f"""from ssl import SSLContext
input_code = """from ssl import SSLContext
SSLContext()
"""
expected_output = """from ssl import SSLContext
import ssl
SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

0 comments on commit ad97dd1

Please sign in to comment.