Skip to content

Commit

Permalink
Fixed alias case for harden-pyyaml
Browse files Browse the repository at this point in the history
  • Loading branch information
andrecsilva committed Oct 26, 2023
1 parent 9f53a04 commit 3db7b2d
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 7 deletions.
60 changes: 59 additions & 1 deletion src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,71 @@
import libcst as cst
from libcst import MetadataDependent, matchers
from libcst.helpers import get_full_name_for_node
from libcst.metadata import Assignment, BaseAssignment, ImportAssignment, ScopeProvider
from libcst.metadata import (
Assignment,
BaseAssignment,
ImportAssignment,
Scope,
ScopeProvider,
)
from libcst.metadata.scope_provider import GlobalScope


class NameResolutionMixin(MetadataDependent):
METADATA_DEPENDENCIES: Tuple[Any, ...] = (ScopeProvider,)

def _iterate_scope_ancestors(self, scope: Scope):
"""
Iterate over all the ancestors of scope. Includes self.
"""
yield scope
while not isinstance(scope, GlobalScope):
scope = scope.parent
yield scope

def find_import_alias_in_nodes_scope_from_name(self, name, node):
maybe_tuple = self.find_import_in_nodes_scope_from_name(name, node)
if maybe_tuple:
return self._get_assigned_name_for_import(maybe_tuple[1])
return None

def _get_assigned_name_for_import(self, alias: cst.ImportAlias) -> str:
# "" shouldn't happen in any way, sanity check
if alias.asname:
match name := alias.asname.name:
case cst.Name():
return name.value
case _:
return ""
return get_full_name_for_node(alias.name.value) or ""

def find_import_in_nodes_scope_from_name(
self, name: str, node: cst.CSTNode
) -> Optional[Tuple[cst.ImportFrom | cst.Import, cst.ImportAlias]]:
"""
Given a node, find the earliest import node and alias for the given name in its scope.
"""
for scope in self._iterate_scope_ancestors(
self.get_metadata(ScopeProvider, node)
):
all_import_assignments = filter(
lambda x: isinstance(x, ImportAssignment),
reversed(list(scope.assignments)),
)
unique_import_nodes = {ia.node: None for ia in all_import_assignments}
for imp in unique_import_nodes.keys():
match imp:
case cst.Import():
for ia in imp.names:
if get_full_name_for_node(ia.name) == name:
return (imp, ia)
case cst.ImportFrom():
if not isinstance(imp.names, cst.ImportStar):
for ia in imp.names:
if get_full_name_for_node(ia.name) == name:
return (imp, ia)
return None

def find_base_name(self, node):
"""
Given a node, solve its name to its basest form. For now it can only solve names that are imported. For example, in what follows, the base name for exec.capitalize() is sys.executable.capitalize.
Expand Down
16 changes: 13 additions & 3 deletions src/core_codemods/harden_pyyaml.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.utils_mixin import NameResolutionMixin


class HardenPyyaml(SemgrepCodemod):
class HardenPyyaml(SemgrepCodemod, NameResolutionMixin):
NAME = "harden-pyyaml"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW
SUMMARY = "Use SafeLoader in `yaml.load()` Calls"
Expand Down Expand Up @@ -44,6 +45,15 @@ def rule(cls):
"""

def on_result_found(self, _, updated_node):
new_args = [*updated_node.args[:1], self.parse_expression("yaml.SafeLoader")]
def on_result_found(self, original_node, updated_node):
maybe_name = self.find_import_alias_in_nodes_scope_from_name(
"yaml", original_node
)
if not maybe_name:
self.add_needed_import("yaml")
maybe_name = maybe_name or "yaml"
new_args = [
*updated_node.args[:1],
self.parse_expression(f"{maybe_name}.SafeLoader"),
]
return self.update_arg_target(updated_node, new_args)
7 changes: 4 additions & 3 deletions tests/codemods/test_harden_pyyaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,17 @@ def test_all_unsafe_loaders_kwarg(self, tmpdir, loader):
"""
self.run_and_assert(tmpdir, input_code, expected)

@pytest.mark.skip()
def test_import_alias(self, tmpdir):
input_code = """import yaml as yam
from yaml import Loader
data = b'!!python/object/apply:subprocess.Popen \\n- ls'
deserialized_data = yam.load(data, Loader=Loader)
"""
expected = """import yaml
expected = """import yaml as yam
from yaml import Loader
data = b'!!python/object/apply:subprocess.Popen \\n- ls'
deserialized_data = yaml.load(data, yaml.SafeLoader)
deserialized_data = yam.load(data, yam.SafeLoader)
"""
self.run_and_assert(tmpdir, input_code, expected)

0 comments on commit 3db7b2d

Please sign in to comment.