Skip to content

Commit

Permalink
implements enabled_rules parameter in config.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
lauraschauer committed Jul 5, 2024
1 parent 30c895e commit 70d9401
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 56 deletions.
13 changes: 4 additions & 9 deletions prospector/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ def main(argv): # noqa: C901

# Whether to use the LLMService
if config.llm_service:
if (
not config.repository
and not config.llm_service.use_llm_repository_url
):
if not config.repository and not config.llm_service.use_llm_repository_url:
logger.error(
"Repository URL was neither specified nor allowed to obtain with LLM support. One must be set."
)
Expand All @@ -71,7 +68,7 @@ def main(argv): # noqa: C901
)
return

# Create the LLMService singleton for later use
# Create the LLMService Singleton for later use
try:
LLMService(config.llm_service)
except Exception as e:
Expand All @@ -83,9 +80,7 @@ def main(argv): # noqa: C901
return

config.pub_date = (
config.pub_date + "T00:00:00Z"
if config.pub_date is not None
else ""
config.pub_date + "T00:00:00Z" if config.pub_date is not None else ""
)

logger.debug("Using the following configuration:")
Expand All @@ -109,7 +104,7 @@ def main(argv): # noqa: C901
limit_candidates=config.max_candidates,
# ignore_adv_refs=config.ignore_refs,
use_llm_repository_url=config.llm_service.use_llm_repository_url,
use_llm_rules=config.llm_service.use_llm_rules,
enabled_rules=config.enabled_rules,
)

if config.preprocess_only:
Expand Down
21 changes: 21 additions & 0 deletions prospector/config-sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,27 @@ redis_url: redis://redis:6379/0

# use_llm_repository_url: False # whether to use LLM's to obtain the repository URL

enabled_rules:
# Phase 1 Rules
- VULN_ID_IN_MESSAGE
- XREF_BUG
- XREF_GH
- COMMIT_IN_REFERENCE
- VULN_ID_IN_LINKED_ISSUE
- CHANGES_RELEVANT_FILES
- CHANGES_RELEVANT_CODE
- RELEVANT_WORDS_IN_MESSAGE
- ADV_KEYWORDS_IN_FILES
- ADV_KEYWORDS_IN_MSG
- SEC_KEYWORDS_IN_MESSAGE
- SEC_KEYWORDS_IN_LINKED_GH
- SEC_KEYWORDS_IN_LINKED_BUG
- GITHUB_ISSUE_IN_MESSAGE
- BUG_IN_MESSAGE
- COMMIT_HAS_TWINS
# Phase 2 Rules (llm_service required!):
# - COMMIT_IS_SECURITY_RELEVANT

# Report file format: "html", "json", "console" or "all"
# and the file name
report:
Expand Down
16 changes: 5 additions & 11 deletions prospector/core/prospector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from git.version_to_tag import get_possible_tags
from llm.llm_service import LLMService
from log.logger import get_level, logger, pretty_log
from rules.rules import PHASE_1, PHASE_2, apply_rules
from rules.rules import apply_rules
from stats.execution import (
Counter,
ExecutionTimer,
Expand Down Expand Up @@ -66,11 +66,10 @@ def prospector( # noqa: C901
use_backend: str = USE_BACKEND_ALWAYS,
git_cache: str = "/tmp/git_cache",
limit_candidates: int = MAX_CANDIDATES,
rules: List[str] = [PHASE_1],
enabled_rules: List[str] = [],
tag_commits: bool = True,
silent: bool = False,
use_llm_repository_url: bool = False,
use_llm_rules: bool = False,
) -> Tuple[List[Commit], AdvisoryRecord] | Tuple[int, int]:
if silent:
logger.disabled = True
Expand Down Expand Up @@ -233,7 +232,7 @@ def prospector( # noqa: C901
logger.warning("Preprocessed commits are not being sent to backend")

ranked_candidates = evaluate_commits(
preprocessed_commits, advisory_record, rules, use_llm_rules
preprocessed_commits, advisory_record, enabled_rules
)

# ConsoleWriter.print("Commit ranking and aggregation...")
Expand Down Expand Up @@ -272,8 +271,7 @@ def filter(commits: Dict[str, RawCommit]) -> Dict[str, RawCommit]:
def evaluate_commits(
commits: List[Commit],
advisory: AdvisoryRecord,
rules: List[str],
use_llm_rules: bool,
enabled_rules: List[str],
) -> List[Commit]:
"""This function applies rule phases. Each phase is associated with a set of rules, for example:
- Phase 1: NLP Rules
Expand All @@ -283,18 +281,14 @@ def evaluate_commits(
commits: the list of candidate commits that rules should be applied to
advisory: the object containing all information about the advisory
rules: a (sub)set of rules to run
use_llm_rules: indication whether the user wishes to use the LLM supported rules
Returns:
a list of commits ranked according to their relevance score
Raises:
MissingMandatoryValue: if there is an error in the LLM configuration object
"""
with ExecutionTimer(core_statistics.sub_collection("candidates analysis")):
with ConsoleWriter("Candidate analysis") as _:
if use_llm_rules:
rules.append(PHASE_2)

ranked_commits = apply_rules(commits, advisory, rules=rules)
ranked_commits = apply_rules(commits, advisory, enabled_rules=enabled_rules)

return ranked_commits

Expand Down
33 changes: 11 additions & 22 deletions prospector/rules/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
from stats.execution import Counter, execution_statistics
from util.lsh import build_lsh_index, decode_minhash

PHASE_1 = "phase_1"
PHASE_2 = "phase_2"

MAX_COMMITS_FOR_LLM_RULES = 3
MAX_COMMITS_FOR_LLM_RULES = 1


rule_statistics = execution_statistics.sub_collection("rules")
Expand Down Expand Up @@ -46,19 +43,24 @@ def as_dict(self):
def get_rule_as_tuple(self) -> Tuple[str, str, int]:
return (self.id, self.message, self.relevance)

def get_id(self):
return self.id


def apply_rules(
candidates: List[Commit],
advisory_record: AdvisoryRecord,
rules: List[str] = [PHASE_1],
enabled_rules: List[str] = [],
) -> List[Commit]:
"""Applies the selected set of rules and returns the ranked list of commits (uses apply_ranking())."""

Rule.lsh_index = build_lsh_index()
if PHASE_2 in rules:
Rule.llm_service = LLMService()

phase_1_rules = get_enabled_rules(rules)
phase_2_rules = get_enabled_rules(rules)
phase_1_rules = [rule for rule in RULES_PHASE_1 if rule.get_id() in enabled_rules]
phase_2_rules = [rule for rule in RULES_PHASE_2 if rule.get_id() in enabled_rules]

if phase_2_rules:
Rule.llm_service = LLMService()

rule_statistics.collect(
"active", len(phase_1_rules) + len(phase_2_rules), unit="rules"
Expand Down Expand Up @@ -457,16 +459,3 @@ def apply(
RULES_PHASE_2: List[Rule] = [
CommitIsSecurityRelevant("COMMIT_IS_SECURITY_RELEVANT", 32)
]


def get_enabled_rules(rules: List[str]) -> List[Rule]:

if PHASE_1 in rules:
rules.remove(PHASE_1) # signify phase 1 is done
return RULES_PHASE_1

if PHASE_2 in rules:
rules.remove(PHASE_2) # signify phase 2 is done
return RULES_PHASE_2

return []
13 changes: 1 addition & 12 deletions prospector/rules/rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,7 @@

from datamodel.advisory import AdvisoryRecord
from datamodel.commit import Commit
from rules.rules import (
ChangesRelevantCode,
ChangesRelevantFiles,
CommitMentionedInReference,
CrossReferencedBug,
CrossReferencedGh,
GHSecurityAdvInMessage,
ReferencesGhIssue,
VulnIdInLinkedIssue,
VulnIdInMessage,
apply_rules,
)
from rules.rules import apply_rules
from util.lsh import get_encoded_minhash

# from datamodel.commit_features import CommitWithFeatures
Expand Down
7 changes: 5 additions & 2 deletions prospector/util/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import sys
from dataclasses import MISSING, dataclass
from typing import Optional
from typing import List, Optional

from omegaconf import OmegaConf
from omegaconf.errors import (
Expand Down Expand Up @@ -183,7 +183,6 @@ class LLMServiceConfig:
model_name: str
ai_core_sk: str
use_llm_repository_url: bool
use_llm_rules: bool
temperature: float = 0.0


Expand All @@ -200,6 +199,7 @@ class ConfigSchema:
report: ReportConfig = MISSING
log_level: str = MISSING
git_cache: str = MISSING
enabled_rules: List[str] = MISSING
nvd_token: Optional[str] = None
database: DatabaseConfig = DatabaseConfig(
user="postgres",
Expand Down Expand Up @@ -237,6 +237,7 @@ def __init__(
ping: bool,
log_level: str,
git_cache: str,
enabled_rules: List[str],
ignore_refs: bool,
llm_service: LLMServiceConfig,
):
Expand All @@ -261,6 +262,7 @@ def __init__(
self.ping = ping
self.log_level = log_level
self.git_cache = git_cache
self.enabled_rules = enabled_rules
self.ignore_refs = ignore_refs


Expand Down Expand Up @@ -296,6 +298,7 @@ def get_configuration(argv):
report_filename=args.report_filename or conf.report.name,
ping=args.ping,
git_cache=conf.git_cache,
enabled_rules=conf.enabled_rules,
log_level=args.log_level or conf.log_level,
ignore_refs=args.ignore_refs,
)
Expand Down

0 comments on commit 70d9401

Please sign in to comment.