Skip to content

Commit

Permalink
Changed solution to be more reliable and specific
Browse files Browse the repository at this point in the history
Added alias detection for harden-pyyaml, secure-tempfile,and upgrade-sslcontext-minimum-version. Enabled tests.
  • Loading branch information
andrecsilva committed Oct 27, 2023
1 parent 3db7b2d commit a199aea
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 74 deletions.
102 changes: 45 additions & 57 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Assignment,
BaseAssignment,
ImportAssignment,
Scope,
ScopeProvider,
)
from libcst.metadata.scope_provider import GlobalScope
Expand All @@ -15,58 +14,6 @@
class NameResolutionMixin(MetadataDependent):
METADATA_DEPENDENCIES: Tuple[Any, ...] = (ScopeProvider,)

def _iterate_scope_ancestors(self, scope: Scope):
"""
Iterate over all the ancestors of scope. Includes self.
"""
yield scope
while not isinstance(scope, GlobalScope):
scope = scope.parent
yield scope

def find_import_alias_in_nodes_scope_from_name(self, name, node):
maybe_tuple = self.find_import_in_nodes_scope_from_name(name, node)
if maybe_tuple:
return self._get_assigned_name_for_import(maybe_tuple[1])
return None

def _get_assigned_name_for_import(self, alias: cst.ImportAlias) -> str:
# "" shouldn't happen in any way, sanity check
if alias.asname:
match name := alias.asname.name:
case cst.Name():
return name.value
case _:
return ""
return get_full_name_for_node(alias.name.value) or ""

def find_import_in_nodes_scope_from_name(
self, name: str, node: cst.CSTNode
) -> Optional[Tuple[cst.ImportFrom | cst.Import, cst.ImportAlias]]:
"""
Given a node, find the earliest import node and alias for the given name in its scope.
"""
for scope in self._iterate_scope_ancestors(
self.get_metadata(ScopeProvider, node)
):
all_import_assignments = filter(
lambda x: isinstance(x, ImportAssignment),
reversed(list(scope.assignments)),
)
unique_import_nodes = {ia.node: None for ia in all_import_assignments}
for imp in unique_import_nodes.keys():
match imp:
case cst.Import():
for ia in imp.names:
if get_full_name_for_node(ia.name) == name:
return (imp, ia)
case cst.ImportFrom():
if not isinstance(imp.names, cst.ImportStar):
for ia in imp.names:
if get_full_name_for_node(ia.name) == name:
return (imp, ia)
return None

def find_base_name(self, node):
"""
Given a node, solve its name to its basest form. For now it can only solve names that are imported. For example, in what follows, the base name for exec.capitalize() is sys.executable.capitalize.
Expand Down Expand Up @@ -116,6 +63,44 @@ def _is_direct_call_from_imported_module(
return (import_node, alias)
return None

def get_imported_prefix(
self, node
) -> Optional[tuple[Union[cst.Import, cst.ImportFrom], cst.ImportAlias]]:
"""
Given a node representing an access, finds if any part of its prefix is imported.
Returns a import and import alias pair.
"""
for nodo in iterate_left_expressions(node):
match nodo:
case cst.Name() | cst.Attribute():
maybe_assignment = self.find_single_assignment(nodo)
if maybe_assignment and isinstance(
maybe_assignment, ImportAssignment
):
import_node = maybe_assignment.node
for alias in import_node.names:
if maybe_assignment.name in (
alias.evaluated_alias,
alias.evaluated_name,
):
return (import_node, alias)
return None

def get_aliased_prefix_name(self, node: cst.CSTNode, name: str):
"""
Returns the alias of name if name is imported and used as a prefix for this node.
"""
maybe_import = self.get_imported_prefix(node)
maybe_name = None
if maybe_import:
imp, ia = maybe_import
match imp:
case cst.Import():
imp_name = get_full_name_for_node(ia.name)
if imp_name == name and ia.asname:
maybe_name = ia.asname.name.value
return maybe_name

def find_assignments(
self,
node: Union[cst.Name, cst.Attribute, cst.Call, cst.Subscript, cst.Decorator],
Expand Down Expand Up @@ -168,10 +153,13 @@ def find_single_assignment(

def iterate_left_expressions(node: cst.BaseExpression):
yield node
if matchers.matches(node, matchers.Attribute()):
yield from iterate_left_expressions(node.value)
if matchers.matches(node, matchers.Call()):
yield from iterate_left_expressions(node.func)
match node:
case cst.Attribute():
yield from iterate_left_expressions(node.value)
case cst.Call():
yield from iterate_left_expressions(node.func)
case cst.Subscript():
yield from iterate_left_expressions(node.value)


def get_leftmost_expression(node: cst.BaseExpression) -> cst.BaseExpression:
Expand Down
12 changes: 6 additions & 6 deletions src/core_codemods/harden_pyyaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class HardenPyyaml(SemgrepCodemod, NameResolutionMixin):
}
]

_module_name = "yaml"

@classmethod
def rule(cls):
return """
Expand Down Expand Up @@ -46,12 +48,10 @@ def rule(cls):
"""

def on_result_found(self, original_node, updated_node):
maybe_name = self.find_import_alias_in_nodes_scope_from_name(
"yaml", original_node
)
if not maybe_name:
self.add_needed_import("yaml")
maybe_name = maybe_name or "yaml"
maybe_name = self.get_aliased_prefix_name(original_node, self._module_name)
maybe_name = maybe_name or self._module_name
if maybe_name and maybe_name == self._module_name:
self.add_needed_import(self._module_name)
new_args = [
*updated_node.args[:1],
self.parse_expression(f"{maybe_name}.SafeLoader"),
Expand Down
12 changes: 9 additions & 3 deletions src/core_codemods/tempfile_mktemp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.utils_mixin import NameResolutionMixin


class TempfileMktemp(SemgrepCodemod):
class TempfileMktemp(SemgrepCodemod, NameResolutionMixin):
NAME = "secure-tempfile"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW
SUMMARY = "Upgrade and Secure Temp File Creation"
Expand All @@ -14,6 +15,8 @@ class TempfileMktemp(SemgrepCodemod):
}
]

_module_name = "tempfile"

@classmethod
def rule(cls):
return """
Expand All @@ -26,6 +29,9 @@ def rule(cls):
"""

def on_result_found(self, original_node, updated_node):
maybe_name = self.get_aliased_prefix_name(original_node, self._module_name)
maybe_name = maybe_name or self._module_name
if maybe_name and maybe_name == self._module_name:
self.add_needed_import(self._module_name)
self.remove_unused_import(original_node)
self.add_needed_import("tempfile")
return self.update_call_target(updated_node, "tempfile", "mkstemp")
return self.update_call_target(updated_node, maybe_name, "mkstemp")
14 changes: 11 additions & 3 deletions src/core_codemods/upgrade_sslcontext_minimum_version.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.utils_mixin import NameResolutionMixin


class UpgradeSSLContextMinimumVersion(SemgrepCodemod):
class UpgradeSSLContextMinimumVersion(SemgrepCodemod, NameResolutionMixin):
NAME = "upgrade-sslcontext-minimum-version"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW
SUMMARY = "Upgrade SSLContext Minimum Version"
Expand All @@ -19,6 +20,8 @@ class UpgradeSSLContextMinimumVersion(SemgrepCodemod):
},
]

_module_name = "ssl"

@classmethod
def rule(cls):
return """
Expand All @@ -45,6 +48,11 @@ def rule(cls):
"""

def on_result_found(self, original_node, updated_node):
maybe_name = self.get_aliased_prefix_name(
original_node.value, self._module_name
)
maybe_name = maybe_name or self._module_name
if maybe_name and maybe_name == self._module_name:
self.add_needed_import(self._module_name)
self.remove_unused_import(original_node)
self.add_needed_import("ssl")
return self.update_assign_rhs(updated_node, "ssl.TLSVersion.TLSv1_2")
return self.update_assign_rhs(updated_node, f"{maybe_name}.TLSVersion.TLSv1_2")
3 changes: 1 addition & 2 deletions tests/codemods/test_tempfile_mktemp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import pytest
import pytest # pylint: disable=unused-import
from core_codemods.tempfile_mktemp import TempfileMktemp
from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest

Expand Down Expand Up @@ -50,7 +50,6 @@ def test_from_import(self, tmpdir):
"""
self.run_and_assert(tmpdir, input_code, expected_output)

@pytest.mark.skip()
def test_import_alias(self, tmpdir):
input_code = """import tempfile as _tempfile
Expand Down
5 changes: 2 additions & 3 deletions tests/codemods/test_upgrade_sslcontext_minimum_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
]


class TestUpgradeSSLContextMininumVersion(BaseSemgrepCodemodTest):
class TestUpgradeSSLContextMinimumVersion(BaseSemgrepCodemodTest):
codemod = UpgradeSSLContextMinimumVersion

@pytest.mark.parametrize("version", INSECURE_VERSIONS)
Expand Down Expand Up @@ -64,10 +64,9 @@ def test_import_with_alias(self, tmpdir):
context.minimum_version = whatever.TLSVersion.SSLv3
"""
expected_output = """import ssl as whatever
import ssl
context = whatever.SSLContext()
context.minimum_version = ssl.TLSVersion.TLSv1_2
context.minimum_version = whatever.TLSVersion.TLSv1_2
"""
self.run_and_assert(tmpdir, input_code, expected_output)

Expand Down

0 comments on commit a199aea

Please sign in to comment.