diff --git a/integration_tests/base_test.py b/integration_tests/base_test.py index 1e0a6b219..c894ab6ab 100644 --- a/integration_tests/base_test.py +++ b/integration_tests/base_test.py @@ -7,7 +7,7 @@ from codemodder import __VERSION__ from codemodder import registry -from tests.validations import validate_code +from tests.validations import execute_code SAMPLES_DIR = "tests/samples" @@ -51,6 +51,7 @@ class BaseIntegrationTest(DependencyTestMixin, CleanRepoMixin): num_changes = 1 _lines = [] num_changed_files = 1 + allowed_exceptions = () @classmethod def setup_class(cls): @@ -121,7 +122,7 @@ def check_code_after(self): with open(self.code_path, "r", encoding="utf-8") as f: new_code = f.read() assert new_code == self.expected_new_code - validate_code(path=self.code_path) + execute_code(path=self.code_path, allowed_exceptions=self.allowed_exceptions) def test_file_rewritten(self): """ diff --git a/integration_tests/test_harden_pyyaml.py b/integration_tests/test_harden_pyyaml.py index ab20815a7..8240b623d 100644 --- a/integration_tests/test_harden_pyyaml.py +++ b/integration_tests/test_harden_pyyaml.py @@ -1,3 +1,4 @@ +import yaml from core_codemods.harden_pyyaml import HardenPyyaml from integration_tests.base_test import ( BaseIntegrationTest, @@ -14,3 +15,5 @@ class TestHardenPyyaml(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n import yaml\n \n data = b"!!python/object/apply:subprocess.Popen \\\\n- ls"\n-deserialized_data = yaml.load(data, Loader=yaml.Loader)\n+deserialized_data = yaml.load(data, yaml.SafeLoader)\n' expected_line_change = "4" change_description = HardenPyyaml.CHANGE_DESCRIPTION + # expected exception because the yaml.SafeLoader protects against unsafe code + allowed_exceptions = (yaml.constructor.ConstructorError,) diff --git a/integration_tests/test_limit_readline.py b/integration_tests/test_limit_readline.py index c343c52f1..b228b24ca 100644 --- a/integration_tests/test_limit_readline.py +++ b/integration_tests/test_limit_readline.py @@ -14,3 +14,5 @@ class TestLimitReadline(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,2 +1,2 @@\n file = open("some_file.txt")\n-file.readline()\n+file.readline(5_000_000)\n' expected_line_change = "2" change_description = LimitReadline.CHANGE_DESCRIPTION + # expected because output code points to fake file + allowed_exceptions = (FileNotFoundError,) diff --git a/integration_tests/test_lxml_safe_parsing.py b/integration_tests/test_lxml_safe_parsing.py index 8d580c8d4..3a8d65ea3 100644 --- a/integration_tests/test_lxml_safe_parsing.py +++ b/integration_tests/test_lxml_safe_parsing.py @@ -25,3 +25,4 @@ class TestLxmlSafeParsing(BaseIntegrationTest): expected_line_change = "2" num_changes = 2 change_description = LxmlSafeParsing.CHANGE_DESCRIPTION + allowed_exceptions = (OSError,) diff --git a/integration_tests/test_order_imports.py b/integration_tests/test_order_imports.py index 0e1f05113..3d174c877 100644 --- a/integration_tests/test_order_imports.py +++ b/integration_tests/test_order_imports.py @@ -11,21 +11,22 @@ class TestOrderImports(BaseIntegrationTest): original_code, expected_new_code = original_and_expected_from_code_path( code_path, [ - (1, "# comment b4\n"), - (2, "# comment b5\n"), - (3, "# comment b3\n"), - (4, "# comment b1\n"), - (5, "# comment b2\n"), - (6, "import b\n"), - (7, "import d\n"), - (8, "# comment a\n"), - (9, "from a import a1, a2\n"), - (10, "\n"), - (11, "a1\n"), - (12, "a2\n"), - (13, "b\n"), - (14, "c\n"), - (15, "d"), + (1, "# comment builtins4\n"), + (2, "# comment builtins5\n"), + (3, "# comment builtins3\n"), + (4, "# comment builtins1\n"), + (5, "# comment builtins2\n"), + (6, "import builtins\n"), + (7, "import collections\n"), + (8, "import datetime\n"), + (9, "# comment a\n"), + (10, "from abc import ABC, ABCMeta\n"), + (11, "\n"), + (12, "ABC\n"), + (13, "ABCMeta\n"), + (14, "builtins\n"), + (15, "collections\n"), + (16, ""), (17, ""), (18, ""), (19, ""), @@ -34,6 +35,6 @@ class TestOrderImports(BaseIntegrationTest): ], ) - expected_diff = "--- \n+++ \n@@ -1,19 +1,13 @@\n #!/bin/env python\n-from a import a2\n-\n+# comment b4\n+# comment b5\n+# comment b3\n # comment b1\n # comment b2\n import b\n-\n+import d\n # comment a\n-from a import a1\n-\n-# comment b3\n-import b, d\n-\n-# comment b4\n-# comment b5\n-import b\n+from a import a1, a2\n \n a1\n a2\n" + expected_diff = "--- \n+++ \n@@ -1,20 +1,14 @@\n #!/bin/env python\n-from abc import ABCMeta\n-\n+# comment builtins4\n+# comment builtins5\n+# comment builtins3\n # comment builtins1\n # comment builtins2\n import builtins\n-\n+import collections\n+import datetime\n # comment a\n-from abc import ABC\n-\n-# comment builtins3\n-import builtins, datetime\n-\n-# comment builtins4\n-# comment builtins5\n-import builtins\n-import collections\n+from abc import ABC, ABCMeta\n \n ABC\n ABCMeta\n" expected_line_change = "2" change_description = OrderImports.CHANGE_DESCRIPTION diff --git a/integration_tests/test_remove_unused_imports.py b/integration_tests/test_remove_unused_imports.py index cb25080d0..366756ece 100644 --- a/integration_tests/test_remove_unused_imports.py +++ b/integration_tests/test_remove_unused_imports.py @@ -11,10 +11,10 @@ class TestRemoveUnusedImports(BaseIntegrationTest): original_code, expected_new_code = original_and_expected_from_code_path( code_path, [ - (1, """from b import c\n"""), + (1, """from builtins import complex\n"""), ], ) - expected_diff = "--- \n+++ \n@@ -1,5 +1,5 @@\n import a\n-from b import c, d\n+from b import c\n \n a\n c\n" + expected_diff = "--- \n+++ \n@@ -1,5 +1,5 @@\n import abc\n-from builtins import complex, dict\n+from builtins import complex\n \n abc\n complex\n" expected_line_change = 2 change_description = RemoveUnusedImports.CHANGE_DESCRIPTION diff --git a/integration_tests/test_request_verify.py b/integration_tests/test_request_verify.py index 359113111..9cf400838 100644 --- a/integration_tests/test_request_verify.py +++ b/integration_tests/test_request_verify.py @@ -3,6 +3,7 @@ BaseIntegrationTest, original_and_expected_from_code_path, ) +from requests.exceptions import ConnectionError class TestRequestsVerify(BaseIntegrationTest): @@ -11,14 +12,16 @@ class TestRequestsVerify(BaseIntegrationTest): original_code, expected_new_code = original_and_expected_from_code_path( code_path, [ - (2, """requests.get("www.google.com", verify=True)\n"""), + (2, """requests.get("https://www.google.com", verify=True)\n"""), ( 3, - """requests.post("https/some-api/", json={"id": 1234, "price": 18}, verify=True)\n""", + """requests.post("https://some-api/", json={"id": 1234, "price": 18}, verify=True)\n""", ), ], ) - expected_diff = '--- \n+++ \n@@ -1,5 +1,5 @@\n import requests\n \n-requests.get("www.google.com", verify=False)\n-requests.post("https/some-api/", json={"id": 1234, "price": 18}, verify=False)\n+requests.get("www.google.com", verify=True)\n+requests.post("https/some-api/", json={"id": 1234, "price": 18}, verify=True)\n var = "hello"\n' + expected_diff = '--- \n+++ \n@@ -1,5 +1,5 @@\n import requests\n \n-requests.get("https://www.google.com", verify=False)\n-requests.post("https://some-api/", json={"id": 1234, "price": 18}, verify=False)\n+requests.get("https://www.google.com", verify=True)\n+requests.post("https://some-api/", json={"id": 1234, "price": 18}, verify=True)\n var = "hello"\n' expected_line_change = "3" num_changes = 2 change_description = RequestsVerify.CHANGE_DESCRIPTION + # expected because when executing the output code it will make a request which fails, which is OK. + allowed_exceptions = (ConnectionError,) diff --git a/integration_tests/test_url_sandbox.py b/integration_tests/test_url_sandbox.py index faaf0069c..ad56c8031 100644 --- a/integration_tests/test_url_sandbox.py +++ b/integration_tests/test_url_sandbox.py @@ -12,11 +12,11 @@ class TestUrlSandbox(BaseIntegrationTest): code_path, [ (0, """from security import safe_requests\n"""), - (2, """safe_requests.get("www.google.com")\n"""), + (2, """safe_requests.get("https://www.google.com")\n"""), ], ) - expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n-import requests\n+from security import safe_requests\n \n-requests.get("www.google.com")\n+safe_requests.get("www.google.com")\n var = "hello"\n' + expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n-import requests\n+from security import safe_requests\n \n-requests.get("https://www.google.com")\n+safe_requests.get("https://www.google.com")\n var = "hello"\n' expected_line_change = "3" change_description = UrlSandbox.CHANGE_DESCRIPTION num_changed_files = 1 diff --git a/integration_tests/test_use_walrus_if.py b/integration_tests/test_use_walrus_if.py index cfd69919d..5b091071c 100644 --- a/integration_tests/test_use_walrus_if.py +++ b/integration_tests/test_use_walrus_if.py @@ -10,22 +10,22 @@ class TestUseWalrusIf(BaseIntegrationTest): code_path = "tests/samples/use_walrus_if.py" original_code, _ = original_and_expected_from_code_path(code_path, []) expected_new_code = """ -if (x := foo()) is not None: +if (x := sum([1, 2])) is not None: print(x) -if y := bar(): +if y := max([1, 2]): print(y) -z = baz() +z = min([1, 2]) print(z) def whatever(): - if (b := biz()) == 10: + if (b := int("2")) == 10: print(b) """.lstrip() - expected_diff = "--- \n+++ \n@@ -1,9 +1,7 @@\n-x = foo()\n-if x is not None:\n+if (x := foo()) is not None:\n print(x)\n \n-y = bar()\n-if y:\n+if y := bar():\n print(y)\n \n z = baz()\n@@ -11,6 +9,5 @@\n \n \n def whatever():\n- b = biz()\n- if b == 10:\n+ if (b := biz()) == 10:\n print(b)\n" + expected_diff = '--- \n+++ \n@@ -1,9 +1,7 @@\n-x = sum([1, 2])\n-if x is not None:\n+if (x := sum([1, 2])) is not None:\n print(x)\n \n-y = max([1, 2])\n-if y:\n+if y := max([1, 2]):\n print(y)\n \n z = min([1, 2])\n@@ -11,6 +9,5 @@\n \n \n def whatever():\n- b = int("2")\n- if b == 10:\n+ if (b := int("2")) == 10:\n print(b)\n' num_changes = 3 expected_line_change = 1 diff --git a/requirements/test.txt b/requirements/test.txt index 451e801a9..9f6b5109c 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -6,4 +6,5 @@ pytest==7.4.* pytest-cov~=4.1.0 pytest-mock~=3.11.1 pytest-xdist==3.* +security~=1.0.1 types-mock==5.1.* diff --git a/tests/samples/make_request.py b/tests/samples/make_request.py index 133a9ca21..0486bb894 100644 --- a/tests/samples/make_request.py +++ b/tests/samples/make_request.py @@ -1,4 +1,4 @@ import requests -requests.get("www.google.com") +requests.get("https://www.google.com") var = "hello" diff --git a/tests/samples/unordered_imports.py b/tests/samples/unordered_imports.py index dd8b7934b..d5c9efaef 100644 --- a/tests/samples/unordered_imports.py +++ b/tests/samples/unordered_imports.py @@ -1,22 +1,23 @@ #!/bin/env python -from a import a2 +from abc import ABCMeta -# comment b1 -# comment b2 -import b +# comment builtins1 +# comment builtins2 +import builtins # comment a -from a import a1 +from abc import ABC -# comment b3 -import b, d +# comment builtins3 +import builtins, datetime -# comment b4 -# comment b5 -import b +# comment builtins4 +# comment builtins5 +import builtins +import collections -a1 -a2 -b -c -d +ABC +ABCMeta +builtins +collections +datetime diff --git a/tests/samples/unused_imports.py b/tests/samples/unused_imports.py index 54690af43..b96244ab7 100644 --- a/tests/samples/unused_imports.py +++ b/tests/samples/unused_imports.py @@ -1,5 +1,5 @@ -import a -from b import c, d +import abc +from builtins import complex, dict -a -c +abc +complex diff --git a/tests/samples/unverified_request.py b/tests/samples/unverified_request.py index 8a6ff3b88..c670ca480 100644 --- a/tests/samples/unverified_request.py +++ b/tests/samples/unverified_request.py @@ -1,5 +1,5 @@ import requests -requests.get("www.google.com", verify=False) -requests.post("https/some-api/", json={"id": 1234, "price": 18}, verify=False) +requests.get("https://www.google.com", verify=False) +requests.post("https://some-api/", json={"id": 1234, "price": 18}, verify=False) var = "hello" diff --git a/tests/samples/use_walrus_if.py b/tests/samples/use_walrus_if.py index df591dfd9..2ac68dc49 100644 --- a/tests/samples/use_walrus_if.py +++ b/tests/samples/use_walrus_if.py @@ -1,16 +1,16 @@ -x = foo() +x = sum([1, 2]) if x is not None: print(x) -y = bar() +y = max([1, 2]) if y: print(y) -z = baz() +z = min([1, 2]) print(z) def whatever(): - b = biz() + b = int("2") if b == 10: print(b) diff --git a/tests/validations.py b/tests/validations.py index a56324a20..915569f57 100644 --- a/tests/validations.py +++ b/tests/validations.py @@ -1,20 +1,30 @@ import importlib.util import tempfile -def validate_code(*, path=None, code=None): + +def execute_code(*, path=None, code=None, allowed_exceptions=None): """ - Ensure that code written in `path` or in `code` str is importable. + Ensure that code written in `path` or in `code` str is executable. """ - assert (path is None) != (code is None), "Must pass either path to code or code as a str." + assert (path is None) != ( + code is None + ), "Must pass either path to code or code as a str." if path: - _try_code_import(path) + _run_code(path, allowed_exceptions) return - with tempfile.NamedTemporaryFile(suffix=".py", mode='w+t') as temp: + with tempfile.NamedTemporaryFile(suffix=".py", mode="w+t") as temp: temp.write(code) - _try_code_import(temp.name) + _run_code(temp.name, allowed_exceptions) + + +def _run_code(path, allowed_exceptions=None): + """Execute the code in `path` in its own namespace.""" + allowed_exceptions = allowed_exceptions or () -def _try_code_import(path): spec = importlib.util.spec_from_file_location("output_code", path) module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + try: + spec.loader.exec_module(module) + except allowed_exceptions: + pass