From 70d94012b6372ba31d0d03ea7af1f7ad89e52f07 Mon Sep 17 00:00:00 2001 From: I748376 Date: Mon, 24 Jun 2024 09:52:54 +0000 Subject: [PATCH] implements enabled_rules parameter in config.yaml --- prospector/cli/main.py | 13 ++++--------- prospector/config-sample.yaml | 21 ++++++++++++++++++++ prospector/core/prospector.py | 16 +++++----------- prospector/rules/rules.py | 33 +++++++++++--------------------- prospector/rules/rules_test.py | 13 +------------ prospector/util/config_parser.py | 7 +++++-- 6 files changed, 47 insertions(+), 56 deletions(-) diff --git a/prospector/cli/main.py b/prospector/cli/main.py index 18c1b94a9..2cdeac7d4 100644 --- a/prospector/cli/main.py +++ b/prospector/cli/main.py @@ -58,10 +58,7 @@ def main(argv): # noqa: C901 # Whether to use the LLMService if config.llm_service: - if ( - not config.repository - and not config.llm_service.use_llm_repository_url - ): + if not config.repository and not config.llm_service.use_llm_repository_url: logger.error( "Repository URL was neither specified nor allowed to obtain with LLM support. One must be set." ) @@ -71,7 +68,7 @@ def main(argv): # noqa: C901 ) return - # Create the LLMService singleton for later use + # Create the LLMService Singleton for later use try: LLMService(config.llm_service) except Exception as e: @@ -83,9 +80,7 @@ def main(argv): # noqa: C901 return config.pub_date = ( - config.pub_date + "T00:00:00Z" - if config.pub_date is not None - else "" + config.pub_date + "T00:00:00Z" if config.pub_date is not None else "" ) logger.debug("Using the following configuration:") @@ -109,7 +104,7 @@ def main(argv): # noqa: C901 limit_candidates=config.max_candidates, # ignore_adv_refs=config.ignore_refs, use_llm_repository_url=config.llm_service.use_llm_repository_url, - use_llm_rules=config.llm_service.use_llm_rules, + enabled_rules=config.enabled_rules, ) if config.preprocess_only: diff --git a/prospector/config-sample.yaml b/prospector/config-sample.yaml index 4faa61c8a..f89f4b369 100644 --- a/prospector/config-sample.yaml +++ b/prospector/config-sample.yaml @@ -36,6 +36,27 @@ redis_url: redis://redis:6379/0 # use_llm_repository_url: False # whether to use LLM's to obtain the repository URL +enabled_rules: + # Phase 1 Rules + - VULN_ID_IN_MESSAGE + - XREF_BUG + - XREF_GH + - COMMIT_IN_REFERENCE + - VULN_ID_IN_LINKED_ISSUE + - CHANGES_RELEVANT_FILES + - CHANGES_RELEVANT_CODE + - RELEVANT_WORDS_IN_MESSAGE + - ADV_KEYWORDS_IN_FILES + - ADV_KEYWORDS_IN_MSG + - SEC_KEYWORDS_IN_MESSAGE + - SEC_KEYWORDS_IN_LINKED_GH + - SEC_KEYWORDS_IN_LINKED_BUG + - GITHUB_ISSUE_IN_MESSAGE + - BUG_IN_MESSAGE + - COMMIT_HAS_TWINS + # Phase 2 Rules (llm_service required!): + # - COMMIT_IS_SECURITY_RELEVANT + # Report file format: "html", "json", "console" or "all" # and the file name report: diff --git a/prospector/core/prospector.py b/prospector/core/prospector.py index 29fe1655a..b92e53b7d 100644 --- a/prospector/core/prospector.py +++ b/prospector/core/prospector.py @@ -20,7 +20,7 @@ from git.version_to_tag import get_possible_tags from llm.llm_service import LLMService from log.logger import get_level, logger, pretty_log -from rules.rules import PHASE_1, PHASE_2, apply_rules +from rules.rules import apply_rules from stats.execution import ( Counter, ExecutionTimer, @@ -66,11 +66,10 @@ def prospector( # noqa: C901 use_backend: str = USE_BACKEND_ALWAYS, git_cache: str = "/tmp/git_cache", limit_candidates: int = MAX_CANDIDATES, - rules: List[str] = [PHASE_1], + enabled_rules: List[str] = [], tag_commits: bool = True, silent: bool = False, use_llm_repository_url: bool = False, - use_llm_rules: bool = False, ) -> Tuple[List[Commit], AdvisoryRecord] | Tuple[int, int]: if silent: logger.disabled = True @@ -233,7 +232,7 @@ def prospector( # noqa: C901 logger.warning("Preprocessed commits are not being sent to backend") ranked_candidates = evaluate_commits( - preprocessed_commits, advisory_record, rules, use_llm_rules + preprocessed_commits, advisory_record, enabled_rules ) # ConsoleWriter.print("Commit ranking and aggregation...") @@ -272,8 +271,7 @@ def filter(commits: Dict[str, RawCommit]) -> Dict[str, RawCommit]: def evaluate_commits( commits: List[Commit], advisory: AdvisoryRecord, - rules: List[str], - use_llm_rules: bool, + enabled_rules: List[str], ) -> List[Commit]: """This function applies rule phases. Each phase is associated with a set of rules, for example: - Phase 1: NLP Rules @@ -283,7 +281,6 @@ def evaluate_commits( commits: the list of candidate commits that rules should be applied to advisory: the object containing all information about the advisory rules: a (sub)set of rules to run - use_llm_rules: indication whether the user wishes to use the LLM supported rules Returns: a list of commits ranked according to their relevance score Raises: @@ -291,10 +288,7 @@ def evaluate_commits( """ with ExecutionTimer(core_statistics.sub_collection("candidates analysis")): with ConsoleWriter("Candidate analysis") as _: - if use_llm_rules: - rules.append(PHASE_2) - - ranked_commits = apply_rules(commits, advisory, rules=rules) + ranked_commits = apply_rules(commits, advisory, enabled_rules=enabled_rules) return ranked_commits diff --git a/prospector/rules/rules.py b/prospector/rules/rules.py index 93016123b..0c083e31c 100644 --- a/prospector/rules/rules.py +++ b/prospector/rules/rules.py @@ -11,10 +11,7 @@ from stats.execution import Counter, execution_statistics from util.lsh import build_lsh_index, decode_minhash -PHASE_1 = "phase_1" -PHASE_2 = "phase_2" - -MAX_COMMITS_FOR_LLM_RULES = 3 +MAX_COMMITS_FOR_LLM_RULES = 1 rule_statistics = execution_statistics.sub_collection("rules") @@ -46,19 +43,24 @@ def as_dict(self): def get_rule_as_tuple(self) -> Tuple[str, str, int]: return (self.id, self.message, self.relevance) + def get_id(self): + return self.id + def apply_rules( candidates: List[Commit], advisory_record: AdvisoryRecord, - rules: List[str] = [PHASE_1], + enabled_rules: List[str] = [], ) -> List[Commit]: + """Applies the selected set of rules and returns the ranked list of commits (uses apply_ranking()).""" Rule.lsh_index = build_lsh_index() - if PHASE_2 in rules: - Rule.llm_service = LLMService() - phase_1_rules = get_enabled_rules(rules) - phase_2_rules = get_enabled_rules(rules) + phase_1_rules = [rule for rule in RULES_PHASE_1 if rule.get_id() in enabled_rules] + phase_2_rules = [rule for rule in RULES_PHASE_2 if rule.get_id() in enabled_rules] + + if phase_2_rules: + Rule.llm_service = LLMService() rule_statistics.collect( "active", len(phase_1_rules) + len(phase_2_rules), unit="rules" @@ -457,16 +459,3 @@ def apply( RULES_PHASE_2: List[Rule] = [ CommitIsSecurityRelevant("COMMIT_IS_SECURITY_RELEVANT", 32) ] - - -def get_enabled_rules(rules: List[str]) -> List[Rule]: - - if PHASE_1 in rules: - rules.remove(PHASE_1) # signify phase 1 is done - return RULES_PHASE_1 - - if PHASE_2 in rules: - rules.remove(PHASE_2) # signify phase 2 is done - return RULES_PHASE_2 - - return [] diff --git a/prospector/rules/rules_test.py b/prospector/rules/rules_test.py index cfeccfb90..5b5abf730 100644 --- a/prospector/rules/rules_test.py +++ b/prospector/rules/rules_test.py @@ -5,18 +5,7 @@ from datamodel.advisory import AdvisoryRecord from datamodel.commit import Commit -from rules.rules import ( - ChangesRelevantCode, - ChangesRelevantFiles, - CommitMentionedInReference, - CrossReferencedBug, - CrossReferencedGh, - GHSecurityAdvInMessage, - ReferencesGhIssue, - VulnIdInLinkedIssue, - VulnIdInMessage, - apply_rules, -) +from rules.rules import apply_rules from util.lsh import get_encoded_minhash # from datamodel.commit_features import CommitWithFeatures diff --git a/prospector/util/config_parser.py b/prospector/util/config_parser.py index 2eb6205ab..7bc83a8e0 100644 --- a/prospector/util/config_parser.py +++ b/prospector/util/config_parser.py @@ -2,7 +2,7 @@ import os import sys from dataclasses import MISSING, dataclass -from typing import Optional +from typing import List, Optional from omegaconf import OmegaConf from omegaconf.errors import ( @@ -183,7 +183,6 @@ class LLMServiceConfig: model_name: str ai_core_sk: str use_llm_repository_url: bool - use_llm_rules: bool temperature: float = 0.0 @@ -200,6 +199,7 @@ class ConfigSchema: report: ReportConfig = MISSING log_level: str = MISSING git_cache: str = MISSING + enabled_rules: List[str] = MISSING nvd_token: Optional[str] = None database: DatabaseConfig = DatabaseConfig( user="postgres", @@ -237,6 +237,7 @@ def __init__( ping: bool, log_level: str, git_cache: str, + enabled_rules: List[str], ignore_refs: bool, llm_service: LLMServiceConfig, ): @@ -261,6 +262,7 @@ def __init__( self.ping = ping self.log_level = log_level self.git_cache = git_cache + self.enabled_rules = enabled_rules self.ignore_refs = ignore_refs @@ -296,6 +298,7 @@ def get_configuration(argv): report_filename=args.report_filename or conf.report.name, ping=args.ping, git_cache=conf.git_cache, + enabled_rules=conf.enabled_rules, log_level=args.log_level or conf.log_level, ignore_refs=args.ignore_refs, )