Skip to content

Commit

Permalink
final touch-up
Browse files Browse the repository at this point in the history
  • Loading branch information
lauraschauer committed Jun 20, 2024
1 parent 59722cd commit f4bf894
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 38 deletions.
10 changes: 9 additions & 1 deletion prospector/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,15 @@ def main(argv): # noqa: C901

# If at least one 'use_llm' option is set, then create an LLMService singleton
if any([True for x in dir(config.llm_service) if x.startswith("use_llm")]):
LLMService(config.llm_service)
try:
LLMService(config.llm_service)
except Exception as e:
logger.error(f"Problem with LLMService instantiation: {e}")
console.print(
"LLMService could not be created. Check logs.",
status=MessageStatus.ERROR,
)
return

config.pub_date = (
config.pub_date + "T00:00:00Z" if config.pub_date is not None else ""
Expand Down
39 changes: 22 additions & 17 deletions prospector/llm/instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@
# "gpt-4-turbo": OpenAI, # currently TBD
# "gpt-4o": OpenAI, # currently TBD
"gemini-1.0-pro": Gemini,
"mistralai--mixtral-8x7b-instruct-v01": Mistral,
"mistral-large": Mistral,
}


THIRD_PARTY_MAPPING = {
"gpt-4": ChatOpenAI,
"gpt-3.5-turbo": ChatOpenAI,
"gemini-pro": ChatVertexAI,
"mistral-large-latest": ChatMistralAI,
"gpt-4": (ChatOpenAI, "OPENAI_API_KEY"),
"gpt-3.5-turbo": (ChatOpenAI, "OPENAI_API_KEY"),
"gemini-pro": (ChatVertexAI, "GOOGLE_API_KEY"),
"mistral-large-latest": (ChatMistralAI, "MISTRAL_API_KEY"),
}


Expand All @@ -55,17 +55,18 @@ def create_model_instance(
LLM: An instance of the specified LLM model.
Raises:
ValueError: if there is a problem with deploymenturl, model_name or AI Core credentials
ValueError: if there is a problem with deployment_url, model_name or AI Core credentials
"""
# LASCHA: correct docstring

def create_sap_provider(
model_name: str, temperature: float, ai_core_sk_filepath: str
):
) -> LLM:

deployment_url = env.get("GPT_35_TURBO_URL", None)
deployment_url = env.get(model_name.upper().replace("-", "_") + "_URL", None)
if deployment_url is None:
raise ValueError(f"Deployment URL for {model_name} is not set.")
raise ValueError(
f"Deployment URL ({model_name.upper().replace('-', '_')}_URL) for {model_name} is not set."
)

model_class = SAP_MAPPING.get(model_name, None)
if model_class is None:
Expand All @@ -85,21 +86,22 @@ def create_sap_provider(

return model

def create_third_party_provider(model_name: str, temperature: float):
model_definition = THIRD_PARTY_MAPPING.get(model_name, None)
def create_third_party_provider(model_name: str, temperature: float) -> LLM:
model_class = THIRD_PARTY_MAPPING.get(model_name, None)[0]

if model_definition is None:
if model_class is None:
raise ValueError(f"Model '{model_name}' is not available.")

model = model_definition._class(
api_key_variable = THIRD_PARTY_MAPPING.get(model_name, None)[1]

model = model_class(
model=model_name,
api_key=model_definition.access_info,
api_key=api_key_variable,
temperature=temperature,
)

return model

# LLM Instantiation
try:
match model_type:
case "sap":
Expand All @@ -120,9 +122,12 @@ def create_third_party_provider(model_name: str, temperature: float):
return model


def get_headers(ai_core_sk_file_path: str):
def get_headers(ai_core_sk_file_path: str) -> Dict[str, str]:
"""Generate the request headers to use SAP AI Core. This method generates the authentication token and returns a Dict with headers.
Params:
ai_core_sk_file_path (str): the path to the file containing the SAP AI Core credentials.
Returns:
The headers object needed to send requests to the SAP AI Core.
"""
Expand Down
27 changes: 17 additions & 10 deletions prospector/llm/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,34 @@
from langchain_core.output_parsers import StrOutputParser

from llm.instantiation import create_model_instance
from llm.prompts import best_guess
from llm.prompts import prompt_best_guess
from log.logger import logger
from util.config_parser import LLMServiceConfig
from util.singleton import Singleton


class LLMService(metaclass=Singleton):
"""A wrapper class for all functions requiring an LLM. This class is also a singleton, as only one model
should be used throughout the program.
"""A wrapper class for all functions requiring an LLM. This class is also a singleton, as only a
single model should be used throughout the program.
"""

config: LLMServiceConfig = None

def __init__(self, config: LLMServiceConfig):
self.config = config
def __init__(self, config: LLMServiceConfig = None):

if self.config is None and config is not None:
self.config = config
elif self.config is None and config is None:
raise ValueError(
"On the first instantiation, a configuration object must be passed."
)

try:
self.model: LLM = create_model_instance(
config.type,
config.model_name,
config.ai_core_sk,
config.temperature,
self.config.type,
self.config.model_name,
self.config.ai_core_sk,
self.config.temperature,
)
except Exception:
raise
Expand All @@ -42,7 +49,7 @@ def get_repository_url(self, advisory_description, advisory_references) -> str:
ValueError if advisory information cannot be obtained or there is an error in the model invocation.
"""
try:
chain = best_guess | self.model | StrOutputParser()
chain = prompt_best_guess | self.model | StrOutputParser()

url = chain.invoke(
{
Expand Down
14 changes: 7 additions & 7 deletions prospector/llm/prompts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from langchain.prompts import FewShotPromptTemplate, PromptTemplate

# example output for few-shot prompting
examples_without_num = [
# Get Repository URL, few-shot prompting examples
examples_data = [
{
"cve_description": "Apache Olingo versions 4.0.0 to 4.7.0 provide the AsyncRequestWrapperImpl class which reads a URL from the Location header, and then sends a GET or DELETE request to this URL. It may allow to implement a SSRF attack. If an attacker tricks a client to connect to a malicious server, the server can make the client call any URL including internal resources which are not directly accessible by the attacker.",
"cve_references": "https://www.zerodayinitiative.com/advisories/ZDI-24-196/",
Expand All @@ -20,26 +20,26 @@
]

# Formatter for the few-shot examples without CVE numbers
examples_prompt_without_num = PromptTemplate(
examples_formatted = PromptTemplate(
input_variables=["cve_references", "result"],
template="""<description> {cve_description} </description>
<references> {cve_references}</references>
<output> {result} </output>""",
)

best_guess = FewShotPromptTemplate(
prompt_best_guess = FewShotPromptTemplate(
prefix="""You will be provided with the ID, description and references of a vulnerability advisory (CVE). Return nothing but the URL of the repository the given CVE is concerned with.'.
Here are a few examples delimited with XML tags:""",
examples=examples_without_num,
example_prompt=examples_prompt_without_num,
examples=examples_data,
example_prompt=examples_formatted,
suffix="""Here is the CVE information:
<description> {description} </description>
<references> {references} </references>
If you cannot find the URL, return your best guess of what the repository URL could be. Use any hints (eg. the mention of GitHub or GitLab) in the CVE description and references. Return nothing but the URL.
""",
input_variables=["description", "references"],
metadata={"name": "best_guess"},
metadata={"name": "prompt_best_guess"},
)
23 changes: 20 additions & 3 deletions prospector/llm/test_llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import pytest
from langchain_core.language_models.llms import LLM
from langchain_google_vertexai import ChatVertexAI
from langchain_mistralai import ChatMistralAI
from langchain_openai import ChatOpenAI
from requests_cache import Optional

from llm.llm_service import LLMService # this is a singleton
Expand Down Expand Up @@ -61,12 +64,26 @@ def test_sap_gemini_instantiation(self):
assert isinstance(llm_service.model, Gemini)

def test_sap_mistral_instantiation(self):
config = Config(
"sap", "mistralai--mixtral-8x7b-instruct-v01", 0.0, "example.json"
)
config = Config("sap", "mistral-large", 0.0, "example.json")
llm_service = LLMService(config)
assert isinstance(llm_service.model, Mistral)

def test_gpt_instantiation(self):
config = Config("third_party", "gpt-4", 0.0, "example.json")
llm_service = LLMService(config)
assert isinstance(llm_service.model, ChatOpenAI)

# Google throws an error on creation, when no account is found
# def test_gemini_instantiation(self):
# config = Config("third_party", "gemini-pro", 0.0, "example.json")
# llm_service = LLMService(config)
# assert isinstance(llm_service.model, ChatVertexAI)

def test_mistral_instantiation(self):
config = Config("third_party", "mistral-large-latest", 0.0, "example.json")
llm_service = LLMService(config)
assert isinstance(llm_service.model, ChatMistralAI)

def test_singleton_instance_creation(self):
"""A second instantiation should return the exisiting instance."""
config = Config("sap", "gpt-4", 0.0, "example.json")
Expand Down

0 comments on commit f4bf894

Please sign in to comment.