diff --git a/src/codemodder/result.py b/src/codemodder/result.py index 8dd66c55..99293601 100644 --- a/src/codemodder/result.py +++ b/src/codemodder/result.py @@ -1,5 +1,6 @@ from __future__ import annotations +import itertools from abc import abstractmethod from dataclasses import dataclass, field from pathlib import Path @@ -139,7 +140,14 @@ def fuzzy_column_match(pos: CodeRange, location: Location) -> bool: class ResultSet(dict[str, dict[Path, list[Result]]]): + results_for_rule: dict[str, list[Result]] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.results_for_rule = {} + def add_result(self, result: Result): + self.results_for_rule.setdefault(result.rule_id, []).append(result) for loc in result.locations: self.setdefault(result.rule_id, {}).setdefault(loc.file, []).append(result) @@ -157,6 +165,16 @@ def results_for_rule_and_file( """ return self.get(rule_id, {}).get(file.relative_to(context.directory), []) + def results_for_rules(self, rule_ids: list[str]) -> list[Result]: + """ + Returns flat list of all results that match any of the given rule IDs. + """ + return list( + itertools.chain.from_iterable( + self.results_for_rule.get(rule_id, []) for rule_id in rule_ids + ) + ) + def files_for_rule(self, rule_id: str) -> list[Path]: return list(self.get(rule_id, {}).keys()) @@ -164,16 +182,22 @@ def all_rule_ids(self) -> list[str]: return list(self.keys()) def __or__(self, other): - result = ResultSet(super().__or__(other)) + result = self.__class__() for k in self.keys() | other.keys(): - result[k] = list_dict_or(self[k], other[k]) + result[k] = list_dict_or(self.get(k, {}), other.get(k, {})) + result.results_for_rule = list_dict_or( + self.results_for_rule, other.results_for_rule + ) return result + def __ior__(self, other): + return self | other + def list_dict_or( dictionary: dict[Any, list[Any]], other: dict[Any, list[Any]] -) -> dict[Path, list[Any]]: - result_dict = other | dictionary +) -> dict[Any, list[Any]]: + result_dict = {} for k in other.keys() | dictionary.keys(): - result_dict[k] = dictionary[k] + other[k] + result_dict[k] = dictionary.get(k, []) + other.get(k, []) return result_dict diff --git a/tests/test_results.py b/tests/test_results.py index e8c7c832..7e697765 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -55,6 +55,11 @@ def test_or(self, tmpdir): in combined["python:S5659"][Path("code.py")] ) + assert combined.results_for_rules(["python:S5659"]) == [ + result1["python:S5659"][Path("code.py")][0], + result2["python:S5659"][Path("code.py")][0], + ] + def test_sonar_only_open_issues(self, tmpdir): issues = { "issues": [