diff --git a/src/codemodder/codemods/api/helpers.py b/src/codemodder/codemods/api/helpers.py index 93073810..590725c1 100644 --- a/src/codemodder/codemods/api/helpers.py +++ b/src/codemodder/codemods/api/helpers.py @@ -67,7 +67,7 @@ def replace_args(self, original_node, args_info): for arg in original_node.args: arg_name, replacement_val, idx = _match_with_existing_arg(arg, args_info) if arg_name is not None: - new = self.make_new_arg(arg_name, replacement_val, arg) + new = self.make_new_arg(replacement_val, arg_name, arg) del args_info[idx] else: new = arg @@ -75,12 +75,19 @@ def replace_args(self, original_node, args_info): for arg_name, replacement_val, add_if_missing in args_info: if add_if_missing: - new = self.make_new_arg(arg_name, replacement_val) + new = self.make_new_arg(replacement_val, arg_name) new_args.append(new) return new_args - def make_new_arg(self, name, value, existing_arg=None): + def make_new_arg(self, value, name=None, existing_arg=None): + if name is None: + # Make a positional argument + return cst.Arg( + value=cst.parse_expression(value), + ) + + # make a keyword argument equal = ( existing_arg.equal if existing_arg diff --git a/src/core_codemods/docs/pixee_python_upgrade-sslcontext-tls.md b/src/core_codemods/docs/pixee_python_upgrade-sslcontext-tls.md index 11212b8d..7e94f9f2 100644 --- a/src/core_codemods/docs/pixee_python_upgrade-sslcontext-tls.md +++ b/src/core_codemods/docs/pixee_python_upgrade-sslcontext-tls.md @@ -1,12 +1,15 @@ This codemod replaces the use of all unsafe and/or deprecated SSL/TLS versions in the `ssl.SSLContext` constructor. It uses `PROTOCOL_TLS_CLIENT` instead, -which ensures a safe default TLS version. +which ensures a safe default TLS version. It also sets the `protocol` parameter +to `PROTOCOL_TLS_CLIENT` in calls without it, which is now deprecated. Our change involves modifying the argument to `ssl.SSLContext()` to use `PROTOCOL_TLS_CLIENT`. ```diff import ssl +- context = ssl.SSLContext() ++ context = ssl.SSLContext(protocol=PROTOCOL_TLS_CLIENT) - context = ssl.SSLContext(protocol=PROTOCOL_SSLv3) + context = ssl.SSLContext(protocol=PROTOCOL_TLS_CLIENT) ``` diff --git a/src/core_codemods/semgrep/upgrade_sslcontext_tls.yaml b/src/core_codemods/semgrep/upgrade_sslcontext_tls.yaml deleted file mode 100644 index f31d6d66..00000000 --- a/src/core_codemods/semgrep/upgrade_sslcontext_tls.yaml +++ /dev/null @@ -1,21 +0,0 @@ -rules: - - id: upgrade-sslcontext-tls - message: Upgrade weak SSL/TLS protocol version in SSLContext - severity: WARNING - languages: - - python - patterns: - - pattern-either: - - pattern: ssl.SSLContext(...,ssl.PROTOCOL_SSLv2,...) - - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_SSLv2,...) - - pattern: ssl.SSLContext(...,ssl.PROTOCOL_SSLv3,...) - - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_SSLv3,...) - - pattern: ssl.SSLContext(...,ssl.PROTOCOL_TLSv1,...) - - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_TLSv1,...) - - pattern: ssl.SSLContext(...,ssl.PROTOCOL_TLSv1_1,...) - - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_TLSv1_1,...) - - pattern: ssl.SSLContext(...,ssl.PROTOCOL_TLS,...) - - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_TLS,...) - - pattern-inside: | - import ssl - ... diff --git a/src/core_codemods/upgrade_sslcontext_tls.py b/src/core_codemods/upgrade_sslcontext_tls.py index 935500d3..12187838 100644 --- a/src/core_codemods/upgrade_sslcontext_tls.py +++ b/src/core_codemods/upgrade_sslcontext_tls.py @@ -1,36 +1,24 @@ -import libcst as cst -from libcst.codemod import CodemodContext -from codemodder.codemods.base_visitor import BaseTransformer -from codemodder.codemods.base_codemod import ( - SemgrepCodemod, - CodemodMetadata, - ReviewGuidance, -) -from codemodder.change import Change -from codemodder.file_context import FileContext +from codemodder.codemods.base_codemod import ReviewGuidance +from codemodder.codemods.api import SemgrepCodemod +from codemodder.codemods.api.helpers import NewArg -class UpgradeSSLContextTLS(SemgrepCodemod, BaseTransformer): - METADATA = CodemodMetadata( - DESCRIPTION="Replaces known insecure TLS/SSL protocol versions in SSLContext with secure ones.", - NAME="upgrade-sslcontext-tls", - REVIEW_GUIDANCE=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, - REFERENCES=[ - { - "url": "https://docs.python.org/3/library/ssl.html#security-considerations", - "description": "", - }, - {"url": "https://datatracker.ietf.org/doc/rfc8996/", "description": ""}, - { - "url": "https://www.digicert.com/blog/depreciating-tls-1-0-and-1-1", - "description": "", - }, - ], - ) +class UpgradeSSLContextTLS(SemgrepCodemod): + NAME = "upgrade-sslcontext-tls" + REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW SUMMARY = "Upgrade TLS Version In SSLContext" + DESCRIPTION = "Replaces known insecure TLS/SSL protocol versions in SSLContext with secure ones." CHANGE_DESCRIPTION = "Upgrade to use a safe version of TLS in SSLContext" - YAML_FILES = [ - "upgrade_sslcontext_tls.yaml", + REFERENCES = [ + { + "url": "https://docs.python.org/3/library/ssl.html#security-considerations", + "description": "", + }, + {"url": "https://datatracker.ietf.org/doc/rfc8996/", "description": ""}, + { + "url": "https://www.digicert.com/blog/depreciating-tls-1-0-and-1-1", + "description": "", + }, ] # TODO: in the majority of cases, using PROTOCOL_TLS_CLIENT will be the @@ -38,44 +26,45 @@ class UpgradeSSLContextTLS(SemgrepCodemod, BaseTransformer): # PROTOCOL_TLS_SERVER instead. We currently don't have a good way to handle # this. Eventually, when the platform supports parameters, we want to # revisit this to provide PROTOCOL_TLS_SERVER as an alternative fix. - SAFE_TLS_PROTOCOL_VERSION = "PROTOCOL_TLS_CLIENT" - PROTOCOL_ARG_INDEX = 0 - PROTOCOL_KWARG_NAME = "protocol" + SAFE_TLS_PROTOCOL_VERSION = "ssl.PROTOCOL_TLS_CLIENT" - def __init__(self, codemod_context: CodemodContext, file_context: FileContext): - SemgrepCodemod.__init__(self, file_context) - BaseTransformer.__init__(self, codemod_context, self._results) + @classmethod + def rule(cls): + return """ + rules: + - patterns: + - pattern-inside: | + import ssl + ... + - pattern-either: + - pattern: ssl.SSLContext() + - pattern: ssl.SSLContext(...,ssl.PROTOCOL_SSLv2,...) + - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_SSLv2,...) + - pattern: ssl.SSLContext(...,ssl.PROTOCOL_SSLv3,...) + - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_SSLv3,...) + - pattern: ssl.SSLContext(...,ssl.PROTOCOL_TLSv1,...) + - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_TLSv1,...) + - pattern: ssl.SSLContext(...,ssl.PROTOCOL_TLSv1_1,...) + - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_TLSv1_1,...) + - pattern: ssl.SSLContext(...,ssl.PROTOCOL_TLS,...) + - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_TLS,...) + """ - # TODO: apply unused import remover + def on_result_found(self, original_node, updated_node): + self.remove_unused_import(original_node) + self.add_needed_import("ssl") - def update_arg(self, arg: cst.Arg) -> cst.Arg: - new_name = cst.Name(self.SAFE_TLS_PROTOCOL_VERSION) - # TODO: are there other cases to handle here? - new_value = ( - arg.value.with_changes(attr=new_name) - if isinstance(arg.value, cst.Attribute) - else new_name - ) - return arg.with_changes(value=new_value) - - def leave_Call(self, original_node: cst.Call, updated_node: cst.Arg): - pos_to_match = self.get_metadata(self.METADATA_DEPENDENCIES[0], original_node) - if self.filter_by_result( - pos_to_match - ) and self.filter_by_path_includes_or_excludes(pos_to_match): - line_number = pos_to_match.start.line - self.file_context.codemod_changes.append( - Change(line_number, self.CHANGE_DESCRIPTION) - ) - - return updated_node.with_changes( - args=[ - self.update_arg(arg) - if idx == self.PROTOCOL_ARG_INDEX - or (arg.keyword and arg.keyword.value == self.PROTOCOL_KWARG_NAME) - else arg - for idx, arg in enumerate(original_node.args) - ] + if len((args := original_node.args)) == 1 and args[0].keyword is None: + new_args = [self.make_new_arg(self.SAFE_TLS_PROTOCOL_VERSION)] + else: + new_args = self.replace_args( + original_node, + [ + NewArg( + name="protocol", + value=self.SAFE_TLS_PROTOCOL_VERSION, + add_if_missing=True, + ) + ], ) - - return updated_node + return self.update_arg_target(updated_node, new_args) diff --git a/tests/codemods/test_upgrade_sslcontext_tls.py b/tests/codemods/test_upgrade_sslcontext_tls.py index eec1c9a5..caf5cfe4 100644 --- a/tests/codemods/test_upgrade_sslcontext_tls.py +++ b/tests/codemods/test_upgrade_sslcontext_tls.py @@ -89,8 +89,9 @@ def test_upgrade_protocol_with_kwarg_import_alias(self, tmpdir, protocol): var = "hello" """ expected_output = """import ssl as whatever +import ssl -context = whatever.SSLContext(protocol=whatever.PROTOCOL_TLS_CLIENT) +context = whatever.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) var = "hello" """ self.run_and_assert(tmpdir, input_code, expected_output) @@ -114,3 +115,23 @@ def test_upgrade_protocol_in_expression_do_not_modify(self, tmpdir): expected_output = input_code self.run_and_assert(tmpdir, input_code, expected_output) + + def test_import_no_protocol(self, tmpdir): + input_code = """import ssl +context = ssl.SSLContext() +""" + expected_output = """import ssl +context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) +""" + self.run_and_assert(tmpdir, input_code, expected_output) + + def test_from_import_no_protocol(self, tmpdir): + input_code = """from ssl import SSLContext +SSLContext() +""" + expected_output = """from ssl import SSLContext +import ssl + +SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) +""" + self.run_and_assert(tmpdir, input_code, expected_output)