diff --git a/src/codemodder/codemodder.py b/src/codemodder/codemodder.py index 13e90ed25..10bbeda72 100644 --- a/src/codemodder/codemodder.py +++ b/src/codemodder/codemodder.py @@ -35,7 +35,7 @@ def update_code(file_path, new_code): def find_semgrep_results( context: CodemodExecutionContext, codemods: list[CodemodExecutorWrapper], -) -> set[str]: +) -> ResultSet: """Run semgrep once with all configuration files from all codemods and return a set of applicable rule IDs""" yaml_files = list( itertools.chain.from_iterable( @@ -43,10 +43,9 @@ def find_semgrep_results( ) ) if not yaml_files: - return set() + return ResultSet() - results = run_semgrep(context, yaml_files) - return set(results.keys()) + return run_semgrep(context, yaml_files) def apply_codemod_to_file( @@ -163,6 +162,29 @@ def analyze_files( execution_context.process_results(codemod.id, analysis_results) +def log_report(context, argv, elapsed_ms, files_to_analyze): + log_section("report") + logger.info("scanned: %s files", len(files_to_analyze)) + all_failures = context.get_failed_files() + logger.info( + "failed: %s files (%s unique)", + len(all_failures), + len(set(all_failures)), + ) + all_changes = context.get_changed_files() + logger.info( + "changed: %s files (%s unique)", + len(all_changes), + len(set(all_changes)), + ) + logger.info("report file: %s", argv.output) + logger.info("total elapsed: %s ms", elapsed_ms) + logger.info(" semgrep: %s ms", context.timer.get_time_ms("semgrep")) + logger.info(" parse: %s ms", context.timer.get_time_ms("parse")) + logger.info(" transform: %s ms", context.timer.get_time_ms("transform")) + logger.info(" write: %s ms", context.timer.get_time_ms("write")) + + def run(original_args) -> int: start = datetime.datetime.now() @@ -215,17 +237,17 @@ def run(original_args) -> int: return 0 full_names = [str(path) for path in files_to_analyze] - logger.debug("matched files:") log_list(logging.DEBUG, "matched files", full_names) - semgrep_results: set[str] = find_semgrep_results(context, codemods_to_run) + semgrep_results: ResultSet = find_semgrep_results(context, codemods_to_run) + semgrep_finding_ids = semgrep_results.all_rule_ids() log_section("scanning") # run codemods one at a time making sure to respect the given sequence for codemod in codemods_to_run: # Unfortunately the IDs from semgrep are not fully specified # TODO: eventually we need to be able to use fully specified IDs here - if codemod.is_semgrep and codemod.name not in semgrep_results: + if codemod.is_semgrep and codemod.name not in semgrep_finding_ids: logger.debug( "no results from semgrep for %s, skipping analysis", codemod.id, @@ -233,7 +255,9 @@ def run(original_args) -> int: continue logger.info("running codemod %s", codemod.id) - results = codemod.apply(context) + semgrep_files = semgrep_results.files_for_rule(codemod.name) + # Non-semgrep codemods ignore the semgrep results + results = codemod.apply(context, semgrep_files) analyze_files( context, files_to_analyze, @@ -250,27 +274,7 @@ def run(original_args) -> int: elapsed_ms = int(elapsed.total_seconds() * 1000) report_default(elapsed_ms, argv, original_args, results) - log_section("report") - logger.info("scanned: %s files", len(files_to_analyze)) - all_failures = context.get_failed_files() - logger.info( - "failed: %s files (%s unique)", - len(all_failures), - len(set(all_failures)), - ) - all_changes = context.get_changed_files() - logger.info( - "changed: %s files (%s unique)", - len(all_changes), - len(set(all_changes)), - ) - logger.info("report file: %s", argv.output) - logger.info("total elapsed: %s ms", elapsed_ms) - logger.info("semgrep: %s ms", context.timer.get_time_ms("semgrep")) - logger.info("parse: %s ms", context.timer.get_time_ms("parse")) - logger.info("transform: %s ms", context.timer.get_time_ms("transform")) - logger.info("write: %s ms", context.timer.get_time_ms("write")) - + log_report(context, argv, elapsed_ms, files_to_analyze) return 0 diff --git a/src/codemodder/codemods/base_codemod.py b/src/codemodder/codemods/base_codemod.py index 86f7430bc..b83f00c42 100644 --- a/src/codemodder/codemods/base_codemod.py +++ b/src/codemodder/codemods/base_codemod.py @@ -112,8 +112,9 @@ def apply_rule(cls, context, *args, **kwargs) -> ResultSet: Apply semgrep to gather rule results """ yaml_files = kwargs.get("yaml_files") or args[0] + files_to_analyze = kwargs.get("files_to_analyze") or args[1] with context.timer.measure("semgrep"): - return semgrep_run(context, yaml_files) + return semgrep_run(context, yaml_files, files_to_analyze) @property def should_transform(self): diff --git a/src/codemodder/executor.py b/src/codemodder/executor.py index 47910b9f3..b5420b305 100644 --- a/src/codemodder/executor.py +++ b/src/codemodder/executor.py @@ -1,4 +1,5 @@ from importlib.abc import Traversable +from pathlib import Path from wrapt import CallableObjectProxy @@ -22,13 +23,17 @@ def __init__( self.docs_module = docs_module self.semgrep_config_module = semgrep_config_module - def apply(self, context): + def apply(self, context, files: list[Path]): """ Wraps the codemod's apply method to inject additional arguments. Not all codemods will need these arguments. """ - return self.apply_rule(context, yaml_files=self.yaml_files) + return self.apply_rule( + context, + yaml_files=self.yaml_files, + files_to_analyze=files, + ) @property def name(self): diff --git a/src/codemodder/result.py b/src/codemodder/result.py index 4566d7be5..c7105a10f 100644 --- a/src/codemodder/result.py +++ b/src/codemodder/result.py @@ -34,13 +34,20 @@ class Result(ABCDataclass): locations: list[Location] -class ResultSet(dict[str, list[Result]]): +class ResultSet(dict[Path, list[Result]]): def add_result(self, result: Result): - self.setdefault(result.rule_id, []).append(result) + for location in result.locations: + self.setdefault(location.file, []).append(result) def results_for_rule_and_file(self, rule_id: str, file: Path) -> list[Result]: + return [result for result in self.get(file, []) if result.rule_id == rule_id] + + def files_for_rule(self, rule_id: str) -> list[Path]: return [ - result - for result in self.get(rule_id, []) - if result.locations[0].file == file + file + for file, results in self.items() + if any(result.rule_id == rule_id for result in results) ] + + def all_rule_ids(self) -> set[str]: + return {result.rule_id for results in self.values() for result in results} diff --git a/src/codemodder/semgrep.py b/src/codemodder/semgrep.py index 8138cb9e8..e4fc32a30 100644 --- a/src/codemodder/semgrep.py +++ b/src/codemodder/semgrep.py @@ -1,7 +1,7 @@ import subprocess import itertools from tempfile import NamedTemporaryFile -from typing import Iterable +from typing import Iterable, Optional from pathlib import Path from codemodder.context import CodemodExecutionContext from codemodder.sarifs import SarifResultSet @@ -11,6 +11,7 @@ def run( execution_context: CodemodExecutionContext, yaml_files: Iterable[Path], + files_to_analyze: Optional[Iterable[Path]] = None, ) -> SarifResultSet: """ Runs Semgrep and outputs a dict with the results organized by rule_id. @@ -34,7 +35,7 @@ def run( map(lambda f: ["--config", str(f)], yaml_files) ) ) - command.append(str(execution_context.directory)) + command.extend(map(str, files_to_analyze or [execution_context.directory])) logger.debug("semgrep command: `%s`", " ".join(command)) subprocess.run( command, diff --git a/tests/test_codemodder.py b/tests/test_codemodder.py index 775ea84c6..5dede479a 100644 --- a/tests/test_codemodder.py +++ b/tests/test_codemodder.py @@ -3,6 +3,7 @@ from codemodder.codemodder import run, find_semgrep_results from codemodder.semgrep import run as semgrep_run from codemodder.registry import load_registered_codemods +from codemodder.result import ResultSet class TestRun: @@ -187,5 +188,5 @@ def test_find_semgrep_results_no_yaml(self, mocker): codemod_include=["use-defusedxml"] ) result = find_semgrep_results(mocker.MagicMock(), codemods) - assert result == set() + assert result == ResultSet() assert run_semgrep.call_count == 0 diff --git a/tests/test_sarif_processing.py b/tests/test_sarif_processing.py index 6f837c9ad..19dc5f59a 100644 --- a/tests/test_sarif_processing.py +++ b/tests/test_sarif_processing.py @@ -32,11 +32,11 @@ def test_results_by_rule_id(self): sarif_file = Path("tests") / "samples" / "semgrep.sarif" results = SarifResultSet.from_sarif(sarif_file) - expected_rule = "secure-random" - assert list(results.keys()) == [expected_rule] - expected_path = Path("tests/samples/insecure_random.py") - assert expected_path == results[expected_rule][0].locations[0].file + assert list(results.keys()) == [expected_path] + + expected_rule = "secure-random" + assert expected_rule == results[expected_path][0].rule_id def test_codeql_sarif_input(self, tmpdir): completed_process = subprocess.run(