From f4bf8949f446d20b79a5c83e299ce98a2f6c292e Mon Sep 17 00:00:00 2001 From: I748376 Date: Thu, 20 Jun 2024 12:55:19 +0000 Subject: [PATCH] final touch-up --- prospector/cli/main.py | 10 +++++++- prospector/llm/instantiation.py | 39 +++++++++++++++++------------- prospector/llm/llm_service.py | 27 +++++++++++++-------- prospector/llm/prompts.py | 14 +++++------ prospector/llm/test_llm_service.py | 23 +++++++++++++++--- 5 files changed, 75 insertions(+), 38 deletions(-) diff --git a/prospector/cli/main.py b/prospector/cli/main.py index d281cbe5c..fd1c8b138 100644 --- a/prospector/cli/main.py +++ b/prospector/cli/main.py @@ -70,7 +70,15 @@ def main(argv): # noqa: C901 # If at least one 'use_llm' option is set, then create an LLMService singleton if any([True for x in dir(config.llm_service) if x.startswith("use_llm")]): - LLMService(config.llm_service) + try: + LLMService(config.llm_service) + except Exception as e: + logger.error(f"Problem with LLMService instantiation: {e}") + console.print( + "LLMService could not be created. Check logs.", + status=MessageStatus.ERROR, + ) + return config.pub_date = ( config.pub_date + "T00:00:00Z" if config.pub_date is not None else "" diff --git a/prospector/llm/instantiation.py b/prospector/llm/instantiation.py index 6653b6e9d..1794f9cb8 100644 --- a/prospector/llm/instantiation.py +++ b/prospector/llm/instantiation.py @@ -24,15 +24,15 @@ # "gpt-4-turbo": OpenAI, # currently TBD # "gpt-4o": OpenAI, # currently TBD "gemini-1.0-pro": Gemini, - "mistralai--mixtral-8x7b-instruct-v01": Mistral, + "mistral-large": Mistral, } THIRD_PARTY_MAPPING = { - "gpt-4": ChatOpenAI, - "gpt-3.5-turbo": ChatOpenAI, - "gemini-pro": ChatVertexAI, - "mistral-large-latest": ChatMistralAI, + "gpt-4": (ChatOpenAI, "OPENAI_API_KEY"), + "gpt-3.5-turbo": (ChatOpenAI, "OPENAI_API_KEY"), + "gemini-pro": (ChatVertexAI, "GOOGLE_API_KEY"), + "mistral-large-latest": (ChatMistralAI, "MISTRAL_API_KEY"), } @@ -55,17 +55,18 @@ def create_model_instance( LLM: An instance of the specified LLM model. Raises: - ValueError: if there is a problem with deploymenturl, model_name or AI Core credentials + ValueError: if there is a problem with deployment_url, model_name or AI Core credentials """ - # LASCHA: correct docstring def create_sap_provider( model_name: str, temperature: float, ai_core_sk_filepath: str - ): + ) -> LLM: - deployment_url = env.get("GPT_35_TURBO_URL", None) + deployment_url = env.get(model_name.upper().replace("-", "_") + "_URL", None) if deployment_url is None: - raise ValueError(f"Deployment URL for {model_name} is not set.") + raise ValueError( + f"Deployment URL ({model_name.upper().replace('-', '_')}_URL) for {model_name} is not set." + ) model_class = SAP_MAPPING.get(model_name, None) if model_class is None: @@ -85,21 +86,22 @@ def create_sap_provider( return model - def create_third_party_provider(model_name: str, temperature: float): - model_definition = THIRD_PARTY_MAPPING.get(model_name, None) + def create_third_party_provider(model_name: str, temperature: float) -> LLM: + model_class = THIRD_PARTY_MAPPING.get(model_name, None)[0] - if model_definition is None: + if model_class is None: raise ValueError(f"Model '{model_name}' is not available.") - model = model_definition._class( + api_key_variable = THIRD_PARTY_MAPPING.get(model_name, None)[1] + + model = model_class( model=model_name, - api_key=model_definition.access_info, + api_key=api_key_variable, temperature=temperature, ) return model - # LLM Instantiation try: match model_type: case "sap": @@ -120,9 +122,12 @@ def create_third_party_provider(model_name: str, temperature: float): return model -def get_headers(ai_core_sk_file_path: str): +def get_headers(ai_core_sk_file_path: str) -> Dict[str, str]: """Generate the request headers to use SAP AI Core. This method generates the authentication token and returns a Dict with headers. + Params: + ai_core_sk_file_path (str): the path to the file containing the SAP AI Core credentials. + Returns: The headers object needed to send requests to the SAP AI Core. """ diff --git a/prospector/llm/llm_service.py b/prospector/llm/llm_service.py index 5ee631427..3804a1012 100644 --- a/prospector/llm/llm_service.py +++ b/prospector/llm/llm_service.py @@ -3,27 +3,34 @@ from langchain_core.output_parsers import StrOutputParser from llm.instantiation import create_model_instance -from llm.prompts import best_guess +from llm.prompts import prompt_best_guess from log.logger import logger from util.config_parser import LLMServiceConfig from util.singleton import Singleton class LLMService(metaclass=Singleton): - """A wrapper class for all functions requiring an LLM. This class is also a singleton, as only one model - should be used throughout the program. + """A wrapper class for all functions requiring an LLM. This class is also a singleton, as only a + single model should be used throughout the program. """ config: LLMServiceConfig = None - def __init__(self, config: LLMServiceConfig): - self.config = config + def __init__(self, config: LLMServiceConfig = None): + + if self.config is None and config is not None: + self.config = config + elif self.config is None and config is None: + raise ValueError( + "On the first instantiation, a configuration object must be passed." + ) + try: self.model: LLM = create_model_instance( - config.type, - config.model_name, - config.ai_core_sk, - config.temperature, + self.config.type, + self.config.model_name, + self.config.ai_core_sk, + self.config.temperature, ) except Exception: raise @@ -42,7 +49,7 @@ def get_repository_url(self, advisory_description, advisory_references) -> str: ValueError if advisory information cannot be obtained or there is an error in the model invocation. """ try: - chain = best_guess | self.model | StrOutputParser() + chain = prompt_best_guess | self.model | StrOutputParser() url = chain.invoke( { diff --git a/prospector/llm/prompts.py b/prospector/llm/prompts.py index 57fd2444a..f9c083599 100644 --- a/prospector/llm/prompts.py +++ b/prospector/llm/prompts.py @@ -1,7 +1,7 @@ from langchain.prompts import FewShotPromptTemplate, PromptTemplate -# example output for few-shot prompting -examples_without_num = [ +# Get Repository URL, few-shot prompting examples +examples_data = [ { "cve_description": "Apache Olingo versions 4.0.0 to 4.7.0 provide the AsyncRequestWrapperImpl class which reads a URL from the Location header, and then sends a GET or DELETE request to this URL. It may allow to implement a SSRF attack. If an attacker tricks a client to connect to a malicious server, the server can make the client call any URL including internal resources which are not directly accessible by the attacker.", "cve_references": "https://www.zerodayinitiative.com/advisories/ZDI-24-196/", @@ -20,7 +20,7 @@ ] # Formatter for the few-shot examples without CVE numbers -examples_prompt_without_num = PromptTemplate( +examples_formatted = PromptTemplate( input_variables=["cve_references", "result"], template=""" {cve_description} {cve_references} @@ -28,12 +28,12 @@ {result} """, ) -best_guess = FewShotPromptTemplate( +prompt_best_guess = FewShotPromptTemplate( prefix="""You will be provided with the ID, description and references of a vulnerability advisory (CVE). Return nothing but the URL of the repository the given CVE is concerned with.'. Here are a few examples delimited with XML tags:""", - examples=examples_without_num, - example_prompt=examples_prompt_without_num, + examples=examples_data, + example_prompt=examples_formatted, suffix="""Here is the CVE information: {description} {references} @@ -41,5 +41,5 @@ If you cannot find the URL, return your best guess of what the repository URL could be. Use any hints (eg. the mention of GitHub or GitLab) in the CVE description and references. Return nothing but the URL. """, input_variables=["description", "references"], - metadata={"name": "best_guess"}, + metadata={"name": "prompt_best_guess"}, ) diff --git a/prospector/llm/test_llm_service.py b/prospector/llm/test_llm_service.py index 2030e77aa..fdd82f208 100644 --- a/prospector/llm/test_llm_service.py +++ b/prospector/llm/test_llm_service.py @@ -2,6 +2,9 @@ import pytest from langchain_core.language_models.llms import LLM +from langchain_google_vertexai import ChatVertexAI +from langchain_mistralai import ChatMistralAI +from langchain_openai import ChatOpenAI from requests_cache import Optional from llm.llm_service import LLMService # this is a singleton @@ -61,12 +64,26 @@ def test_sap_gemini_instantiation(self): assert isinstance(llm_service.model, Gemini) def test_sap_mistral_instantiation(self): - config = Config( - "sap", "mistralai--mixtral-8x7b-instruct-v01", 0.0, "example.json" - ) + config = Config("sap", "mistral-large", 0.0, "example.json") llm_service = LLMService(config) assert isinstance(llm_service.model, Mistral) + def test_gpt_instantiation(self): + config = Config("third_party", "gpt-4", 0.0, "example.json") + llm_service = LLMService(config) + assert isinstance(llm_service.model, ChatOpenAI) + + # Google throws an error on creation, when no account is found + # def test_gemini_instantiation(self): + # config = Config("third_party", "gemini-pro", 0.0, "example.json") + # llm_service = LLMService(config) + # assert isinstance(llm_service.model, ChatVertexAI) + + def test_mistral_instantiation(self): + config = Config("third_party", "mistral-large-latest", 0.0, "example.json") + llm_service = LLMService(config) + assert isinstance(llm_service.model, ChatMistralAI) + def test_singleton_instance_creation(self): """A second instantiation should return the exisiting instance.""" config = Config("sap", "gpt-4", 0.0, "example.json")