Skip to content

Commit

Permalink
support empty call
Browse files Browse the repository at this point in the history
  • Loading branch information
clavedeluna committed Oct 27, 2023
1 parent b5039db commit 3d9991d
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
This codemod replaces the use of all unsafe and/or deprecated SSL/TLS versions
in the `ssl.SSLContext` constructor. It uses `PROTOCOL_TLS_CLIENT` instead,
which ensures a safe default TLS version.
which ensures a safe default TLS version. It also sets the `protocol` parameter
to `PROTOCOL_TLS_CLIENT` in calls without it, which is now deprecated.

Our change involves modifying the argument to `ssl.SSLContext()` to
use `PROTOCOL_TLS_CLIENT`.

```diff
import ssl
- context = ssl.SSLContext()
+ context = ssl.SSLContext(protocol=PROTOCOL_TLS_CLIENT)
- context = ssl.SSLContext(protocol=PROTOCOL_SSLv3)
+ context = ssl.SSLContext(protocol=PROTOCOL_TLS_CLIENT)
```
Expand Down
1 change: 1 addition & 0 deletions src/core_codemods/semgrep/upgrade_sslcontext_tls.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ rules:
- python
patterns:
- pattern-either:
- pattern: ssl.SSLContext()
- pattern: ssl.SSLContext(...,ssl.PROTOCOL_SSLv2,...)
- pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_SSLv2,...)
- pattern: ssl.SSLContext(...,ssl.PROTOCOL_SSLv3,...)
Expand Down
26 changes: 26 additions & 0 deletions src/core_codemods/upgrade_sslcontext_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Arg):
Change(line_number, self.CHANGE_DESCRIPTION)
)

if not updated_node.args:
return updated_node.with_changes(
args=[
self.make_new_arg(
self.PROTOCOL_KWARG_NAME,
f"ssl.{self.SAFE_TLS_PROTOCOL_VERSION}",
)
]
)

return updated_node.with_changes(
args=[
self.update_arg(arg)
Expand All @@ -79,3 +89,19 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Arg):
)

return updated_node

# dedupe with api
def make_new_arg(self, name, value, existing_arg=None):
equal = (
existing_arg.equal
if existing_arg
else cst.AssignEqual(
whitespace_before=cst.SimpleWhitespace(""),
whitespace_after=cst.SimpleWhitespace(""),
)
)
return cst.Arg(
keyword=cst.parse_expression(name),
value=cst.parse_expression(value),
equal=equal,
)
20 changes: 20 additions & 0 deletions tests/codemods/test_upgrade_sslcontext_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,23 @@ def test_upgrade_protocol_in_expression_do_not_modify(self, tmpdir):
expected_output = input_code

self.run_and_assert(tmpdir, input_code, expected_output)

def test_import_no_protocol(self, tmpdir):
input_code = f"""import ssl
context = ssl.SSLContext()
"""
expected_output = """import ssl
context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

@pytest.mark.skip()
def test_from_import_no_protocol(self, tmpdir):
input_code = f"""from ssl import SSLContext
SSLContext()
"""
expected_output = """from ssl import SSLContext
import ssl
SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

0 comments on commit 3d9991d

Please sign in to comment.