Skip to content

Commit

Permalink
Preserve kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Nov 15, 2023
1 parent 4048243 commit cce0b02
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
5 changes: 3 additions & 2 deletions integration_tests/test_harden_pyyaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ class TestHardenPyyaml(BaseIntegrationTest):
codemod = HardenPyyaml
code_path = "tests/samples/unsafe_yaml.py"
original_code, expected_new_code = original_and_expected_from_code_path(
code_path, [(3, "deserialized_data = yaml.load(data, yaml.SafeLoader)\n")]
code_path,
[(3, "deserialized_data = yaml.load(data, Loader=yaml.SafeLoader)\n")],
)
expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n import yaml\n \n data = b"!!python/object/apply:subprocess.Popen \\\\n- ls"\n-deserialized_data = yaml.load(data, Loader=yaml.Loader)\n+deserialized_data = yaml.load(data, yaml.SafeLoader)\n'
expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n import yaml\n \n data = b"!!python/object/apply:subprocess.Popen \\\\n- ls"\n-deserialized_data = yaml.load(data, Loader=yaml.Loader)\n+deserialized_data = yaml.load(data, Loader=yaml.SafeLoader)\n'
expected_line_change = "4"
change_description = HardenPyyaml.CHANGE_DESCRIPTION
# expected exception because the yaml.SafeLoader protects against unsafe code
Expand Down
4 changes: 3 additions & 1 deletion src/core_codemods/harden_pyyaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def on_result_found(self, original_node, updated_node):
self.add_needed_import(self._module_name)
new_args = [
*updated_node.args[:1],
self.parse_expression(f"{maybe_name}.SafeLoader"),
updated_node.args[1].with_changes(
value=self.parse_expression(f"{maybe_name}.SafeLoader")
),
]
return self.update_arg_target(updated_node, new_args)
4 changes: 2 additions & 2 deletions tests/codemods/test_harden_pyyaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_all_unsafe_loaders_kwarg(self, tmpdir, loader):

expected = """import yaml
data = b'!!python/object/apply:subprocess.Popen \\n- ls'
deserialized_data = yaml.load(data, yaml.SafeLoader)
deserialized_data = yaml.load(data, Loader=yaml.SafeLoader)
"""
self.run_and_assert(tmpdir, input_code, expected)

Expand All @@ -57,7 +57,7 @@ def test_import_alias(self, tmpdir):
from yaml import Loader
data = b'!!python/object/apply:subprocess.Popen \\n- ls'
deserialized_data = yam.load(data, yam.SafeLoader)
deserialized_data = yam.load(data, Loader=yam.SafeLoader)
"""
self.run_and_assert(tmpdir, input_code, expected)

Expand Down

0 comments on commit cce0b02

Please sign in to comment.