From 0dbf62264bdf56866e3c22873fe3dee3e00facd1 Mon Sep 17 00:00:00 2001 From: Daniel D'Avella Date: Thu, 19 Oct 2023 09:40:23 -0400 Subject: [PATCH] Create changeset for updated dependencies --- integration_tests/test_process_sandbox.py | 5 +- integration_tests/test_url_sandbox.py | 6 +- pyproject.toml | 2 +- src/codemodder/change.py | 12 +++ src/codemodder/codemodder.py | 18 ++-- src/codemodder/codemods/base_codemod.py | 3 + src/codemodder/context.py | 53 +++++++---- src/codemodder/dependency_manager.py | 93 +++++++++++++++---- src/codemodder/file_context.py | 6 +- src/core_codemods/process_creation_sandbox.py | 2 +- src/core_codemods/url_sandbox.py | 2 +- src/core_codemods/use_defused_xml.py | 2 +- tests/codemods/base_codemod_test.py | 3 + .../codemods/test_process_creation_sandbox.py | 5 + tests/codemods/test_url_sandbox.py | 6 ++ tests/codemods/test_use_defused_xml.py | 7 ++ tests/conftest.py | 4 +- tests/test_dependency_manager.py | 87 ++++++++--------- 18 files changed, 212 insertions(+), 104 deletions(-) diff --git a/integration_tests/test_process_sandbox.py b/integration_tests/test_process_sandbox.py index 46dbf1e00..707bdf891 100644 --- a/integration_tests/test_process_sandbox.py +++ b/integration_tests/test_process_sandbox.py @@ -21,10 +21,9 @@ class TestProcessSandbox(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,10 +1,11 @@\n import subprocess\n+from security import safe_command\n \n-subprocess.run("echo \'hi\'", shell=True)\n-subprocess.run(["ls", "-l"])\n+safe_command.run(subprocess.run, "echo \'hi\'", shell=True)\n+safe_command.run(subprocess.run, ["ls", "-l"])\n \n-subprocess.call("echo \'hi\'", shell=True)\n-subprocess.call(["ls", "-l"])\n+safe_command.call(subprocess.call, "echo \'hi\'", shell=True)\n+safe_command.call(subprocess.call, ["ls", "-l"])\n \n subprocess.check_output(["ls", "-l"])\n \n' expected_line_change = "3" num_changes = 4 + num_changed_files = 2 change_description = ProcessSandbox.CHANGE_DESCRIPTION requirements_path = "tests/samples/requirements.txt" original_requirements = "# file used to test dependency management\nrequests==2.31.0\nblack==23.7.*\nmypy~=1.4\npylint>1\n" - expected_new_reqs = ( - "requests==2.31.0\nblack==23.7.*\nmypy~=1.4\npylint>1\nsecurity==1.0.1" - ) + expected_new_reqs = "# file used to test dependency management\nrequests==2.31.0\nblack==23.7.*\nmypy~=1.4\npylint>1\nsecurity==1.0.1" diff --git a/integration_tests/test_url_sandbox.py b/integration_tests/test_url_sandbox.py index ad56c8031..b9014038b 100644 --- a/integration_tests/test_url_sandbox.py +++ b/integration_tests/test_url_sandbox.py @@ -19,10 +19,8 @@ class TestUrlSandbox(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n-import requests\n+from security import safe_requests\n \n-requests.get("https://www.google.com")\n+safe_requests.get("https://www.google.com")\n var = "hello"\n' expected_line_change = "3" change_description = UrlSandbox.CHANGE_DESCRIPTION - num_changed_files = 1 + num_changed_files = 2 requirements_path = "tests/samples/requirements.txt" original_requirements = "# file used to test dependency management\nrequests==2.31.0\nblack==23.7.*\nmypy~=1.4\npylint>1\n" - expected_new_reqs = ( - "requests==2.31.0\nblack==23.7.*\nmypy~=1.4\npylint>1\nsecurity==1.0.1" - ) + expected_new_reqs = "# file used to test dependency management\nrequests==2.31.0\nblack==23.7.*\nmypy~=1.4\npylint>1\nsecurity==1.0.1" diff --git a/pyproject.toml b/pyproject.toml index 99ec0a032..d9c8af5ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,9 +9,9 @@ requires-python = ">=3.10.0" readme = "README.md" license = {file = "LICENSE"} dependencies = [ - "dependency-manager @ git+https://github.com/pixee/python-dependency-manager#egg=dependency-manager", "isort~=5.12.0", "libcst~=1.1.0", + "packaging~=23.0.0", "pylint~=3.0.0", "python-json-logger~=2.0.0", "PyYAML~=6.0.0", diff --git a/src/codemodder/change.py b/src/codemodder/change.py index f2cda9257..4429f5d53 100644 --- a/src/codemodder/change.py +++ b/src/codemodder/change.py @@ -15,3 +15,15 @@ def to_json(self): "properties": self.properties, "packageActions": self.packageActions, } + + +@dataclass +class ChangeSet: + """A set of changes made to a file at `path`""" + + path: str + diff: str + changes: list[Change] + + def to_json(self): + return {"path": self.path, "diff": self.diff, "changes": self.changes} diff --git a/src/codemodder/codemodder.py b/src/codemodder/codemodder.py index f06f5d508..69603d3a5 100644 --- a/src/codemodder/codemodder.py +++ b/src/codemodder/codemodder.py @@ -4,7 +4,6 @@ import os import sys from pathlib import Path -from textwrap import indent import libcst as cst from libcst.codemod import CodemodContext @@ -13,9 +12,9 @@ from codemodder import registry, __VERSION__ from codemodder.logging import configure_logger, logger, log_section, log_list from codemodder.cli import parse_args +from codemodder.change import ChangeSet from codemodder.code_directory import file_line_patterns, match_files -from codemodder.context import CodemodExecutionContext, ChangeSet -from codemodder.dependency_manager import write_dependencies +from codemodder.context import CodemodExecutionContext from codemodder.executor import CodemodExecutorWrapper from codemodder.report.codetf_reporter import report_default @@ -76,6 +75,7 @@ def analyze_files( sarif, cli_args, ): + # TODO: parallelize this loop for idx, file_path in enumerate(files_to_analyze): logger.debug("scanning file %s", file_path) if idx and idx % 100 == 0: @@ -93,6 +93,7 @@ def analyze_files( line_include = file_line_patterns(file_path, cli_args.path_include) sarif_for_file = sarif.get(str(file_path)) or {} + # NOTE: file context will become more important if/when we parallelize this loop file_context = FileContext( file_path, line_exclude, @@ -107,13 +108,7 @@ def analyze_files( source_tree, ) - if failures := execution_context.get_failures(codemod.id): - log_list(logging.INFO, "failed", failures) - if changes := execution_context.get_results(codemod.id): - logger.info("changed:") - for change in changes: - logger.info(" - %s", change.path) - logger.debug(" diff:\n%s", indent(change.diff, " " * 6)) + execution_context.add_dependencies(codemod.id, file_context.dependencies) def run(original_args) -> int: @@ -179,10 +174,11 @@ def run(original_args) -> int: results, argv, ) + context.process_dependencies(codemod.id) + context.log_changes(codemod.id) results = context.compile_results(codemods_to_run) - write_dependencies(context) elapsed = datetime.datetime.now() - start elapsed_ms = int(elapsed.total_seconds() * 1000) report_default(elapsed_ms, argv, original_args, results) diff --git a/src/codemodder/codemods/base_codemod.py b/src/codemodder/codemods/base_codemod.py index d2a4fe4ae..53579fd01 100644 --- a/src/codemodder/codemods/base_codemod.py +++ b/src/codemodder/codemods/base_codemod.py @@ -99,6 +99,9 @@ def line_exclude(self): def line_include(self): return self.file_context.line_include + def add_dependency(self, dependency: str): + self.file_context.add_dependency(dependency) + class SemgrepCodemod(BaseCodemod): YAML_FILES: ClassVar[List[str]] = NotImplemented diff --git a/src/codemodder/context.py b/src/codemodder/context.py index bd8d982b2..9b88e8a7c 100644 --- a/src/codemodder/context.py +++ b/src/codemodder/context.py @@ -1,28 +1,19 @@ +import logging from pathlib import Path -from dataclasses import dataclass import itertools +from textwrap import indent -from codemodder.change import Change +from codemodder.change import ChangeSet +from codemodder.dependency_manager import DependencyManager from codemodder.executor import CodemodExecutorWrapper +from codemodder.logging import logger, log_list from codemodder.registry import CodemodRegistry -@dataclass -class ChangeSet: - """A set of changes made to a file at `path`""" - - path: str - diff: str - changes: list[Change] - - def to_json(self): - return {"path": self.path, "diff": self.diff, "changes": self.changes} - - class CodemodExecutionContext: # pylint: disable=too-many-instance-attributes _results_by_codemod: dict[str, list[ChangeSet]] = {} _failures_by_codemod: dict[str, list[Path]] = {} - dependencies: set[str] + dependencies: dict[str, set[str]] = {} directory: Path dry_run: bool = False verbose: bool = False @@ -38,9 +29,9 @@ def __init__( self.directory = directory self.dry_run = dry_run self.verbose = verbose - self.dependencies = set() self._results_by_codemod = {} self._failures_by_codemod = {} + self.dependencies = {} self.registry = registry def add_result(self, codemod_name, change_set): @@ -49,6 +40,9 @@ def add_result(self, codemod_name, change_set): def add_failure(self, codemod_name, file_path): self._failures_by_codemod.setdefault(codemod_name, []).append(file_path) + def add_dependencies(self, codemod_id: str, dependencies: set[str]): + self.dependencies.setdefault(codemod_id, set()).update(dependencies) + def get_results(self, codemod_name): return self._results_by_codemod.get(codemod_name, []) @@ -69,8 +63,22 @@ def get_failed_files(self): ) ) - def add_dependency(self, dependency: str): - self.dependencies.add(dependency) + def process_dependencies(self, codemod_id: str): + dependencies = self.dependencies.get(codemod_id) + if not dependencies: + return + + dm = DependencyManager(self.directory) + if not dm.found_dependency_file: + logger.info( + "unable to write dependencies for %s: no dependency file found", + codemod_id, + ) + return + + dm.add(list(dependencies)) + if (changeset := dm.write(self.dry_run)) is not None: + self.add_result(codemod_id, changeset) def compile_results(self, codemods: list[CodemodExecutorWrapper]): results = [] @@ -90,3 +98,12 @@ def compile_results(self, codemods: list[CodemodExecutorWrapper]): results.append(data) return results + + def log_changes(self, codemod_id: str): + if failures := self.get_failures(codemod_id): + log_list(logging.INFO, "failed", failures) + if changes := self.get_results(codemod_id): + logger.info("changed:") + for change in changes: + logger.info(" - %s", change.path) + logger.debug(" diff:\n%s", indent(change.diff, " " * 6)) diff --git a/src/codemodder/dependency_manager.py b/src/codemodder/dependency_manager.py index 0ab88cf73..90a5aaf29 100644 --- a/src/codemodder/dependency_manager.py +++ b/src/codemodder/dependency_manager.py @@ -1,25 +1,82 @@ -import sys -import io - -from dependency_manager import DependencyManagerAbstract +from functools import cached_property from pathlib import Path +from typing import Optional + +import difflib +from packaging.requirements import Requirement + +from codemodder.change import ChangeSet + + +class DependencyManager: + parent_directory: Path + _lines: list[str] + _new_requirements: list[str] + + def __init__(self, parent_directory: Path): + """One-time class initialization.""" + self.parent_directory = parent_directory + self.dependency_file_changed = False + self._lines = [] + self._new_requirements = [] + + def add(self, dependencies: list[str]): + """add any number of dependencies to the end of list of dependencies.""" + for dep_str in dependencies: + dep = Requirement(dep_str) + if dep not in self.dependencies: + self.dependencies.update({dep: None}) + self._new_requirements.append(str(dep)) + + def write(self, dry_run: bool = False) -> Optional[ChangeSet]: + """ + Write the updated dependency files if any changes were made. + """ + if not (self.dependency_file and self._new_requirements): + return None + + updated = self._lines + self._new_requirements + ["\n"] + + diff = "".join(difflib.unified_diff(self._lines, updated)) + # TODO: add a change entry for each new requirement + # TODO: make sure to set the contextual_description=True in the properties bag + + if not dry_run: + with open(self.dependency_file, "w", encoding="utf-8") as f: + f.writelines(self._lines) + f.writelines(self._new_requirements) -from codemodder.context import CodemodExecutionContext + self.dependency_file_changed = True + return ChangeSet(str(self.dependency_file), diff, changes=[]) + @property + def found_dependency_file(self) -> bool: + return self.dependency_file is not None -def write_dependencies(execution_context: CodemodExecutionContext): - class DependencyManager(DependencyManagerAbstract): - def get_parent_dir(self): - return Path(execution_context.directory) + @cached_property + def dependency_file(self) -> Optional[Path]: + try: + # For now for simplicity only return the first file + return next(Path(self.parent_directory).rglob("requirements.txt")) + except StopIteration: + pass + return None - dm = DependencyManager() - dm.add(list(execution_context.dependencies)) + @cached_property + def dependencies(self) -> dict[Requirement, None]: + """ + Extract list of dependencies from requirements.txt file. + Same order of requirements is maintained, no alphabetical sorting is done. + """ + if not self.dependency_file: + return {} - try: - # Hacky solution to prevent the dependency manager from writing to stdout - sys.stdout = io.StringIO() - dm.write(dry_run=execution_context.dry_run) - finally: - sys.stdout = sys.__stdout__ + with open(self.dependency_file, "r", encoding="utf-8") as f: + self._lines = f.readlines() - return dm + return { + Requirement(line): None + for x in self._lines + # Skip empty lines and comments + if (line := x.strip()) and not line.startswith("#") + } diff --git a/src/codemodder/file_context.py b/src/codemodder/file_context.py index c1c37e2d1..3b6641b4c 100644 --- a/src/codemodder/file_context.py +++ b/src/codemodder/file_context.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List @@ -13,6 +13,7 @@ class FileContext: line_exclude: List[int] line_include: List[int] results_by_id: Dict + dependencies: set[str] = field(default_factory=set) def __post_init__(self): if self.line_include is None: @@ -20,3 +21,6 @@ def __post_init__(self): if self.line_exclude is None: self.line_exclude = [] self.codemod_changes = [] + + def add_dependency(self, dependency: str): + self.dependencies.add(dependency) diff --git a/src/core_codemods/process_creation_sandbox.py b/src/core_codemods/process_creation_sandbox.py index 11d99e57d..ed0a4d7b7 100644 --- a/src/core_codemods/process_creation_sandbox.py +++ b/src/core_codemods/process_creation_sandbox.py @@ -40,7 +40,7 @@ def rule(cls): def on_result_found(self, original_node, updated_node): self.add_needed_import("security", "safe_command") - self.execution_context.add_dependency("security==1.0.1") + self.add_dependency("security==1.0.1") return self.update_call_target( updated_node, "safe_command", diff --git a/src/core_codemods/url_sandbox.py b/src/core_codemods/url_sandbox.py index e068200e5..35be3ea86 100644 --- a/src/core_codemods/url_sandbox.py +++ b/src/core_codemods/url_sandbox.py @@ -72,7 +72,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: find_requests_visitor.changes_in_file ) new_tree = tree.visit(ReplaceNodes(find_requests_visitor.nodes_to_change)) - self.execution_context.add_dependency("security==1.0.1") + self.add_dependency("security==1.0.1") # if it finds any request.get(...), try to remove the imports if any( ( diff --git a/src/core_codemods/use_defused_xml.py b/src/core_codemods/use_defused_xml.py index 7508961c6..f8307b929 100644 --- a/src/core_codemods/use_defused_xml.py +++ b/src/core_codemods/use_defused_xml.py @@ -91,6 +91,6 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: result_tree = visitor.transform_module(tree) self.file_context.codemod_changes.extend(visitor.changes_in_file) if visitor.changes_in_file: - self.execution_context.add_dependency("defusedxml") + self.add_dependency("defusedxml") # TODO: which version? return result_tree diff --git a/tests/codemods/base_codemod_test.py b/tests/codemods/base_codemod_test.py index 7910d5bfe..c35cd4ccc 100644 --- a/tests/codemods/base_codemod_test.py +++ b/tests/codemods/base_codemod_test.py @@ -49,6 +49,9 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected): assert output_tree.code == dedent(expected) + def assert_dependency(self, dependency: str): + assert self.file_context and self.file_context.dependencies == set([dependency]) + class BaseSemgrepCodemodTest(BaseCodemodTest): @classmethod diff --git a/tests/codemods/test_process_creation_sandbox.py b/tests/codemods/test_process_creation_sandbox.py index c53f7f2ad..48761f4f7 100644 --- a/tests/codemods/test_process_creation_sandbox.py +++ b/tests/codemods/test_process_creation_sandbox.py @@ -22,6 +22,7 @@ def test_import_subprocess(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) + self.assert_dependency("security==1.0.1") def test_import_alias(self, tmpdir): input_code = """import subprocess as sub @@ -36,6 +37,7 @@ def test_import_alias(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) + self.assert_dependency("security==1.0.1") def test_from_subprocess(self, tmpdir): input_code = """from subprocess import run @@ -50,6 +52,7 @@ def test_from_subprocess(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) + self.assert_dependency("security==1.0.1") def test_subprocess_nameerror(self, tmpdir): input_code = """subprocess.run("echo 'hi'", shell=True) @@ -94,6 +97,7 @@ def test_subprocess_nameerror(self, tmpdir): ) def test_other_import_untouched(self, tmpdir, input_code, expected): self.run_and_assert(tmpdir, input_code, expected) + self.assert_dependency("security==1.0.1") def test_multifunctions(self, tmpdir): # Test that subprocess methods that aren't part of the codemod are not changed. @@ -111,6 +115,7 @@ def test_multifunctions(self, tmpdir): subprocess.check_output(["ls", "-l"])""" self.run_and_assert(tmpdir, input_code, expected) + self.assert_dependency("security==1.0.1") def test_custom_run(self, tmpdir): input_code = """from app_funcs import run diff --git a/tests/codemods/test_url_sandbox.py b/tests/codemods/test_url_sandbox.py index 2264bbd33..09e7193b5 100644 --- a/tests/codemods/test_url_sandbox.py +++ b/tests/codemods/test_url_sandbox.py @@ -21,6 +21,7 @@ def test_import_requests(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) + self.assert_dependency("security==1.0.1") def test_from_requests(self, tmpdir): input_code = """from requests import get @@ -34,6 +35,7 @@ def test_from_requests(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) + self.assert_dependency("security==1.0.1") def test_requests_nameerror(self, tmpdir): input_code = """requests.get("www.google.com") @@ -76,6 +78,7 @@ def test_requests_nameerror(self, tmpdir): ) def test_requests_other_import_untouched(self, tmpdir, input_code, expected): self.run_and_assert(tmpdir, input_code, expected) + self.assert_dependency("security==1.0.1") def test_requests_multifunctions(self, tmpdir): # Test that `requests` import isn't removed if code uses part of the requests @@ -93,6 +96,7 @@ def test_requests_multifunctions(self, tmpdir): requests.status_codes.codes.FORBIDDEN""" self.run_and_assert(tmpdir, input_code, expected) + self.assert_dependency("security==1.0.1") def test_custom_get(self, tmpdir): input_code = """from app_funcs import get @@ -123,6 +127,7 @@ def test_from_requests_with_alias(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) + self.assert_dependency("security==1.0.1") def test_requests_with_alias(self, tmpdir): input_code = """import requests as req @@ -136,3 +141,4 @@ def test_requests_with_alias(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) + self.assert_dependency("security==1.0.1") diff --git a/tests/codemods/test_use_defused_xml.py b/tests/codemods/test_use_defused_xml.py index 2e9549e57..668060d06 100644 --- a/tests/codemods/test_use_defused_xml.py +++ b/tests/codemods/test_use_defused_xml.py @@ -29,6 +29,7 @@ def test_etree_simple_call(self, tmpdir, module, method): """ self.run_and_assert(tmpdir, original_code, new_code) + self.assert_dependency("defusedxml") @pytest.mark.parametrize("method", ETREE_METHODS) @pytest.mark.parametrize("module", ["ElementTree", "cElementTree"]) @@ -46,6 +47,7 @@ def test_etree_attribute_call(self, tmpdir, module, method): """ self.run_and_assert(tmpdir, original_code, new_code) + self.assert_dependency("defusedxml") def test_etree_elementtree_with_alias(self, tmpdir): original_code = """ @@ -61,6 +63,7 @@ def test_etree_elementtree_with_alias(self, tmpdir): """ self.run_and_assert(tmpdir, original_code, new_code) + self.assert_dependency("defusedxml") def test_etree_parse_with_alias(self, tmpdir): original_code = """ @@ -76,6 +79,7 @@ def test_etree_parse_with_alias(self, tmpdir): """ self.run_and_assert(tmpdir, original_code, new_code) + self.assert_dependency("defusedxml") @pytest.mark.parametrize("method", SAX_METHODS) def test_sax_simple_call(self, tmpdir, method): @@ -92,6 +96,7 @@ def test_sax_simple_call(self, tmpdir, method): """ self.run_and_assert(tmpdir, original_code, new_code) + self.assert_dependency("defusedxml") @pytest.mark.parametrize("method", SAX_METHODS) def test_sax_attribute_call(self, tmpdir, method): @@ -108,6 +113,7 @@ def test_sax_attribute_call(self, tmpdir, method): """ self.run_and_assert(tmpdir, original_code, new_code) + self.assert_dependency("defusedxml") @pytest.mark.parametrize("method", DOM_METHODS) @pytest.mark.parametrize("module", ["minidom", "pulldom"]) @@ -125,3 +131,4 @@ def test_dom_simple_call(self, tmpdir, module, method): """ self.run_and_assert(tmpdir, original_code, new_code) + self.assert_dependency("defusedxml") diff --git a/tests/conftest.py b/tests/conftest.py index d87046c36..22a7facd8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,9 +44,7 @@ def disable_write_dependencies(): """ Unit tests should not write any dependency files """ - dm_write = mock.patch( - "codemodder.dependency_manager.DependencyManagerAbstract._write" - ) + dm_write = mock.patch("codemodder.dependency_manager.DependencyManager.write") dm_write.start() yield diff --git a/tests/test_dependency_manager.py b/tests/test_dependency_manager.py index 04cb6e1f2..bf62b99c3 100644 --- a/tests/test_dependency_manager.py +++ b/tests/test_dependency_manager.py @@ -1,48 +1,51 @@ -import mock from pathlib import Path -from codemodder.codemodder import run -from codemodder.semgrep import run as semgrep_run -from codemodder.dependency_manager import write_dependencies + +import pytest + +from codemodder.dependency_manager import ( + DependencyManager, + Requirement, +) + + +@pytest.fixture(autouse=True, scope="module") +def disable_write_dependencies(): + """Override fixture from conftest.py in order to allow testing""" class TestDependencyManager: TEST_DIR = "tests/" - def test_init_once(self, mocker): - context = mocker.Mock(directory=self.TEST_DIR, dependencies=[]) - dm = write_dependencies(context) - expected_path = Path(self.TEST_DIR) - assert dm.parent_directory == expected_path - assert dm.dependency_file == expected_path / "samples" / "requirements.txt" - assert len(dm.dependencies) == 4 - - dep = next(iter(dm.dependencies)) - assert str(dep) == "requests==2.31.0" - - @mock.patch("codemodder.dependency_manager.DependencyManagerAbstract._write") - @mock.patch("codemodder.codemods.base_codemod.semgrep_run", side_effect=semgrep_run) - def test_dont_write(self, _, write_mock): - # Tests that dependency manager does not write to file if only - # codemods that don't change dependencies run. - args = [ - "tests/samples/", - "--output", - "here.txt", - "--codemod-include=secure-random", - ] - res = run(args) - assert res == 0 - write_mock.assert_not_called() - - @mock.patch("codemodder.dependency_manager.DependencyManagerAbstract._write") - @mock.patch("codemodder.codemods.base_codemod.semgrep_run", side_effect=semgrep_run) - def test_write_expected(self, _, write_mock): - args = [ - "tests/samples/", - "--output", - "here.txt", - "--codemod-include=url-sandbox", - ] - res = run(args) - assert res == 0 - write_mock.assert_called() + def test_read_dependency_file(self, tmpdir): + dependency_file = Path(tmpdir) / "requirements.txt" + dependency_file.write_text("requests\n", encoding="utf-8") + + dm = DependencyManager(Path(tmpdir)) + assert dm.dependencies == {Requirement("requests"): None} + + @pytest.mark.parametrize("dry_run", [True, False]) + def test_add_dependency_preserve_comments(self, tmpdir, dry_run): + contents = "# comment\n\nrequests\n" + dependency_file = Path(tmpdir) / "requirements.txt" + dependency_file.write_text(contents, encoding="utf-8") + + dm = DependencyManager(Path(tmpdir)) + dm.add(["defusedxml"]) + changeset = dm.write(dry_run=dry_run) + + assert dependency_file.read_text(encoding="utf-8") == ( + contents if dry_run else "# comment\n\nrequests\ndefusedxml" + ) + + assert changeset is not None + assert changeset.path == str(dependency_file) + assert changeset.diff == ( + "--- \n" + "+++ \n" + "@@ -1,3 +1,5 @@\n" + " # comment\n" + " \n" + " requests\n" + "+defusedxml+\n" + ) + assert changeset.changes == []