Skip to content

Commit

Permalink
Detect variables with hardcoded assignments
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Nov 17, 2023
1 parent 71331a9 commit 508fb7d
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 27 deletions.
1 change: 1 addition & 0 deletions src/core_codemods/semgrep/sandbox_url_creation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ rules:
pattern-either:
- patterns:
- pattern: requests.get(...)
- pattern-not: requests.get("...")
- pattern-inside: |
import requests
...
56 changes: 30 additions & 26 deletions src/core_codemods/url_sandbox.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import List, Optional, Union

import libcst as cst
from libcst import CSTNode, matchers
from libcst.codemod import Codemod, CodemodContext
from libcst.metadata import PositionProvider, ScopeProvider
from codemodder.codemods.utils import ReplaceNodes
from codemodder.file_context import FileContext

from libcst.codemod.visitors import AddImportsVisitor, ImportItem
from codemodder.change import Change
Expand All @@ -17,7 +16,10 @@
from codemodder.codemods.transformations.remove_unused_imports import (
RemoveUnusedImportsCodemod,
)
from codemodder.codemods.utils import ReplaceNodes
from codemodder.dependency import Security
from codemodder.file_context import FileContext


replacement_import = "safe_requests"

Expand Down Expand Up @@ -122,39 +124,41 @@ def leave_Call(self, original_node: cst.Call):
case cst.SimpleString():
return

# case get(...)
if matchers.matches(original_node, matchers.Call(func=matchers.Name())):
# find if get(...) comes from an from requests import get
maybe_node = self.find_single_assignment(original_node)
if maybe_node and matchers.matches(maybe_node, matchers.ImportFrom()):
match original_node:
# case get(...)
case cst.Call(func=cst.Name()):
# find if get(...) comes from an from requests import get
match self.find_single_assignment(original_node):
case cst.ImportFrom() as node:
self.nodes_to_change.update(
{
node: cst.ImportFrom(
module=cst.Attribute(
value=cst.Name(Security.name),
attr=cst.Name(replacement_import),
),
names=node.names,
)
}
)
self.changes_in_file.append(
Change(line_number, UrlSandbox.CHANGE_DESCRIPTION)
)

# case req.get(...)
case _:
self.nodes_to_change.update(
{
maybe_node: cst.ImportFrom(
module=cst.parse_expression(
f"{Security.name}.{replacement_import}"
),
names=maybe_node.names,
original_node: cst.Call(
func=cst.parse_expression(replacement_import + ".get"),
args=original_node.args,
)
}
)
self.changes_in_file.append(
Change(line_number, UrlSandbox.CHANGE_DESCRIPTION)
)

# case req.get(...)
else:
self.nodes_to_change.update(
{
original_node: cst.Call(
func=cst.parse_expression(replacement_import + ".get"),
args=original_node.args,
)
}
)
self.changes_in_file.append(
Change(line_number, UrlSandbox.CHANGE_DESCRIPTION)
)

def _find_assignments(self, node: CSTNode):
"""
Given a MetadataWrapper and a CSTNode representing an access, find all the possible assignments that it refers.
Expand Down
48 changes: 47 additions & 1 deletion tests/codemods/test_url_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_ignore_hardcoded(self, tmpdir):

self.run_and_assert(tmpdir, input_code, expected)

def test_ignore_hardcoded_from_variable(self, tmpdir):
def test_ignore_hardcoded_from_global_variable(self, tmpdir):
expected = input_code = """
import requests
Expand All @@ -201,3 +201,49 @@ def test_ignore_hardcoded_from_variable(self, tmpdir):
"""

self.run_and_assert(tmpdir, input_code, expected)

def test_ignore_hardcoded_from_local_variable(self, tmpdir):
expected = input_code = """
import requests
def foo():
url = "www.google.com"
requests.get(url)
"""

self.run_and_assert(tmpdir, input_code, expected)

def test_ignore_hardcoded_from_local_variable_transitive(self, tmpdir):
expected = input_code = """
import requests
def foo():
url = "www.google.com"
new_url = url
requests.get(new_url)
"""

self.run_and_assert(tmpdir, input_code, expected)

def test_ignore_hardcoded_from_local_variable_transitive_reassigned(self, tmpdir):
input_code = """
import requests
def foo():
url = "www.google.com"
new_url = url
new_url = input()
requests.get(new_url)
"""

expected = """
from security import safe_requests
def foo():
url = "www.google.com"
new_url = url
new_url = input()
safe_requests.get(new_url)
"""

self.run_and_assert(tmpdir, input_code, expected)

0 comments on commit 508fb7d

Please sign in to comment.