Skip to content

Commit

Permalink
Only scan semgrep files with initial findings
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Nov 3, 2023
1 parent d794259 commit 2efaf98
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 44 deletions.
62 changes: 33 additions & 29 deletions src/codemodder/codemodder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,17 @@ def update_code(file_path, new_code):
def find_semgrep_results(
context: CodemodExecutionContext,
codemods: list[CodemodExecutorWrapper],
) -> set[str]:
) -> ResultSet:
"""Run semgrep once with all configuration files from all codemods and return a set of applicable rule IDs"""
yaml_files = list(
itertools.chain.from_iterable(
[codemod.yaml_files for codemod in codemods if codemod.yaml_files]
)
)
if not yaml_files:
return set()
return ResultSet()

results = run_semgrep(context, yaml_files)
return set(results.keys())
return run_semgrep(context, yaml_files)


def apply_codemod_to_file(
Expand Down Expand Up @@ -163,6 +162,29 @@ def analyze_files(
execution_context.process_results(codemod.id, analysis_results)


def log_report(context, argv, elapsed_ms, files_to_analyze):
log_section("report")
logger.info("scanned: %s files", len(files_to_analyze))
all_failures = context.get_failed_files()
logger.info(
"failed: %s files (%s unique)",
len(all_failures),
len(set(all_failures)),
)
all_changes = context.get_changed_files()
logger.info(
"changed: %s files (%s unique)",
len(all_changes),
len(set(all_changes)),
)
logger.info("report file: %s", argv.output)
logger.info("total elapsed: %s ms", elapsed_ms)
logger.info(" semgrep: %s ms", context.timer.get_time_ms("semgrep"))
logger.info(" parse: %s ms", context.timer.get_time_ms("parse"))
logger.info(" transform: %s ms", context.timer.get_time_ms("transform"))
logger.info(" write: %s ms", context.timer.get_time_ms("write"))


def run(original_args) -> int:
start = datetime.datetime.now()

Expand Down Expand Up @@ -215,25 +237,27 @@ def run(original_args) -> int:
return 0

full_names = [str(path) for path in files_to_analyze]
logger.debug("matched files:")
log_list(logging.DEBUG, "matched files", full_names)

semgrep_results: set[str] = find_semgrep_results(context, codemods_to_run)
semgrep_results: ResultSet = find_semgrep_results(context, codemods_to_run)
semgrep_finding_ids = semgrep_results.all_rule_ids()

log_section("scanning")
# run codemods one at a time making sure to respect the given sequence
for codemod in codemods_to_run:
# Unfortunately the IDs from semgrep are not fully specified
# TODO: eventually we need to be able to use fully specified IDs here
if codemod.is_semgrep and codemod.name not in semgrep_results:
if codemod.is_semgrep and codemod.name not in semgrep_finding_ids:
logger.debug(
"no results from semgrep for %s, skipping analysis",
codemod.id,
)
continue

logger.info("running codemod %s", codemod.id)
results = codemod.apply(context)
semgrep_files = semgrep_results.files_for_rule(codemod.name)
# Non-semgrep codemods ignore the semgrep results
results = codemod.apply(context, semgrep_files)
analyze_files(
context,
files_to_analyze,
Expand All @@ -250,27 +274,7 @@ def run(original_args) -> int:
elapsed_ms = int(elapsed.total_seconds() * 1000)
report_default(elapsed_ms, argv, original_args, results)

log_section("report")
logger.info("scanned: %s files", len(files_to_analyze))
all_failures = context.get_failed_files()
logger.info(
"failed: %s files (%s unique)",
len(all_failures),
len(set(all_failures)),
)
all_changes = context.get_changed_files()
logger.info(
"changed: %s files (%s unique)",
len(all_changes),
len(set(all_changes)),
)
logger.info("report file: %s", argv.output)
logger.info("total elapsed: %s ms", elapsed_ms)
logger.info("semgrep: %s ms", context.timer.get_time_ms("semgrep"))
logger.info("parse: %s ms", context.timer.get_time_ms("parse"))
logger.info("transform: %s ms", context.timer.get_time_ms("transform"))
logger.info("write: %s ms", context.timer.get_time_ms("write"))

log_report(context, argv, elapsed_ms, files_to_analyze)
return 0


Expand Down
3 changes: 2 additions & 1 deletion src/codemodder/codemods/base_codemod.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ def apply_rule(cls, context, *args, **kwargs) -> ResultSet:
Apply semgrep to gather rule results
"""
yaml_files = kwargs.get("yaml_files") or args[0]
files_to_analyze = kwargs.get("files_to_analyze") or args[1]
with context.timer.measure("semgrep"):
return semgrep_run(context, yaml_files)
return semgrep_run(context, yaml_files, files_to_analyze)

@property
def should_transform(self):
Expand Down
9 changes: 7 additions & 2 deletions src/codemodder/executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from importlib.abc import Traversable
from pathlib import Path

from wrapt import CallableObjectProxy

Expand All @@ -22,13 +23,17 @@ def __init__(
self.docs_module = docs_module
self.semgrep_config_module = semgrep_config_module

def apply(self, context):
def apply(self, context, files: list[Path]):
"""
Wraps the codemod's apply method to inject additional arguments.
Not all codemods will need these arguments.
"""
return self.apply_rule(context, yaml_files=self.yaml_files)
return self.apply_rule(
context,
yaml_files=self.yaml_files,
files_to_analyze=files,
)

@property
def name(self):
Expand Down
17 changes: 12 additions & 5 deletions src/codemodder/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,20 @@ class Result(ABCDataclass):
locations: list[Location]


class ResultSet(dict[str, list[Result]]):
class ResultSet(dict[Path, list[Result]]):
def add_result(self, result: Result):
self.setdefault(result.rule_id, []).append(result)
for location in result.locations:
self.setdefault(location.file, []).append(result)

def results_for_rule_and_file(self, rule_id: str, file: Path) -> list[Result]:
return [result for result in self.get(file, []) if result.rule_id == rule_id]

def files_for_rule(self, rule_id: str) -> list[Path]:
return [
result
for result in self.get(rule_id, [])
if result.locations[0].file == file
file
for file, results in self.items()
if any(result.rule_id == rule_id for result in results)
]

def all_rule_ids(self) -> set[str]:
return {result.rule_id for results in self.values() for result in results}
5 changes: 3 additions & 2 deletions src/codemodder/semgrep.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import subprocess
import itertools
from tempfile import NamedTemporaryFile
from typing import Iterable
from typing import Iterable, Optional
from pathlib import Path
from codemodder.context import CodemodExecutionContext
from codemodder.sarifs import SarifResultSet
Expand All @@ -11,6 +11,7 @@
def run(
execution_context: CodemodExecutionContext,
yaml_files: Iterable[Path],
files_to_analyze: Optional[Iterable[Path]] = None,
) -> SarifResultSet:
"""
Runs Semgrep and outputs a dict with the results organized by rule_id.
Expand All @@ -34,7 +35,7 @@ def run(
map(lambda f: ["--config", str(f)], yaml_files)
)
)
command.append(str(execution_context.directory))
command.extend(map(str, files_to_analyze or [execution_context.directory]))
logger.debug("semgrep command: `%s`", " ".join(command))
subprocess.run(
command,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_codemodder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from codemodder.codemodder import run, find_semgrep_results
from codemodder.semgrep import run as semgrep_run
from codemodder.registry import load_registered_codemods
from codemodder.result import ResultSet


class TestRun:
Expand Down Expand Up @@ -187,5 +188,5 @@ def test_find_semgrep_results_no_yaml(self, mocker):
codemod_include=["use-defusedxml"]
)
result = find_semgrep_results(mocker.MagicMock(), codemods)
assert result == set()
assert result == ResultSet()
assert run_semgrep.call_count == 0
8 changes: 4 additions & 4 deletions tests/test_sarif_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def test_results_by_rule_id(self):
sarif_file = Path("tests") / "samples" / "semgrep.sarif"

results = SarifResultSet.from_sarif(sarif_file)
expected_rule = "secure-random"
assert list(results.keys()) == [expected_rule]

expected_path = Path("tests/samples/insecure_random.py")
assert expected_path == results[expected_rule][0].locations[0].file
assert list(results.keys()) == [expected_path]

expected_rule = "secure-random"
assert expected_rule == results[expected_path][0].rule_id

def test_codeql_sarif_input(self, tmpdir):
completed_process = subprocess.run(
Expand Down

0 comments on commit 2efaf98

Please sign in to comment.