Skip to content

Commit

Permalink
Only scan semgrep files with initial findings
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Nov 6, 2023
1 parent bd3e1d4 commit b22c1ed
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 49 deletions.
62 changes: 33 additions & 29 deletions src/codemodder/codemodder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,17 @@ def update_code(file_path, new_code):
def find_semgrep_results(
context: CodemodExecutionContext,
codemods: list[CodemodExecutorWrapper],
) -> set[str]:
) -> ResultSet:
"""Run semgrep once with all configuration files from all codemods and return a set of applicable rule IDs"""
yaml_files = list(
itertools.chain.from_iterable(
[codemod.yaml_files for codemod in codemods if codemod.yaml_files]
)
)
if not yaml_files:
return set()
return ResultSet()

results = run_semgrep(context, yaml_files)
return set(results.keys())
return run_semgrep(context, yaml_files)


def create_diff(original_tree: cst.Module, new_tree: cst.Module) -> str:
Expand Down Expand Up @@ -173,6 +172,29 @@ def analyze_files(
execution_context.process_results(codemod.id, analysis_results)


def log_report(context, argv, elapsed_ms, files_to_analyze):
log_section("report")
logger.info("scanned: %s files", len(files_to_analyze))
all_failures = context.get_failed_files()
logger.info(
"failed: %s files (%s unique)",
len(all_failures),
len(set(all_failures)),
)
all_changes = context.get_changed_files()
logger.info(
"changed: %s files (%s unique)",
len(all_changes),
len(set(all_changes)),
)
logger.info("report file: %s", argv.output)
logger.info("total elapsed: %s ms", elapsed_ms)
logger.info(" semgrep: %s ms", context.timer.get_time_ms("semgrep"))
logger.info(" parse: %s ms", context.timer.get_time_ms("parse"))
logger.info(" transform: %s ms", context.timer.get_time_ms("transform"))
logger.info(" write: %s ms", context.timer.get_time_ms("write"))


def run(original_args) -> int:
start = datetime.datetime.now()

Expand Down Expand Up @@ -225,25 +247,27 @@ def run(original_args) -> int:
return 0

full_names = [str(path) for path in files_to_analyze]
logger.debug("matched files:")
log_list(logging.DEBUG, "matched files", full_names)

semgrep_results: set[str] = find_semgrep_results(context, codemods_to_run)
semgrep_results: ResultSet = find_semgrep_results(context, codemods_to_run)
semgrep_finding_ids = semgrep_results.all_rule_ids()

log_section("scanning")
# run codemods one at a time making sure to respect the given sequence
for codemod in codemods_to_run:
# Unfortunately the IDs from semgrep are not fully specified
# TODO: eventually we need to be able to use fully specified IDs here
if codemod.is_semgrep and codemod.name not in semgrep_results:
if codemod.is_semgrep and codemod.name not in semgrep_finding_ids:
logger.debug(
"no results from semgrep for %s, skipping analysis",
codemod.id,
)
continue

logger.info("running codemod %s", codemod.id)
results = codemod.apply(context)
semgrep_files = semgrep_results.files_for_rule(codemod.name)
# Non-semgrep codemods ignore the semgrep results
results = codemod.apply(context, semgrep_files)
analyze_files(
context,
files_to_analyze,
Expand All @@ -260,27 +284,7 @@ def run(original_args) -> int:
elapsed_ms = int(elapsed.total_seconds() * 1000)
report_default(elapsed_ms, argv, original_args, results)

log_section("report")
logger.info("scanned: %s files", len(files_to_analyze))
all_failures = context.get_failed_files()
logger.info(
"failed: %s files (%s unique)",
len(all_failures),
len(set(all_failures)),
)
all_changes = context.get_changed_files()
logger.info(
"changed: %s files (%s unique)",
len(all_changes),
len(set(all_changes)),
)
logger.info("report file: %s", argv.output)
logger.info("total elapsed: %s ms", elapsed_ms)
logger.info("semgrep: %s ms", context.timer.get_time_ms("semgrep"))
logger.info("parse: %s ms", context.timer.get_time_ms("parse"))
logger.info("transform: %s ms", context.timer.get_time_ms("transform"))
logger.info("write: %s ms", context.timer.get_time_ms("write"))

log_report(context, argv, elapsed_ms, files_to_analyze)
return 0


Expand Down
3 changes: 2 additions & 1 deletion src/codemodder/codemods/base_codemod.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ def apply_rule(cls, context, *args, **kwargs) -> ResultSet:
Apply semgrep to gather rule results
"""
yaml_files = kwargs.get("yaml_files") or args[0]
files_to_analyze = kwargs.get("files_to_analyze") or args[1]
with context.timer.measure("semgrep"):
return semgrep_run(context, yaml_files)
return semgrep_run(context, yaml_files, files_to_analyze)

@property
def should_transform(self):
Expand Down
9 changes: 7 additions & 2 deletions src/codemodder/executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from importlib.abc import Traversable
from pathlib import Path

from wrapt import CallableObjectProxy

Expand All @@ -22,13 +23,17 @@ def __init__(
self.docs_module = docs_module
self.semgrep_config_module = semgrep_config_module

def apply(self, context):
def apply(self, context, files: list[Path]):
"""
Wraps the codemod's apply method to inject additional arguments.
Not all codemods will need these arguments.
"""
return self.apply_rule(context, yaml_files=self.yaml_files)
return self.apply_rule(
context,
yaml_files=self.yaml_files,
files_to_analyze=files,
)

@property
def name(self):
Expand Down
17 changes: 10 additions & 7 deletions src/codemodder/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ class Result(ABCDataclass):
locations: list[Location]


class ResultSet(dict[str, list[Result]]):
class ResultSet(dict[str, dict[Path, list[Result]]]):
def add_result(self, result: Result):
self.setdefault(result.rule_id, []).append(result)
for loc in result.locations:
self.setdefault(result.rule_id, {}).setdefault(loc.file, []).append(result)

def results_for_rule_and_file(self, rule_id: str, file: Path) -> list[Result]:
return [
result
for result in self.get(rule_id, [])
if result.locations[0].file == file
]
return self.get(rule_id, {}).get(file, [])

def files_for_rule(self, rule_id: str) -> list[Path]:
return list(self.get(rule_id, {}).keys())

def all_rule_ids(self) -> list[str]:
return list(self.keys())
5 changes: 3 additions & 2 deletions src/codemodder/semgrep.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import subprocess
import itertools
from tempfile import NamedTemporaryFile
from typing import Iterable
from typing import Iterable, Optional
from pathlib import Path
from codemodder.context import CodemodExecutionContext
from codemodder.sarifs import SarifResultSet
Expand All @@ -11,6 +11,7 @@
def run(
execution_context: CodemodExecutionContext,
yaml_files: Iterable[Path],
files_to_analyze: Optional[Iterable[Path]] = None,
) -> SarifResultSet:
"""
Runs Semgrep and outputs a dict with the results organized by rule_id.
Expand All @@ -34,7 +35,7 @@ def run(
map(lambda f: ["--config", str(f)], yaml_files)
)
)
command.append(str(execution_context.directory))
command.extend(map(str, files_to_analyze or [execution_context.directory]))
logger.debug("semgrep command: `%s`", " ".join(command))
subprocess.run(
command,
Expand Down
7 changes: 1 addition & 6 deletions tests/codemods/base_codemod_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,7 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected):
)
input_tree = cst.parse_module(input_code)
all_results = self.results_by_id_filepath(input_code, file_path)
results = [
result
for entry in all_results.values()
for result in entry
if result.rule_id == self.codemod.name()
]
results = all_results.results_for_rule_and_file(self.codemod.name(), file_path)
self.file_context = FileContext(
root,
file_path,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_codemodder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from codemodder.codemodder import create_diff, run, find_semgrep_results
from codemodder.semgrep import run as semgrep_run
from codemodder.registry import load_registered_codemods
from codemodder.result import ResultSet


class TestRun:
Expand Down Expand Up @@ -190,7 +191,7 @@ def test_find_semgrep_results_no_yaml(self, mocker):
codemod_include=["use-defusedxml"]
)
result = find_semgrep_results(mocker.MagicMock(), codemods)
assert result == set()
assert result == ResultSet()
assert run_semgrep.call_count == 0

def test_diff_newline_edge_case(self):
Expand Down
7 changes: 6 additions & 1 deletion tests/test_sarif_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ def test_results_by_rule_id(self):
assert list(results.keys()) == [expected_rule]

expected_path = Path("tests/samples/insecure_random.py")
assert expected_path == results[expected_rule][0].locations[0].file
assert list(results[expected_rule].keys()) == [expected_path]

assert results[expected_rule][expected_path][0].rule_id == expected_rule
assert (
results[expected_rule][expected_path][0].locations[0].file == expected_path
)

def test_codeql_sarif_input(self, tmpdir):
completed_process = subprocess.run(
Expand Down

0 comments on commit b22c1ed

Please sign in to comment.