diff --git a/src/core_codemods/semgrep/sandbox_url_creation.yaml b/src/core_codemods/semgrep/sandbox_url_creation.yaml index 781d0527..aa793bc2 100644 --- a/src/core_codemods/semgrep/sandbox_url_creation.yaml +++ b/src/core_codemods/semgrep/sandbox_url_creation.yaml @@ -7,6 +7,7 @@ rules: pattern-either: - patterns: - pattern: requests.get(...) + - pattern-not: requests.get("...") - pattern-inside: | import requests ... diff --git a/src/core_codemods/url_sandbox.py b/src/core_codemods/url_sandbox.py index 3a0eb21d..85213384 100644 --- a/src/core_codemods/url_sandbox.py +++ b/src/core_codemods/url_sandbox.py @@ -1,10 +1,9 @@ from typing import List, Optional, Union + import libcst as cst from libcst import CSTNode, matchers from libcst.codemod import Codemod, CodemodContext from libcst.metadata import PositionProvider, ScopeProvider -from codemodder.codemods.utils import ReplaceNodes -from codemodder.file_context import FileContext from libcst.codemod.visitors import AddImportsVisitor, ImportItem from codemodder.change import Change @@ -17,7 +16,10 @@ from codemodder.codemods.transformations.remove_unused_imports import ( RemoveUnusedImportsCodemod, ) +from codemodder.codemods.utils import ReplaceNodes from codemodder.dependency import Security +from codemodder.file_context import FileContext + replacement_import = "safe_requests" @@ -122,18 +124,34 @@ def leave_Call(self, original_node: cst.Call): case cst.SimpleString(): return - # case get(...) - if matchers.matches(original_node, matchers.Call(func=matchers.Name())): - # find if get(...) comes from an from requests import get - maybe_node = self.find_single_assignment(original_node) - if maybe_node and matchers.matches(maybe_node, matchers.ImportFrom()): + match original_node: + # case get(...) + case cst.Call(func=cst.Name()): + # find if get(...) comes from an from requests import get + match self.find_single_assignment(original_node): + case cst.ImportFrom() as node: + self.nodes_to_change.update( + { + node: cst.ImportFrom( + module=cst.Attribute( + value=cst.Name(Security.name), + attr=cst.Name(replacement_import), + ), + names=node.names, + ) + } + ) + self.changes_in_file.append( + Change(line_number, UrlSandbox.CHANGE_DESCRIPTION) + ) + + # case req.get(...) + case _: self.nodes_to_change.update( { - maybe_node: cst.ImportFrom( - module=cst.parse_expression( - f"{Security.name}.{replacement_import}" - ), - names=maybe_node.names, + original_node: cst.Call( + func=cst.parse_expression(replacement_import + ".get"), + args=original_node.args, ) } ) @@ -141,20 +159,6 @@ def leave_Call(self, original_node: cst.Call): Change(line_number, UrlSandbox.CHANGE_DESCRIPTION) ) - # case req.get(...) - else: - self.nodes_to_change.update( - { - original_node: cst.Call( - func=cst.parse_expression(replacement_import + ".get"), - args=original_node.args, - ) - } - ) - self.changes_in_file.append( - Change(line_number, UrlSandbox.CHANGE_DESCRIPTION) - ) - def _find_assignments(self, node: CSTNode): """ Given a MetadataWrapper and a CSTNode representing an access, find all the possible assignments that it refers. diff --git a/tests/codemods/test_url_sandbox.py b/tests/codemods/test_url_sandbox.py index 10297a82..9978dd9a 100644 --- a/tests/codemods/test_url_sandbox.py +++ b/tests/codemods/test_url_sandbox.py @@ -192,7 +192,7 @@ def test_ignore_hardcoded(self, tmpdir): self.run_and_assert(tmpdir, input_code, expected) - def test_ignore_hardcoded_from_variable(self, tmpdir): + def test_ignore_hardcoded_from_global_variable(self, tmpdir): expected = input_code = """ import requests @@ -201,3 +201,49 @@ def test_ignore_hardcoded_from_variable(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected) + + def test_ignore_hardcoded_from_local_variable(self, tmpdir): + expected = input_code = """ + import requests + + def foo(): + url = "www.google.com" + requests.get(url) + """ + + self.run_and_assert(tmpdir, input_code, expected) + + def test_ignore_hardcoded_from_local_variable_transitive(self, tmpdir): + expected = input_code = """ + import requests + + def foo(): + url = "www.google.com" + new_url = url + requests.get(new_url) + """ + + self.run_and_assert(tmpdir, input_code, expected) + + def test_ignore_hardcoded_from_local_variable_transitive_reassigned(self, tmpdir): + input_code = """ + import requests + + def foo(): + url = "www.google.com" + new_url = url + new_url = input() + requests.get(new_url) + """ + + expected = """ + from security import safe_requests + + def foo(): + url = "www.google.com" + new_url = url + new_url = input() + safe_requests.get(new_url) + """ + + self.run_and_assert(tmpdir, input_code, expected)