diff --git a/prospector/llm/llm_service.py b/prospector/llm/llm_service.py index cbc4e69e4..685bf79a0 100644 --- a/prospector/llm/llm_service.py +++ b/prospector/llm/llm_service.py @@ -3,9 +3,11 @@ import validators from langchain_core.language_models.llms import LLM from langchain_core.output_parsers import StrOutputParser +from requests import HTTPError from llm.instantiation import create_model_instance -from llm.prompts import prompt_best_guess +from llm.prompts.classify_commit import zero_shot as cc_zero_shot +from llm.prompts.get_repository_url import prompt_best_guess from log.logger import logger from util.config_parser import LLMServiceConfig from util.singleton import Singleton @@ -74,3 +76,53 @@ def get_repository_url(self, advisory_description, advisory_references) -> str: raise RuntimeError(f"Prompt-model chain could not be invoked: {e}") return url + + def classify_commit( + self, diff: str, repository_name: str, commit_message: str + ) -> bool: + """Ask an LLM whether a commit is security relevant or not. The response will be either True or False. + + Args: + candidate (Commit): The commit to input into the LLM + + Returns: + True if the commit is deemed security relevant, False if not. + + Raises: + ValueError if there is an error in the model invocation or the response was not valid. + """ + try: + chain = cc_zero_shot | self.model | StrOutputParser() + + is_relevant = chain.invoke( + { + "diff": diff, + "repository_name": repository_name, + "commit_message": commit_message, + } + ) + logger.info(f"LLM returned is_relevant={is_relevant}") + + except HTTPError as e: + # if the diff is too big, a 400 error is returned -> silently ignore by returning False for this commit + status_code = e.response.status_code + if status_code == 400: + return False + raise RuntimeError(f"Prompt-model chain could not be invoked: {e}") + except Exception as e: + raise RuntimeError(f"Prompt-model chain could not be invoked: {e}") + + if is_relevant in [ + "True", + "ANSWER:True", + "```ANSWER:True```", + ]: + return True + elif is_relevant in [ + "False", + "ANSWER:False", + "```ANSWER:False```", + ]: + return False + else: + raise RuntimeError(f"The model returned an invalid response: {is_relevant}") diff --git a/prospector/llm/prompts/classify_commit.py b/prospector/llm/prompts/classify_commit.py new file mode 100644 index 000000000..80a99afe9 --- /dev/null +++ b/prospector/llm/prompts/classify_commit.py @@ -0,0 +1,16 @@ +from langchain.prompts import PromptTemplate + +zero_shot = PromptTemplate.from_template( + """Is the following commit security relevant or not? +Please provide the output as a boolean value, either True or False. +If it is security relevant just answer True otherwise answer False. Do not return anything else. + +To provide you with some context, the name of the repository is: {repository_name}, and the +commit message is: {commit_message}. + +Finally, here is the diff of the commit: +{diff}\n + + +Your answer:\n""" +) diff --git a/prospector/llm/prompts.py b/prospector/llm/prompts/get_repository_url.py similarity index 100% rename from prospector/llm/prompts.py rename to prospector/llm/prompts/get_repository_url.py diff --git a/prospector/rules/rules.py b/prospector/rules/rules.py index 80496c812..2ba5a16e9 100644 --- a/prospector/rules/rules.py +++ b/prospector/rules/rules.py @@ -413,6 +413,18 @@ def apply(self, candidate: Commit, advisory_record: AdvisoryRecord): return False +class CommitIsSecurityRelevant(Rule): + """Matches commits that are deemed security relevant by the commit classification service.""" + + def apply( + self, + candidate: Commit, + ) -> bool: + return LLMService().classify_commit( + candidate.diff, candidate.repository, candidate.message + ) + + RULES_PHASE_1: List[Rule] = [ VulnIdInMessage("VULN_ID_IN_MESSAGE", 64), # CommitMentionedInAdv("COMMIT_IN_ADVISORY", 64), @@ -433,4 +445,6 @@ def apply(self, candidate: Commit, advisory_record: AdvisoryRecord): CommitHasTwins("COMMIT_HAS_TWINS", 2), ] -RULES_PHASE_2: List[Rule] = [] +RULES_PHASE_2: List[Rule] = [ + CommitIsSecurityRelevant("COMMIT_IS_SECURITY_RELEVANT", 32) +] diff --git a/prospector/rules/rules_test.py b/prospector/rules/rules_test.py index 230c351e0..93c246ef4 100644 --- a/prospector/rules/rules_test.py +++ b/prospector/rules/rules_test.py @@ -89,7 +89,9 @@ def candidates(): changed_files={ "core/src/main/java/org/apache/cxf/workqueue/AutomaticWorkQueueImpl.java" }, - minhash=get_encoded_minhash(get_msg("Insecure deserialization", 50)), + minhash=get_encoded_minhash( + get_msg("Insecure deserialization", 50) + ), ), # TODO: Not matched by existing tests: GHSecurityAdvInMessage, ReferencesBug, ChangesRelevantCode, TwinMentionedInAdv, VulnIdInLinkedIssue, SecurityKeywordInLinkedGhIssue, SecurityKeywordInLinkedBug, CrossReferencedBug, CrossReferencedGh, CommitHasTwins, ChangesRelevantFiles, CommitMentionedInAdv, RelevantWordsInMessage ] @@ -109,7 +111,9 @@ def advisory_record(): ) -def test_apply_phase_1_rules(candidates: List[Commit], advisory_record: AdvisoryRecord): +def test_apply_phase_1_rules( + candidates: List[Commit], advisory_record: AdvisoryRecord +): annotated_candidates = apply_rules( candidates, advisory_record, enabled_rules=enabled_rules_from_config ) @@ -117,7 +121,9 @@ def test_apply_phase_1_rules(candidates: List[Commit], advisory_record: Advisory # Repo 5: Should match: AdvKeywordsInFiles, SecurityKeywordsInMsg, CommitMentionedInReference assert len(annotated_candidates[0].matched_rules) == 3 - matched_rules_names = [item["id"] for item in annotated_candidates[0].matched_rules] + matched_rules_names = [ + item["id"] for item in annotated_candidates[0].matched_rules + ] assert "ADV_KEYWORDS_IN_FILES" in matched_rules_names assert "COMMIT_IN_REFERENCE" in matched_rules_names assert "SEC_KEYWORDS_IN_MESSAGE" in matched_rules_names @@ -125,21 +131,27 @@ def test_apply_phase_1_rules(candidates: List[Commit], advisory_record: Advisory # Repo 1: Should match: VulnIdInMessage, ReferencesGhIssue assert len(annotated_candidates[1].matched_rules) == 2 - matched_rules_names = [item["id"] for item in annotated_candidates[1].matched_rules] + matched_rules_names = [ + item["id"] for item in annotated_candidates[1].matched_rules + ] assert "VULN_ID_IN_MESSAGE" in matched_rules_names assert "GITHUB_ISSUE_IN_MESSAGE" in matched_rules_names # Repo 3: Should match: VulnIdInMessage, ReferencesGhIssue assert len(annotated_candidates[2].matched_rules) == 2 - matched_rules_names = [item["id"] for item in annotated_candidates[2].matched_rules] + matched_rules_names = [ + item["id"] for item in annotated_candidates[2].matched_rules + ] assert "VULN_ID_IN_MESSAGE" in matched_rules_names assert "GITHUB_ISSUE_IN_MESSAGE" in matched_rules_names # Repo 4: Should match: SecurityKeywordsInMsg assert len(annotated_candidates[3].matched_rules) == 1 - matched_rules_names = [item["id"] for item in annotated_candidates[3].matched_rules] + matched_rules_names = [ + item["id"] for item in annotated_candidates[3].matched_rules + ] assert "SEC_KEYWORDS_IN_MESSAGE" in matched_rules_names # Repo 2: Matches nothing