diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 2bd2107d2..501d6c1a6 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -2,13 +2,71 @@ import libcst as cst from libcst import MetadataDependent, matchers from libcst.helpers import get_full_name_for_node -from libcst.metadata import Assignment, BaseAssignment, ImportAssignment, ScopeProvider +from libcst.metadata import ( + Assignment, + BaseAssignment, + ImportAssignment, + Scope, + ScopeProvider, +) from libcst.metadata.scope_provider import GlobalScope class NameResolutionMixin(MetadataDependent): METADATA_DEPENDENCIES: Tuple[Any, ...] = (ScopeProvider,) + def _iterate_scope_ancestors(self, scope: Scope): + """ + Iterate over all the ancestors of scope. Includes self. + """ + yield scope + while not isinstance(scope, GlobalScope): + scope = scope.parent + yield scope + + def find_import_alias_in_nodes_scope_from_name(self, name, node): + maybe_tuple = self.find_import_in_nodes_scope_from_name(name, node) + if maybe_tuple: + return self._get_assigned_name_for_import(maybe_tuple[1]) + return None + + def _get_assigned_name_for_import(self, alias: cst.ImportAlias) -> str: + # "" shouldn't happen in any way, sanity check + if alias.asname: + match name := alias.asname.name: + case cst.Name(): + return name.value + case _: + return "" + return get_full_name_for_node(alias.name.value) or "" + + def find_import_in_nodes_scope_from_name( + self, name: str, node: cst.CSTNode + ) -> Optional[Tuple[cst.ImportFrom | cst.Import, cst.ImportAlias]]: + """ + Given a node, find the earliest import node and alias for the given name in its scope. + """ + for scope in self._iterate_scope_ancestors( + self.get_metadata(ScopeProvider, node) + ): + all_import_assignments = filter( + lambda x: isinstance(x, ImportAssignment), + reversed(list(scope.assignments)), + ) + unique_import_nodes = {ia.node: None for ia in all_import_assignments} + for imp in unique_import_nodes.keys(): + match imp: + case cst.Import(): + for ia in imp.names: + if get_full_name_for_node(ia.name) == name: + return (imp, ia) + case cst.ImportFrom(): + if not isinstance(imp.names, cst.ImportStar): + for ia in imp.names: + if get_full_name_for_node(ia.name) == name: + return (imp, ia) + return None + def find_base_name(self, node): """ Given a node, solve its name to its basest form. For now it can only solve names that are imported. For example, in what follows, the base name for exec.capitalize() is sys.executable.capitalize. diff --git a/src/core_codemods/harden_pyyaml.py b/src/core_codemods/harden_pyyaml.py index 3e0991900..6005a16f0 100644 --- a/src/core_codemods/harden_pyyaml.py +++ b/src/core_codemods/harden_pyyaml.py @@ -1,8 +1,9 @@ from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.api import SemgrepCodemod +from codemodder.codemods.utils_mixin import NameResolutionMixin -class HardenPyyaml(SemgrepCodemod): +class HardenPyyaml(SemgrepCodemod, NameResolutionMixin): NAME = "harden-pyyaml" REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW SUMMARY = "Use SafeLoader in `yaml.load()` Calls" @@ -44,6 +45,15 @@ def rule(cls): """ - def on_result_found(self, _, updated_node): - new_args = [*updated_node.args[:1], self.parse_expression("yaml.SafeLoader")] + def on_result_found(self, original_node, updated_node): + maybe_name = self.find_import_alias_in_nodes_scope_from_name( + "yaml", original_node + ) + if not maybe_name: + self.add_needed_import("yaml") + maybe_name = maybe_name or "yaml" + new_args = [ + *updated_node.args[:1], + self.parse_expression(f"{maybe_name}.SafeLoader"), + ] return self.update_arg_target(updated_node, new_args) diff --git a/tests/codemods/test_harden_pyyaml.py b/tests/codemods/test_harden_pyyaml.py index 7f745f930..0e3e47f3a 100644 --- a/tests/codemods/test_harden_pyyaml.py +++ b/tests/codemods/test_harden_pyyaml.py @@ -46,7 +46,6 @@ def test_all_unsafe_loaders_kwarg(self, tmpdir, loader): """ self.run_and_assert(tmpdir, input_code, expected) - @pytest.mark.skip() def test_import_alias(self, tmpdir): input_code = """import yaml as yam from yaml import Loader @@ -54,8 +53,10 @@ def test_import_alias(self, tmpdir): data = b'!!python/object/apply:subprocess.Popen \\n- ls' deserialized_data = yam.load(data, Loader=Loader) """ - expected = """import yaml + expected = """import yaml as yam +from yaml import Loader + data = b'!!python/object/apply:subprocess.Popen \\n- ls' -deserialized_data = yaml.load(data, yaml.SafeLoader) +deserialized_data = yam.load(data, yam.SafeLoader) """ self.run_and_assert(tmpdir, input_code, expected)