diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 501d6c1a..0646c935 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -6,7 +6,6 @@ Assignment, BaseAssignment, ImportAssignment, - Scope, ScopeProvider, ) from libcst.metadata.scope_provider import GlobalScope @@ -15,58 +14,6 @@ class NameResolutionMixin(MetadataDependent): METADATA_DEPENDENCIES: Tuple[Any, ...] = (ScopeProvider,) - def _iterate_scope_ancestors(self, scope: Scope): - """ - Iterate over all the ancestors of scope. Includes self. - """ - yield scope - while not isinstance(scope, GlobalScope): - scope = scope.parent - yield scope - - def find_import_alias_in_nodes_scope_from_name(self, name, node): - maybe_tuple = self.find_import_in_nodes_scope_from_name(name, node) - if maybe_tuple: - return self._get_assigned_name_for_import(maybe_tuple[1]) - return None - - def _get_assigned_name_for_import(self, alias: cst.ImportAlias) -> str: - # "" shouldn't happen in any way, sanity check - if alias.asname: - match name := alias.asname.name: - case cst.Name(): - return name.value - case _: - return "" - return get_full_name_for_node(alias.name.value) or "" - - def find_import_in_nodes_scope_from_name( - self, name: str, node: cst.CSTNode - ) -> Optional[Tuple[cst.ImportFrom | cst.Import, cst.ImportAlias]]: - """ - Given a node, find the earliest import node and alias for the given name in its scope. - """ - for scope in self._iterate_scope_ancestors( - self.get_metadata(ScopeProvider, node) - ): - all_import_assignments = filter( - lambda x: isinstance(x, ImportAssignment), - reversed(list(scope.assignments)), - ) - unique_import_nodes = {ia.node: None for ia in all_import_assignments} - for imp in unique_import_nodes.keys(): - match imp: - case cst.Import(): - for ia in imp.names: - if get_full_name_for_node(ia.name) == name: - return (imp, ia) - case cst.ImportFrom(): - if not isinstance(imp.names, cst.ImportStar): - for ia in imp.names: - if get_full_name_for_node(ia.name) == name: - return (imp, ia) - return None - def find_base_name(self, node): """ Given a node, solve its name to its basest form. For now it can only solve names that are imported. For example, in what follows, the base name for exec.capitalize() is sys.executable.capitalize. @@ -116,6 +63,44 @@ def _is_direct_call_from_imported_module( return (import_node, alias) return None + def get_imported_prefix( + self, node + ) -> Optional[tuple[Union[cst.Import, cst.ImportFrom], cst.ImportAlias]]: + """ + Given a node representing an access, finds if any part of its prefix is imported. + Returns a import and import alias pair. + """ + for nodo in iterate_left_expressions(node): + match nodo: + case cst.Name() | cst.Attribute(): + maybe_assignment = self.find_single_assignment(nodo) + if maybe_assignment and isinstance( + maybe_assignment, ImportAssignment + ): + import_node = maybe_assignment.node + for alias in import_node.names: + if maybe_assignment.name in ( + alias.evaluated_alias, + alias.evaluated_name, + ): + return (import_node, alias) + return None + + def get_aliased_prefix_name(self, node: cst.CSTNode, name: str): + """ + Returns the alias of name if name is imported and used as a prefix for this node. + """ + maybe_import = self.get_imported_prefix(node) + maybe_name = None + if maybe_import: + imp, ia = maybe_import + match imp: + case cst.Import(): + imp_name = get_full_name_for_node(ia.name) + if imp_name == name and ia.asname: + maybe_name = ia.asname.name.value + return maybe_name + def find_assignments( self, node: Union[cst.Name, cst.Attribute, cst.Call, cst.Subscript, cst.Decorator], @@ -168,10 +153,13 @@ def find_single_assignment( def iterate_left_expressions(node: cst.BaseExpression): yield node - if matchers.matches(node, matchers.Attribute()): - yield from iterate_left_expressions(node.value) - if matchers.matches(node, matchers.Call()): - yield from iterate_left_expressions(node.func) + match node: + case cst.Attribute(): + yield from iterate_left_expressions(node.value) + case cst.Call(): + yield from iterate_left_expressions(node.func) + case cst.Subscript(): + yield from iterate_left_expressions(node.value) def get_leftmost_expression(node: cst.BaseExpression) -> cst.BaseExpression: diff --git a/src/core_codemods/harden_pyyaml.py b/src/core_codemods/harden_pyyaml.py index 6005a16f..792a4d31 100644 --- a/src/core_codemods/harden_pyyaml.py +++ b/src/core_codemods/harden_pyyaml.py @@ -15,6 +15,8 @@ class HardenPyyaml(SemgrepCodemod, NameResolutionMixin): } ] + _module_name = "yaml" + @classmethod def rule(cls): return """ @@ -46,12 +48,10 @@ def rule(cls): """ def on_result_found(self, original_node, updated_node): - maybe_name = self.find_import_alias_in_nodes_scope_from_name( - "yaml", original_node - ) - if not maybe_name: - self.add_needed_import("yaml") - maybe_name = maybe_name or "yaml" + maybe_name = self.get_aliased_prefix_name(original_node, self._module_name) + maybe_name = maybe_name or self._module_name + if maybe_name and maybe_name == self._module_name: + self.add_needed_import(self._module_name) new_args = [ *updated_node.args[:1], self.parse_expression(f"{maybe_name}.SafeLoader"), diff --git a/src/core_codemods/tempfile_mktemp.py b/src/core_codemods/tempfile_mktemp.py index cb8ea15d..d45637e5 100644 --- a/src/core_codemods/tempfile_mktemp.py +++ b/src/core_codemods/tempfile_mktemp.py @@ -1,8 +1,9 @@ from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.api import SemgrepCodemod +from codemodder.codemods.utils_mixin import NameResolutionMixin -class TempfileMktemp(SemgrepCodemod): +class TempfileMktemp(SemgrepCodemod, NameResolutionMixin): NAME = "secure-tempfile" REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW SUMMARY = "Upgrade and Secure Temp File Creation" @@ -14,6 +15,8 @@ class TempfileMktemp(SemgrepCodemod): } ] + _module_name = "tempfile" + @classmethod def rule(cls): return """ @@ -26,6 +29,9 @@ def rule(cls): """ def on_result_found(self, original_node, updated_node): + maybe_name = self.get_aliased_prefix_name(original_node, self._module_name) + maybe_name = maybe_name or self._module_name + if maybe_name and maybe_name == self._module_name: + self.add_needed_import(self._module_name) self.remove_unused_import(original_node) - self.add_needed_import("tempfile") - return self.update_call_target(updated_node, "tempfile", "mkstemp") + return self.update_call_target(updated_node, maybe_name, "mkstemp") diff --git a/src/core_codemods/upgrade_sslcontext_minimum_version.py b/src/core_codemods/upgrade_sslcontext_minimum_version.py index 29142bb6..ed8c5e7c 100644 --- a/src/core_codemods/upgrade_sslcontext_minimum_version.py +++ b/src/core_codemods/upgrade_sslcontext_minimum_version.py @@ -1,8 +1,9 @@ from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.api import SemgrepCodemod +from codemodder.codemods.utils_mixin import NameResolutionMixin -class UpgradeSSLContextMinimumVersion(SemgrepCodemod): +class UpgradeSSLContextMinimumVersion(SemgrepCodemod, NameResolutionMixin): NAME = "upgrade-sslcontext-minimum-version" REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW SUMMARY = "Upgrade SSLContext Minimum Version" @@ -19,6 +20,8 @@ class UpgradeSSLContextMinimumVersion(SemgrepCodemod): }, ] + _module_name = "ssl" + @classmethod def rule(cls): return """ @@ -45,6 +48,11 @@ def rule(cls): """ def on_result_found(self, original_node, updated_node): + maybe_name = self.get_aliased_prefix_name( + original_node.value, self._module_name + ) + maybe_name = maybe_name or self._module_name + if maybe_name and maybe_name == self._module_name: + self.add_needed_import(self._module_name) self.remove_unused_import(original_node) - self.add_needed_import("ssl") - return self.update_assign_rhs(updated_node, "ssl.TLSVersion.TLSv1_2") + return self.update_assign_rhs(updated_node, f"{maybe_name}.TLSVersion.TLSv1_2") diff --git a/tests/codemods/test_tempfile_mktemp.py b/tests/codemods/test_tempfile_mktemp.py index 06f8ce37..6dcca028 100644 --- a/tests/codemods/test_tempfile_mktemp.py +++ b/tests/codemods/test_tempfile_mktemp.py @@ -1,4 +1,4 @@ -import pytest +import pytest # pylint: disable=unused-import from core_codemods.tempfile_mktemp import TempfileMktemp from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest @@ -50,7 +50,6 @@ def test_from_import(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected_output) - @pytest.mark.skip() def test_import_alias(self, tmpdir): input_code = """import tempfile as _tempfile diff --git a/tests/codemods/test_upgrade_sslcontext_minimum_version.py b/tests/codemods/test_upgrade_sslcontext_minimum_version.py index cb79c644..ddb5cd49 100644 --- a/tests/codemods/test_upgrade_sslcontext_minimum_version.py +++ b/tests/codemods/test_upgrade_sslcontext_minimum_version.py @@ -13,7 +13,7 @@ ] -class TestUpgradeSSLContextMininumVersion(BaseSemgrepCodemodTest): +class TestUpgradeSSLContextMinimumVersion(BaseSemgrepCodemodTest): codemod = UpgradeSSLContextMinimumVersion @pytest.mark.parametrize("version", INSECURE_VERSIONS) @@ -64,10 +64,9 @@ def test_import_with_alias(self, tmpdir): context.minimum_version = whatever.TLSVersion.SSLv3 """ expected_output = """import ssl as whatever -import ssl context = whatever.SSLContext() -context.minimum_version = ssl.TLSVersion.TLSv1_2 +context.minimum_version = whatever.TLSVersion.TLSv1_2 """ self.run_and_assert(tmpdir, input_code, expected_output)