diff --git a/prospector/README.md b/prospector/README.md
index 1d0cf72c5..968e9be13 100644
--- a/prospector/README.md
+++ b/prospector/README.md
@@ -57,24 +57,31 @@ To quickly set up Prospector, follow these steps. This will run Prospector in it
### 🤖 LLM Support
-To use Prospector with LLM support, set the `use_llm_<...>` parameters in `config.yaml`. Additionally, you must specify required parameters for API access to the LLM. These parameters can vary depending on your choice of provider, please follow what fits your needs:
+To use Prospector with LLM support, you simply set required parameters for the API access to the LLM in *config.yaml*. These parameters can vary depending on your choice of provider, please follow what fits your needs (drop-downs below). If you do not want to use LLM support, keep the `llm_service` block in your *config.yaml* file commented out.
Use SAP AI CORE SDK
-You will need the following parameters in `config.yaml`:
+You will need the following parameters in *config.yaml*:
```yaml
llm_service:
type: sap
model_name:
+ temperature: 0.0
+ ai_core_sk:
```
-`` refers to the model names available in the Generative AI Hub in SAP AI Core. [Here](https://github.tools.sap/I343697/generative-ai-hub-readme#1-supported-models) you can find an overview of available models.
+`` refers to the model names available in the Generative AI Hub in SAP AI Core. You can find an overview of available models on the Generative AI Hub GitHub page.
In `.env`, you must set the deployment URL as an environment variable following this naming convention:
```yaml
-_URL
+_URL # model name in capitals, and "-" changed to "_"
```
+For example, for gpt-4's deployment URL, set an environment variable called `GPT_4_URL`.
+
+The `temperature` parameter is optional. The default value is 0.0, but you can change it to something else.
+
+You also need to point the `ai_core_sk` parameter to a file contianing the secret keys.
@@ -82,11 +89,12 @@ In `.env`, you must set the deployment URL as an environment variable following
Implemented third party providers are **OpenAI**, **Google** and **Mistral**.
-1. You will need the following parameters in `config.yaml`:
+1. You will need the following parameters in *config.yaml*:
```yaml
llm_service:
type: third_party
model_name:
+ temperature: 0.0
```
`` refers to the model names available, for example `gpt-4o` for OpenAI. You can find a lists of available models here:
@@ -94,10 +102,19 @@ Implemented third party providers are **OpenAI**, **Google** and **Mistral**.
2. [Google](https://ai.google.dev/gemini-api/docs/models/gemini)
3. [Mistral](https://docs.mistral.ai/getting-started/models/)
+ The `temperature` parameter is optional. The default value is 0.0, but you can change it to something else.
+
2. Make sure to add your OpenAI API key to your `.env` file as `[OPENAI|GOOGLE|MISTRAL]_API_KEY`.
+####
+
+You can set the `use_llm_<...>` parameters in *config.yaml* for fine-grained control over LLM support in various aspects of Prospector's phases. Each `use_llm_<...>` parameter allows you to enable or disable LLM support for a specific aspect:
+
+- **`use_llm_repository_url`**: Choose whether LLMs should be used to obtain the repository URL. When using this option, you can omit the `--repository` flag as a command line argument and run prospector with `./run_prospector.sh CVE-2020-1925`.
+
+
## 👩💻 Development Setup
Following these steps allows you to run Prospector's components individually: [Backend database and worker containers](#starting-the-backend-database-and-the-job-workers), [RESTful Server](#starting-the-restful-server) for API endpoints, [Prospector CLI](#running-the-cli-version) and [Tests](#testing).
@@ -125,7 +142,7 @@ Afterwards, you will just have to set the environment variables using the `.env`
set -a; source .env; set +a
```
-You can configure prospector from CLI or from the `config.yaml` file. The (recommended) API Keys for Github and the NVD can be configured from the `.env` file (which must then be sourced with `set -a; source .env; set +a`)
+You can configure prospector from CLI or from the *config.yaml* file. The (recommended) API Keys for Github and the NVD can be configured from the `.env` file (which must then be sourced with `set -a; source .env; set +a`)
If at any time you wish to use a different version of the python interpreter, beware that the `requirements.txt` file contains the exact versioning for `python 3.10.6`.
diff --git a/prospector/cli/main.py b/prospector/cli/main.py
index 95b5ef723..696a51e06 100644
--- a/prospector/cli/main.py
+++ b/prospector/cli/main.py
@@ -1,5 +1,4 @@
#!/usr/bin/python3
-import logging
import os
import signal
import sys
@@ -7,8 +6,7 @@
from dotenv import load_dotenv
-import llm.operations as llm
-from llm.model_instantiation import create_model_instance
+from llm.llm_service import LLMService
from util.http import ping_backend
path_root = os.getcwd()
@@ -55,23 +53,32 @@ def main(argv): # noqa: C901
)
return
- # instantiate LLM model if set in config.yaml
- if config.llm_service:
- model = create_model_instance(llm_config=config.llm_service)
-
- if not config.repository and not config.use_llm_repository_url:
- logger.error(
- "Either provide the repository URL or allow LLM usage to obtain it."
- )
- console.print(
- "Either provide the repository URL or allow LLM usage to obtain it.",
- status=MessageStatus.ERROR,
- )
- sys.exit(1)
-
# if config.ping:
# return ping_backend(backend, get_level() < logging.INFO)
+ # Whether to use the LLMService
+ if config.llm_service:
+ if not config.repository and not config.llm_service.use_llm_repository_url:
+ logger.error(
+ "Repository URL was neither specified nor allowed to obtain with LLM support. One must be set."
+ )
+ console.print(
+ "Please set the `--repository` parameter or enable LLM support to infer the repository URL.",
+ status=MessageStatus.ERROR,
+ )
+ return
+
+ # Create the LLMService singleton for later use
+ try:
+ LLMService(config.llm_service)
+ except Exception as e:
+ logger.error(f"Problem with LLMService instantiation: {e}")
+ console.print(
+ "LLMService could not be created. Check logs.",
+ status=MessageStatus.ERROR,
+ )
+ return
+
config.pub_date = (
config.pub_date + "T00:00:00Z" if config.pub_date is not None else ""
)
@@ -81,9 +88,6 @@ def main(argv): # noqa: C901
logger.debug("Vulnerability ID: " + config.vuln_id)
- if not config.repository:
- config.repository = llm.get_repository_url(model=model, vuln_id=config.vuln_id)
-
results, advisory_record = prospector(
vulnerability_id=config.vuln_id,
repository_url=config.repository,
@@ -99,6 +103,7 @@ def main(argv): # noqa: C901
git_cache=config.git_cache,
limit_candidates=config.max_candidates,
# ignore_adv_refs=config.ignore_refs,
+ use_llm_repository_url=config.llm_service.use_llm_repository_url,
)
if config.preprocess_only:
diff --git a/prospector/config-sample.yaml b/prospector/config-sample.yaml
index 7208bc3dd..4faa61c8a 100644
--- a/prospector/config-sample.yaml
+++ b/prospector/config-sample.yaml
@@ -28,12 +28,13 @@ database:
redis_url: redis://redis:6379/0
# LLM Usage (check README for help)
-llm_service:
- type: sap # use "sap" or "third_party"
- model_name: gpt-4-turbo
- # temperature: 0.0 # optional, default is 0.0
+# llm_service:
+# type: sap # use "sap" or "third_party"
+# model_name: gpt-4-turbo
+# temperature: 0.0 # optional, default is 0.0
+# ai_core_sk: # needed for type: sap
-use_llm_repository_url: True # whether to use LLM's to obtain the repository URL
+# use_llm_repository_url: False # whether to use LLM's to obtain the repository URL
# Report file format: "html", "json", "console" or "all"
# and the file name
diff --git a/prospector/core/prospector.py b/prospector/core/prospector.py
index a95603aae..0fd22f199 100644
--- a/prospector/core/prospector.py
+++ b/prospector/core/prospector.py
@@ -5,7 +5,7 @@
import re
import sys
import time
-from typing import Dict, List, Set, Tuple
+from typing import DefaultDict, Dict, List, Set, Tuple
from urllib.parse import urlparse
import requests
@@ -18,6 +18,7 @@
from git.git import Git
from git.raw_commit import RawCommit
from git.version_to_tag import get_possible_tags
+from llm.llm_service import LLMService
from log.logger import get_level, logger, pretty_log
from rules.rules import apply_rules
from stats.execution import (
@@ -51,7 +52,7 @@
@measure_execution_time(execution_statistics, name="core")
def prospector( # noqa: C901
vulnerability_id: str,
- repository_url: str,
+ repository_url: str = None,
publication_date: str = "",
vuln_descr: str = "",
version_interval: str = "",
@@ -68,6 +69,7 @@ def prospector( # noqa: C901
rules: List[str] = ["ALL"],
tag_commits: bool = True,
silent: bool = False,
+ use_llm_repository_url: bool = False,
) -> Tuple[List[Commit], AdvisoryRecord] | Tuple[int, int]:
if silent:
logger.disabled = True
@@ -89,7 +91,28 @@ def prospector( # noqa: C901
if advisory_record is None:
return None, -1
- fixing_commit = advisory_record.get_fixing_commit(repository_url)
+ if use_llm_repository_url:
+ with ConsoleWriter("LLM Usage (Repo URL)") as console:
+ try:
+ repository_url = LLMService().get_repository_url(
+ advisory_record.description, advisory_record.references
+ )
+ console.print(
+ f"\n Repository URL: {repository_url}",
+ status=MessageStatus.OK,
+ )
+ except Exception as e:
+ logger.error(
+ e,
+ exc_info=get_level() < logging.INFO,
+ )
+ console.print(
+ e,
+ status=MessageStatus.ERROR,
+ )
+ sys.exit(1)
+
+ fixing_commit = advisory_record.get_fixing_commit()
# print(advisory_record.references)
# obtain a repository object
repository = Git(repository_url, git_cache)
@@ -140,9 +163,7 @@ def prospector( # noqa: C901
if len(candidates) > limit_candidates:
logger.error(f"Number of candidates exceeds {limit_candidates}, aborting.")
- ConsoleWriter.print(
- f"Candidates limit exceeded: {len(candidates)}.",
- )
+ ConsoleWriter.print(f"Candidates limitlimit exceeded: {len(candidates)}.")
return None, len(candidates)
with ExecutionTimer(
@@ -177,7 +198,10 @@ def prospector( # noqa: C901
# preprocessed_commits += preprocess_commits(missing, timer)
pbar = tqdm(
- missing, desc="Processing commits", unit="commit", disable=silent
+ missing,
+ desc="Processing commits",
+ unit="commit",
+ disable=silent,
)
start_time = time.time()
with Counter(
@@ -319,7 +343,7 @@ def retrieve_preprocessed_commits(
)
]
- logger.error(f"Missing {len(missing)} commits")
+ logger.info(f"{len(missing)} commits not found in backend")
commits = [Commit.parse_obj(rc) for rc in retrieved_commits]
# Sets the tags
# for commit in commits:
diff --git a/prospector/core/report_test.py b/prospector/core/report_test.py
index 7658a0209..387dc2304 100644
--- a/prospector/core/report_test.py
+++ b/prospector/core/report_test.py
@@ -2,7 +2,7 @@
import os.path
from random import randint
-import prospector.core.report as report
+import core.report as report
from datamodel.advisory import build_advisory_record
from datamodel.commit import Commit
from util.sample_data_generation import ( # random_list_of_url,
diff --git a/prospector/datamodel/advisory.py b/prospector/datamodel/advisory.py
index 15569b35b..d1e9e49a6 100644
--- a/prospector/datamodel/advisory.py
+++ b/prospector/datamodel/advisory.py
@@ -10,6 +10,7 @@
import validators
from dateutil.parser import isoparse
+from llm.llm_service import LLMService
from log.logger import get_level, logger, pretty_log
from util.http import extract_from_webpage, fetch_url, get_urls
@@ -69,6 +70,7 @@ def __init__(
reserved_timestamp: int = 0,
published_timestamp: int = 0,
updated_timestamp: int = 0,
+ repository_url: str = None,
references: DefaultDict[str, int] = None,
affected_products: List[str] = None,
versions: Dict[str, List[str]] = None,
@@ -81,6 +83,7 @@ def __init__(
self.reserved_timestamp = reserved_timestamp
self.published_timestamp = published_timestamp
self.updated_timestamp = updated_timestamp
+ self.repository_url = repository_url
self.references = references or defaultdict(lambda: 0)
self.affected_products = affected_products or list()
self.versions = versions or dict()
@@ -133,6 +136,7 @@ def parse_references_from_third_party(self):
self.references[self.extract_hashes(ref)] += 2
def get_advisory(self):
+ """Fills the advisory record with information obtained from an advisory API."""
details, metadata = get_from_mitre(self.cve_id)
if metadata is None:
raise Exception("MITRE API Error")
@@ -176,7 +180,7 @@ def parse_advisory(self, data):
]
self.versions["fixed"] = [v for v in self.versions["fixed"] if v is not None]
- def get_fixing_commit(self, repository) -> List[str]:
+ def get_fixing_commit(self) -> List[str]:
self.references = dict(
sorted(self.references.items(), key=lambda item: item[1], reverse=True)
)
diff --git a/prospector/git/git_test.py b/prospector/git/git_test.py
index 97bf640b5..fbfdbcc86 100644
--- a/prospector/git/git_test.py
+++ b/prospector/git/git_test.py
@@ -42,7 +42,8 @@ def test_get_tags_for_commit(repository: Git):
commit = commits.get(OPENCAST_COMMIT)
if commit is not None:
tags = commit.find_tags()
- assert len(tags) == 75
+ print(tags)
+ assert len(tags) >= 106
assert "10.2" in tags and "11.3" in tags and "9.4" in tags
diff --git a/prospector/git/raw_commit_test.py b/prospector/git/raw_commit_test.py
index 4eb28dc95..534431e94 100644
--- a/prospector/git/raw_commit_test.py
+++ b/prospector/git/raw_commit_test.py
@@ -26,7 +26,7 @@ def commit():
def test_find_tags(commit: RawCommit):
tags = commit.find_tags()
- assert len(tags) == 75
+ assert len(tags) >= 106
assert "10.2" in tags and "11.3" in tags and "9.4" in tags
diff --git a/prospector/llm/instantiation.py b/prospector/llm/instantiation.py
new file mode 100644
index 000000000..db924e722
--- /dev/null
+++ b/prospector/llm/instantiation.py
@@ -0,0 +1,159 @@
+import json
+import os
+from typing import Dict
+
+import requests
+from dotenv import load_dotenv
+from langchain_core.language_models.llms import LLM
+from langchain_google_vertexai import ChatVertexAI
+from langchain_mistralai import ChatMistralAI
+from langchain_openai import ChatOpenAI
+
+from llm.models.gemini import Gemini
+from llm.models.mistral import Mistral
+from llm.models.openai import OpenAI
+
+load_dotenv()
+
+
+SAP_MAPPING = {
+ "gpt-35-turbo": OpenAI,
+ "gpt-35-turbo-16k": OpenAI,
+ "gpt-35-turbo-0125": OpenAI,
+ "gpt-4": OpenAI,
+ "gpt-4-32k": OpenAI,
+ # "gpt-4-turbo": OpenAI, # currently TBD
+ # "gpt-4o": OpenAI, # currently TBD
+ "gemini-1.0-pro": Gemini,
+ "mistral-large": Mistral,
+}
+
+
+THIRD_PARTY_MAPPING = {
+ "gpt-4-turbo": (ChatOpenAI, "OPENAI_API_KEY"),
+ "gpt-4o": (ChatOpenAI, "OPENAI_API_KEY"),
+ "gpt-4": (ChatOpenAI, "OPENAI_API_KEY"),
+ "gpt-3.5-turbo-0125": (ChatOpenAI, "OPENAI_API_KEY"),
+ "gpt-3.5-turbo": (ChatOpenAI, "OPENAI_API_KEY"),
+ "gemini-pro": (ChatVertexAI, "GOOGLE_API_KEY"),
+ "mistral-large-latest": (ChatMistralAI, "MISTRAL_API_KEY"),
+}
+
+
+def create_model_instance(
+ model_type: str,
+ model_name: str,
+ ai_core_sk_filepath: str,
+ temperature: float = 0.0,
+) -> LLM:
+ """Creates and returns the model object given the user's configuration.
+
+ Args:
+ model_type: the way of accessing the LLM API ('sap' for SAP's AI Core, 'third_party' for
+ external providers).
+ model_name: which model to use, e.g. gpt-4.
+ temperature: the temperature for the model, default 0.0.
+ ai_core_sk_filepath: The path to the file containing AI Core credentials
+
+ Returns:
+ LLM: An instance of the specified LLM model.
+
+ Raises:
+ ValueError: if there is a problem with deployment_url, model_name or AI Core credentials
+ """
+
+ def create_sap_provider(
+ model_name: str, temperature: float, ai_core_sk_filepath: str
+ ) -> LLM:
+
+ deployment_url = os.getenv(model_name.upper().replace("-", "_") + "_URL", None)
+ if deployment_url is None:
+ raise ValueError(
+ f"Deployment URL ({model_name.upper().replace('-', '_')}_URL) for {model_name} is not set."
+ )
+
+ model_class = SAP_MAPPING.get(model_name, None)
+ if model_class is None:
+ raise ValueError(f"Model '{model_name}' is not available.")
+
+ if ai_core_sk_filepath is None:
+ raise ValueError(
+ f"AI Core credentials file couldn't be found: '{ai_core_sk_filepath}'"
+ )
+
+ model = model_class(
+ model_name=model_name,
+ deployment_url=deployment_url,
+ temperature=temperature,
+ ai_core_sk_filepath=ai_core_sk_filepath,
+ )
+
+ return model
+
+ def create_third_party_provider(model_name: str, temperature: float) -> LLM:
+ model_class = THIRD_PARTY_MAPPING.get(model_name, None)[0]
+ if model_class is None:
+ raise ValueError(f"Model '{model_name}' is not available.")
+
+ api_key_variable = THIRD_PARTY_MAPPING.get(model_name, None)[1]
+ api_key = os.getenv(api_key_variable)
+ if api_key is None:
+ raise ValueError(f"API key for {model_name} is not set.")
+
+ model = model_class(
+ model=model_name,
+ api_key=api_key,
+ temperature=temperature,
+ )
+
+ return model
+
+ try:
+ match model_type:
+ case "sap":
+ model = create_sap_provider(
+ model_name,
+ temperature,
+ ai_core_sk_filepath,
+ )
+ case "third_party":
+ model = create_third_party_provider(model_name, temperature)
+ case _:
+ raise ValueError(
+ f"Invalid LLM type specified (either sap or third_party). '{model_type}' is not available."
+ )
+ except Exception:
+ raise # re-raise exceptions from create_[sap|third_party]_provider
+
+ return model
+
+
+def get_headers(ai_core_sk_file_path: str) -> Dict[str, str]:
+ """Generate the request headers to use SAP AI Core. This method generates the authentication token and returns a Dict with headers.
+
+ Params:
+ ai_core_sk_file_path (str): the path to the file containing the SAP AI Core credentials.
+
+ Returns:
+ The headers object needed to send requests to the SAP AI Core.
+ """
+ with open(ai_core_sk_file_path) as f:
+ sk = json.load(f)
+
+ auth_url = f"{sk['url']}/oauth/token"
+ client_id = sk["clientid"]
+ client_secret = sk["clientsecret"]
+
+ response = requests.post(
+ auth_url,
+ data={"grant_type": "client_credentials"},
+ auth=(client_id, client_secret),
+ timeout=8000,
+ )
+
+ headers = {
+ "AI-Resource-Group": "default",
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {response.json()['access_token']}",
+ }
+ return headers
diff --git a/prospector/llm/llm_service.py b/prospector/llm/llm_service.py
new file mode 100644
index 000000000..cbc4e69e4
--- /dev/null
+++ b/prospector/llm/llm_service.py
@@ -0,0 +1,76 @@
+import re
+
+import validators
+from langchain_core.language_models.llms import LLM
+from langchain_core.output_parsers import StrOutputParser
+
+from llm.instantiation import create_model_instance
+from llm.prompts import prompt_best_guess
+from log.logger import logger
+from util.config_parser import LLMServiceConfig
+from util.singleton import Singleton
+
+
+class LLMService(metaclass=Singleton):
+ """A wrapper class for all functions requiring an LLM. This class is also a singleton, as only a
+ single model should be used throughout the program.
+ """
+
+ config: LLMServiceConfig = None
+
+ def __init__(self, config: LLMServiceConfig = None):
+
+ if self.config is None and config is not None:
+ self.config = config
+ elif self.config is None and config is None:
+ raise ValueError(
+ "On the first instantiation, a configuration object must be passed."
+ )
+
+ try:
+ self.model: LLM = create_model_instance(
+ self.config.type,
+ self.config.model_name,
+ self.config.ai_core_sk,
+ self.config.temperature,
+ )
+ except Exception:
+ raise
+
+ def get_repository_url(self, advisory_description, advisory_references) -> str:
+ """Ask an LLM to obtain the repository URL given the advisory description and references.
+
+ Args:
+ advisory_description (str): The advisory description
+ advisory_references (dict): The advisory's references
+
+ Returns:
+ The repository URL as a string.
+
+ Raises:
+ ValueError if advisory information cannot be obtained or there is an error in the model invocation.
+ """
+ try:
+ chain = prompt_best_guess | self.model | StrOutputParser()
+
+ url = chain.invoke(
+ {
+ "description": advisory_description,
+ "references": advisory_references,
+ }
+ )
+ logger.info(f"LLM returned the following URL: {url}")
+
+ # delimiters are often returned by the LLM, remove them, if the case
+ pattern = r""
+ match = re.search(pattern, url)
+ if match:
+ return match.group(1)
+
+ if not validators.url(url):
+ raise TypeError(f"LLM returned invalid URL: {url}")
+
+ except Exception as e:
+ raise RuntimeError(f"Prompt-model chain could not be invoked: {e}")
+
+ return url
diff --git a/prospector/llm/llm_service_test.py b/prospector/llm/llm_service_test.py
new file mode 100644
index 000000000..0b43af860
--- /dev/null
+++ b/prospector/llm/llm_service_test.py
@@ -0,0 +1,131 @@
+from typing import Any, List
+
+import pytest
+from langchain_core.language_models.llms import LLM
+from langchain_google_vertexai import ChatVertexAI
+from langchain_mistralai import ChatMistralAI
+from langchain_openai import ChatOpenAI
+from requests_cache import Optional
+
+from llm.llm_service import LLMService # this is a singleton
+from llm.models.gemini import Gemini
+from llm.models.mistral import Mistral
+from llm.models.openai import OpenAI
+from util.singleton import Singleton
+
+
+# Mock the llm_service configuration object
+class Config:
+ type: str = None
+ model_name: str = None
+ temperature: str = None
+ ai_core_sk: str = None
+
+ def __init__(self, type, model_name, temperature, ai_core_sk):
+ self.type = type
+ self.model_name = model_name
+ self.temperature = temperature
+ self.ai_core_sk = ai_core_sk
+
+
+@pytest.fixture(autouse=True)
+def reset_singletons():
+ # Clean up singleton instances after each test
+ Singleton._instances = {}
+
+
+@pytest.fixture(autouse=True)
+def mock_environment_variables():
+ mp = pytest.MonkeyPatch()
+ mp.setenv("GPT_4_URL", "https://deployment.url.com")
+ mp.setenv("MISTRAL_LARGE_URL", "https://deployment.url.com")
+ mp.setenv("GEMINI_1.0_PRO_URL", "https://deployment.url.com")
+ mp.setenv("OPENAI_API_KEY", "https://deployment.url.com")
+ mp.setenv("GOOGLE_API_KEY", "https://deployment.url.com")
+ mp.setenv("MISTRAL_API_KEY", "https://deployment.url.com")
+
+
+class TestModel:
+ def test_sap_gpt_instantiation(self):
+ config = Config("sap", "gpt-4", 0.0, "example.json")
+ llm_service = LLMService(config)
+ assert isinstance(llm_service.model, OpenAI)
+
+ def test_sap_gemini_instantiation(self):
+ config = Config("sap", "gemini-1.0-pro", 0.0, "example.json")
+ llm_service = LLMService(config)
+ assert isinstance(llm_service.model, Gemini)
+
+ def test_sap_mistral_instantiation(self):
+ config = Config("sap", "mistral-large", 0.0, "example.json")
+ llm_service = LLMService(config)
+ assert isinstance(llm_service.model, Mistral)
+
+ def test_gpt_instantiation(self):
+ config = Config("third_party", "gpt-4", 0.0, "example.json")
+ llm_service = LLMService(config)
+ assert isinstance(llm_service.model, ChatOpenAI)
+
+ # Google throws an error on creation, when no account is found
+ # def test_gemini_instantiation(self):
+ # config = Config("third_party", "gemini-pro", 0.0, "example.json")
+ # llm_service = LLMService(config)
+ # assert isinstance(llm_service.model, ChatVertexAI)
+
+ def test_mistral_instantiation(self):
+ config = Config("third_party", "mistral-large-latest", 0.0, "example.json")
+ llm_service = LLMService(config)
+ assert isinstance(llm_service.model, ChatMistralAI)
+
+ def test_singleton_instance_creation(self):
+ """A second instantiation should return the exisiting instance."""
+ config = Config("sap", "gpt-4", 0.0, "example.json")
+ llm_service = LLMService(config)
+ same_service = LLMService(config)
+ assert (
+ llm_service is same_service
+ ), "LLMService should return the same instance."
+
+ def test_singleton_same_instance(self):
+ """A second instantiation with different parameters should return the existing instance unchanged."""
+ config = Config("sap", "gpt-4", 0.0, "example.json")
+ llm_service = LLMService(config)
+ config = Config(
+ "sap", "gpt-35-turbo", 0.0, "example.json"
+ ) # This instantiation should not work, but instead return the already existing instance
+ same_service = LLMService(config)
+ assert llm_service is same_service
+ assert llm_service.model.model_name == "gpt-4"
+
+ def test_singleton_retains_state(self):
+ """Reassigning a field variable of the instance should be allowed and reflected
+ across instantiations."""
+ config = Config("sap", "gpt-4", 0.0, "example.json")
+ service = LLMService(config)
+
+ service.model = OpenAI(
+ model_name="gpt-35-turbo",
+ deployment_url="deployment_url_placeholder",
+ temperature=0.7,
+ ai_core_sk_filepath="example.json",
+ )
+ same_service = LLMService(config)
+
+ assert same_service.model == OpenAI(
+ model_name="gpt-35-turbo",
+ deployment_url="deployment_url_placeholder",
+ temperature=0.7,
+ ai_core_sk_filepath="example.json",
+ ), "LLMService should retain state between instantiations"
+
+ def test_reuse_singleton_without_config(self):
+ config = Config("sap", "gpt-4", 0.0, "example.json")
+ service = LLMService(config)
+
+ same_service = LLMService()
+
+ assert service is same_service
+
+ def test_fail_first_instantiation_without_config(self):
+ with pytest.raises(Exception):
+ LLMService()
diff --git a/prospector/llm/model_instantiation.py b/prospector/llm/model_instantiation.py
deleted file mode 100644
index 2ca1560f1..000000000
--- a/prospector/llm/model_instantiation.py
+++ /dev/null
@@ -1,116 +0,0 @@
-from typing import Dict
-
-from dotenv import dotenv_values
-from langchain_core.language_models.llms import LLM
-from langchain_google_vertexai import ChatVertexAI
-from langchain_mistralai import ChatMistralAI
-from langchain_openai import ChatOpenAI
-
-from llm.models import Gemini, Mistral, OpenAI
-from log.logger import logger
-
-
-class ModelDef:
- def __init__(self, access_info: str, _class: LLM):
- self.access_info = (
- access_info # either deployment_url (for SAP) or API key (for Third Party)
- )
- self._class = _class
-
-
-env: Dict[str, str | None] = dotenv_values()
-
-SAP_MAPPING = {
- "gpt-35-turbo": ModelDef(env.get("GPT_35_TURBO_URL", None), OpenAI),
- "gpt-35-turbo-16k": ModelDef(env.get("GPT_35_TURBO_16K_URL", None), OpenAI),
- "gpt-35-turbo-0125": ModelDef(env.get("GPT_35_TURBO_0125_URL", None), OpenAI),
- "gpt-4": ModelDef(env.get("GPT_4_URL", None), OpenAI),
- "gpt-4-32k": ModelDef(env.get("GPT_4_32K_URL", None), OpenAI),
- # "gpt-4-turbo": env.get("GPT_4_TURBO_URL", None), # currently TBD: https://github.tools.sap/I343697/generative-ai-hub-readme
- # "gpt-4o": env.get("GPT_4O_URL", None), # currently TBD: https://github.tools.sap/I343697/generative-ai-hub-readme
- "gemini-1.0-pro": ModelDef(env.get("GEMINI_1_0_PRO_URL", None), Gemini),
- "mistralai--mixtral-8x7b-instruct-v01": ModelDef(
- env.get("MISTRALAI_MIXTRAL_8X7B_INSTRUCT_V01", None), Mistral
- ),
-}
-
-THIRD_PARTY_MAPPING = {
- "gpt-4": ModelDef(env.get("OPENAI_API_KEY", None), ChatOpenAI),
- "gpt-3.5-turbo": ModelDef(env.get("OPENAI_API_KEY", None), ChatOpenAI),
- "gemini-pro": ModelDef(env.get("GOOGLE_API_KEY", None), ChatVertexAI),
- "mistral-large-latest": ModelDef(env.get("MISTRAL_API_KEY", None), ChatMistralAI),
-}
-
-
-def create_model_instance(llm_config) -> LLM:
- """Creates and returns the model object given the user's configuration.
-
- Args:
- llm_config (dict): A dictionary containing the configuration for the LLM. Expected keys are:
- - 'type' (str): Method for accessing the LLM API ('sap' for SAP's AI Core, 'third_party' for
- external providers).
- - 'model_name' (str): Which model to use, e.g. gpt-4.
- - 'temperature' (Optional(float)): The temperature for the model, default 0.0.
-
- Returns:
- LLM: An instance of the specified LLM model.
- """
-
- def create_sap_provider(model_name: str, temperature: float):
- d = SAP_MAPPING.get(model_name, None)
-
- if d is None:
- raise ValueError(f"Model '{model_name}' is not available.")
-
- model = d._class(
- model_name=model_name,
- deployment_url=d.access_info,
- temperature=temperature,
- )
-
- return model
-
- def create_third_party_provider(model_name: str, temperature: float):
- # obtain definition from main mapping
- d = THIRD_PARTY_MAPPING.get(model_name, None)
-
- if d is None:
- logger.error(f"Model '{model_name}' is not available.")
- raise ValueError(f"Model '{model_name}' is not available.")
-
- model = d._class(
- model=model_name,
- api_key=d.access_info,
- temperature=temperature,
- )
-
- return model
-
- if llm_config is None:
- raise ValueError(
- "When using LLM support, please add necessary parameters to configuration file."
- )
-
- # LLM Instantiation
- try:
- match llm_config.type:
- case "sap":
- model = create_sap_provider(
- llm_config.model_name, llm_config.temperature
- )
- case "third_party":
- model = create_third_party_provider(
- llm_config.model_name, llm_config.temperature
- )
- case _:
- logger.error(
- f"Invalid LLM type specified, '{llm_config.type}' is not available."
- )
- raise ValueError(
- f"Invalid LLM type specified, '{llm_config.type}' is not available."
- )
- except Exception as e:
- logger.error(f"Problem when initialising model: {e}")
- raise ValueError(f"Problem when initialising model: {e}")
-
- return model
diff --git a/prospector/llm/models.py b/prospector/llm/models.py
deleted file mode 100644
index cb3fb0d81..000000000
--- a/prospector/llm/models.py
+++ /dev/null
@@ -1,193 +0,0 @@
-import json
-from typing import Any, List, Mapping, Optional
-
-import requests
-from dotenv import dotenv_values
-from langchain_core.language_models.llms import LLM
-
-from log.logger import logger
-
-
-class SAPProvider(LLM):
- model_name: str
- deployment_url: str
- temperature: float
-
- @property
- def _llm_type(self) -> str:
- return "custom"
-
- @property
- def _identifying_params(self) -> Mapping[str, Any]:
- """Get the identifying parameters."""
- return {
- "model_name": self.model_name,
- }
-
- def _call(
- self,
- prompt: str,
- stop: Optional[List[str]] = None,
- **kwargs: Any,
- ) -> str:
- """Run the LLM on the given input.
-
- Override this method to implement the LLM logic.
-
- Args:
- prompt: The prompt to generate from.
- stop: Stop words to use when generating. Model output is cut off at the
- first occurrence of any of the stop substrings.
- If stop tokens are not supported consider raising NotImplementedError.
- run_manager: Callback manager for the run.
- **kwargs: Arbitrary additional keyword arguments. These are usually passed
- to the model provider API call.
-
- Returns:
- The model output as a string. Actual completions SHOULD NOT include the prompt.
- """
- if self.deployment_url is None:
- raise ValueError(
- "Deployment URL not set. Maybe you forgot to create the environment variable."
- )
- if stop is not None:
- raise ValueError("stop kwargs are not permitted.")
- return ""
-
-
-class OpenAI(SAPProvider):
- def _call(
- self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
- ) -> str:
- # Call super() to make sure model_name is valid
- super()._call(prompt, stop, **kwargs)
- # Model specific request data
- endpoint = f"{self.deployment_url}/chat/completions?api-version=2023-05-15"
- headers = get_headers()
- data = {
- "messages": [
- {
- "role": "user",
- "content": f"{prompt}",
- }
- ],
- "temperature": self.temperature,
- }
-
- response = requests.post(endpoint, headers=headers, json=data)
-
- if not response.status_code == 200:
- logger.error(
- f"Invalid response from AI Core API with error code {response.status_code}"
- )
- raise Exception("Invalid response from AI Core API.")
-
- return self.parse(response.json())
-
- def parse(self, message) -> str:
- """Parse the returned JSON object from OpenAI."""
- return message["choices"][0]["message"]["content"]
-
-
-class Gemini(SAPProvider):
- def _call(
- self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
- ) -> str:
- # Call super() to make sure model_name is valid
- super()._call(prompt, stop, **kwargs)
- # Model specific request data
- endpoint = f"{self.deployment_url}/models/{self.model_name}:generateContent"
- headers = get_headers()
- data = {
- "generation_config": {
- "maxOutputTokens": 1000,
- "temperature": self.temperature,
- },
- "contents": [{"role": "user", "parts": [{"text": prompt}]}],
- "safetySettings": [
- {
- "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
- "threshold": "BLOCK_NONE",
- },
- {
- "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
- "threshold": "BLOCK_NONE",
- },
- {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
- {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
- ],
- }
-
- response = requests.post(endpoint, headers=headers, json=data)
-
- if not response.status_code == 200:
- logger.error(
- f"Invalid response from AI Core API with error code {response.status_code}"
- )
- raise Exception("Invalid response from AI Core API.")
-
- return self.parse(response.json())
-
- def parse(self, message) -> str:
- """Parse the returned JSON object from OpenAI."""
- return message["candidates"][0]["content"]["parts"][0]["text"]
-
-
-class Mistral(SAPProvider):
- def _call(
- self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
- ) -> str:
- # Call super() to make sure model_name is valid
- super()._call(prompt, stop, **kwargs)
- # Model specific request data
- endpoint = f"{self.deployment_url}/chat/completions"
- headers = get_headers()
- data = {
- "model": "mistralai--mixtral-8x7b-instruct-v01",
- "max_tokens": 100,
- "temperature": self.temperature,
- "messages": [{"role": "user", "content": prompt}],
- }
-
- response = requests.post(endpoint, headers=headers, json=data)
-
- if not response.status_code == 200:
- logger.error(
- f"Invalid response from AI Core API with error code {response.status_code}"
- )
- raise Exception("Invalid response from AI Core API.")
-
- return self.parse(response.json())
-
- def parse(self, message) -> str:
- """Parse the returned JSON object from OpenAI."""
- return message["choices"][0]["message"]["content"]
-
-
-def get_headers():
- """Generate the request headers to use SAP AI Core. This method generates the authentication token and returns a Dict with headers.
-
- Returns:
- The headers object needed to send requests to the SAP AI Core.
- """
- with open(dotenv_values()["AI_CORE_KEY_FILEPATH"]) as f:
- sk = json.load(f)
-
- auth_url = f"{sk['url']}/oauth/token"
- client_id = sk["clientid"]
- client_secret = sk["clientsecret"]
- # api_base_url = f"{sk['serviceurls']['AI_API_URL']}/v2"
-
- response = requests.post(
- auth_url,
- data={"grant_type": "client_credentials"},
- auth=(client_id, client_secret),
- timeout=8000,
- )
-
- headers = {
- "AI-Resource-Group": "default",
- "Content-Type": "application/json",
- "Authorization": f"Bearer {response.json()['access_token']}",
- }
- return headers
diff --git a/prospector/llm/models/gemini.py b/prospector/llm/models/gemini.py
new file mode 100644
index 000000000..147086254
--- /dev/null
+++ b/prospector/llm/models/gemini.py
@@ -0,0 +1,87 @@
+from typing import Any, Dict, List, Optional
+
+import requests
+from langchain_core.language_models.llms import LLM
+
+import llm.instantiation as instantiation
+from log.logger import logger
+
+
+class Gemini(LLM):
+ model_name: str
+ deployment_url: str
+ temperature: float
+ ai_core_sk_filepath: str
+
+ @property
+ def _llm_type(self) -> str:
+ return "SAP Gemini"
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Return a dictionary of identifying parameters."""
+ return {
+ "model_name": self.model_name,
+ "deployment_url": self.deployment_url,
+ "temperature": self.temperature,
+ "ai_core_sk_filepath": self.ai_core_sk_filepath,
+ }
+
+ def _call(
+ self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
+ ) -> str:
+ endpoint = f"{self.deployment_url}/models/{self.model_name}:generateContent"
+ headers = instantiation.get_headers(self.ai_core_sk_filepath)
+ data = {
+ "generation_config": {
+ "maxOutputTokens": 1000,
+ "temperature": self.temperature,
+ },
+ "contents": [{"role": "user", "parts": [{"text": prompt}]}],
+ "safetySettings": [
+ {
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
+ "threshold": "BLOCK_NONE",
+ },
+ {
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
+ "threshold": "BLOCK_NONE",
+ },
+ {
+ "category": "HARM_CATEGORY_HATE_SPEECH",
+ "threshold": "BLOCK_NONE",
+ },
+ {
+ "category": "HARM_CATEGORY_HARASSMENT",
+ "threshold": "BLOCK_NONE",
+ },
+ ],
+ }
+
+ try:
+ response = requests.post(endpoint, headers=headers, json=data)
+ return self.parse(response.json())
+ except requests.exceptions.HTTPError as http_error:
+ logger.error(
+ f"HTTP error occurred when sending a request through AI Core: {http_error}"
+ )
+ raise
+ except requests.exceptions.Timeout as timeout_err:
+ logger.error(
+ f"Timeout error occured when sending a request through AI Core: {timeout_err}"
+ )
+ raise
+ except requests.exceptions.ConnectionError as conn_err:
+ logger.error(
+ f"Connection error occurred when sending a request through AI Core: {conn_err}"
+ )
+ raise
+ except requests.exceptions.RequestException as req_err:
+ logger.error(
+ f"A request error occured when sending a request through AI Core: {req_err}"
+ )
+ raise
+
+ def parse(self, message) -> str:
+ """Parse the returned JSON object from OpenAI."""
+ return message["candidates"][0]["content"]["parts"][0]["text"]
diff --git a/prospector/llm/models/mistral.py b/prospector/llm/models/mistral.py
new file mode 100644
index 000000000..9708d8e31
--- /dev/null
+++ b/prospector/llm/models/mistral.py
@@ -0,0 +1,68 @@
+from typing import Any, Dict, List, Optional
+
+import requests
+from langchain_core.language_models.llms import LLM
+
+import llm.instantiation as instantiation
+from log.logger import logger
+
+
+class Mistral(LLM):
+ model_name: str
+ deployment_url: str
+ temperature: float
+ ai_core_sk_filepath: str
+
+ @property
+ def _llm_type(self) -> str:
+ return "SAP Mistral"
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Return a dictionary of identifying parameters."""
+ return {
+ "model_name": self.model_name,
+ "deployment_url": self.deployment_url,
+ "temperature": self.temperature,
+ "ai_core_sk_filepath": self.ai_core_sk_filepath,
+ }
+
+ def _call(
+ self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
+ ) -> str:
+ endpoint = f"{self.deployment_url}/chat/completions"
+ headers = instantiation.get_headers(self.ai_core_sk_filepath)
+ data = {
+ "model": "mistralai--mixtral-8x7b-instruct-v01",
+ "max_tokens": 100,
+ "temperature": self.temperature,
+ "messages": [{"role": "user", "content": prompt}],
+ }
+
+ try:
+ response = requests.post(endpoint, headers=headers, json=data)
+ return self.parse(response.json())
+ except requests.exceptions.HTTPError as http_error:
+ logger.error(
+ f"HTTP error occurred when sending a request through AI Core: {http_error}"
+ )
+ raise
+ except requests.exceptions.Timeout as timeout_err:
+ logger.error(
+ f"Timeout error occured when sending a request through AI Core: {timeout_err}"
+ )
+ raise
+ except requests.exceptions.ConnectionError as conn_err:
+ logger.error(
+ f"Connection error occurred when sending a request through AI Core: {conn_err}"
+ )
+ raise
+ except requests.exceptions.RequestException as req_err:
+ logger.error(
+ f"A request error occured when sending a request through AI Core: {req_err}"
+ )
+ raise
+
+ def parse(self, message) -> str:
+ """Parse the returned JSON object from OpenAI."""
+ return message["choices"][0]["message"]["content"]
diff --git a/prospector/llm/models/openai.py b/prospector/llm/models/openai.py
new file mode 100644
index 000000000..76d95ef5b
--- /dev/null
+++ b/prospector/llm/models/openai.py
@@ -0,0 +1,71 @@
+from typing import Any, Dict, List, Optional
+
+import requests
+from langchain_core.language_models.llms import LLM
+
+import llm.instantiation as instantiation
+from log.logger import logger
+
+
+class OpenAI(LLM):
+ model_name: str
+ deployment_url: str
+ temperature: float
+ ai_core_sk_filepath: str
+
+ @property
+ def _llm_type(self) -> str:
+ return "SAP OpenAI"
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Return a dictionary of identifying parameters."""
+ return {
+ "model_name": self.model_name,
+ "deployment_url": self.deployment_url,
+ "temperature": self.temperature,
+ "ai_core_sk_filepath": self.ai_core_sk_filepath,
+ }
+
+ def _call(
+ self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
+ ) -> str:
+ endpoint = f"{self.deployment_url}/chat/completions?api-version=2023-05-15"
+ headers = instantiation.get_headers(self.ai_core_sk_filepath)
+ data = {
+ "messages": [
+ {
+ "role": "user",
+ "content": f"{prompt}",
+ }
+ ],
+ "temperature": self.temperature,
+ }
+
+ try:
+ response = requests.post(endpoint, headers=headers, json=data)
+ return self.parse(response.json())
+ except requests.exceptions.HTTPError as http_error:
+ logger.error(
+ f"HTTP error occurred when sending a request through AI Core: {http_error}"
+ )
+ raise
+ except requests.exceptions.Timeout as timeout_err:
+ logger.error(
+ f"Timeout error occured when sending a request through AI Core: {timeout_err}"
+ )
+ raise
+ except requests.exceptions.ConnectionError as conn_err:
+ logger.error(
+ f"Connection error occurred when sending a request through AI Core: {conn_err}"
+ )
+ raise
+ except requests.exceptions.RequestException as req_err:
+ logger.error(
+ f"A request error occured when sending a request through AI Core: {req_err}"
+ )
+ raise
+
+ def parse(self, message) -> str:
+ """Parse the returned JSON object from OpenAI."""
+ return message["choices"][0]["message"]["content"]
diff --git a/prospector/llm/operations.py b/prospector/llm/operations.py
deleted file mode 100644
index f157e8590..000000000
--- a/prospector/llm/operations.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import sys
-from typing import Dict
-
-import validators
-from langchain_core.language_models.llms import LLM
-
-from cli.console import ConsoleWriter, MessageStatus
-from datamodel.advisory import get_from_mitre
-from llm.prompts import best_guess
-from log.logger import logger
-
-
-def get_repository_url(model: LLM, vuln_id: str):
- """Ask an LLM to obtain the repository URL given the advisory description and references.
-
- Args:
- model (LLM): The instantiated model (instantiated with create_model_instance())
- vuln_id: The ID of the advisory, e.g. CVE-2020-1925.
-
- Returns:
- The repository URL as a string.
-
- Raises:
- ValueError if advisory information cannot be obtained or there is an error in the model invocation.
- """
- with ConsoleWriter("Invoking LLM") as console:
- details, _ = get_from_mitre(vuln_id)
- if details is None:
- logger.error("Error when getting advisory information from Mitre.")
- console.print(
- "Error when getting advisory information from Mitre.",
- status=MessageStatus.ERROR,
- )
- sys.exit(1)
-
- try:
- chain = best_guess | model
-
- url = chain.invoke(
- {
- "description": details["descriptions"][0]["value"],
- "references": details["references"],
- }
- )
- if not validators.url(url):
- logger.error(f"LLM returned invalid URL: {url}")
- console.print(
- f"LLM returned invalid URL: {url}",
- status=MessageStatus.ERROR,
- )
- sys.exit(1)
- except Exception as e:
- logger.error(f"Prompt-model chain could not be invoked: {e}")
- console.print(
- "Prompt-model chain could not be invoked.",
- status=MessageStatus.ERROR,
- )
- sys.exit(1)
-
- return url
diff --git a/prospector/llm/prompts.py b/prospector/llm/prompts.py
index 57fd2444a..dd1c9c663 100644
--- a/prospector/llm/prompts.py
+++ b/prospector/llm/prompts.py
@@ -1,7 +1,7 @@
from langchain.prompts import FewShotPromptTemplate, PromptTemplate
-# example output for few-shot prompting
-examples_without_num = [
+# Get Repository URL, few-shot prompting examples
+examples_data = [
{
"cve_description": "Apache Olingo versions 4.0.0 to 4.7.0 provide the AsyncRequestWrapperImpl class which reads a URL from the Location header, and then sends a GET or DELETE request to this URL. It may allow to implement a SSRF attack. If an attacker tricks a client to connect to a malicious server, the server can make the client call any URL including internal resources which are not directly accessible by the attacker.",
"cve_references": "https://www.zerodayinitiative.com/advisories/ZDI-24-196/",
@@ -20,7 +20,7 @@
]
# Formatter for the few-shot examples without CVE numbers
-examples_prompt_without_num = PromptTemplate(
+examples_formatted = PromptTemplate(
input_variables=["cve_references", "result"],
template=""" {cve_description}
{cve_references}
@@ -28,18 +28,18 @@
""",
)
-best_guess = FewShotPromptTemplate(
+prompt_best_guess = FewShotPromptTemplate(
prefix="""You will be provided with the ID, description and references of a vulnerability advisory (CVE). Return nothing but the URL of the repository the given CVE is concerned with.'.
Here are a few examples delimited with XML tags:""",
- examples=examples_without_num,
- example_prompt=examples_prompt_without_num,
+ examples=examples_data,
+ example_prompt=examples_formatted,
suffix="""Here is the CVE information:
{description}
{references}
-If you cannot find the URL, return your best guess of what the repository URL could be. Use any hints (eg. the mention of GitHub or GitLab) in the CVE description and references. Return nothing but the URL.
+If you cannot find the URL, return your best guess of what the repository URL could be. Use any hints (eg. the mention of GitHub or GitLab) in the CVE description and references. Do not return the delimiters. Do not return delimiters. Return nothing but the URL.
""",
input_variables=["description", "references"],
- metadata={"name": "best_guess"},
+ metadata={"name": "prompt_best_guess"},
)
diff --git a/prospector/llm/test_llm.py b/prospector/llm/test_llm.py
deleted file mode 100644
index 9ef091659..000000000
--- a/prospector/llm/test_llm.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import pytest
-import requests
-from langchain_openai import ChatOpenAI
-
-from llm.models import Gemini, Mistral, OpenAI
-from llm.operations import create_model_instance, get_repository_url
-
-
-# Mock the llm_service configuration object
-class Config:
- type: str = None
- model_name: str = None
-
- def __init__(self, type, model_name):
- self.type = type
- self.model_name = model_name
-
-
-# Vulnerability ID
-vuln_id = "CVE-2024-32480"
-
-
-class TestModel:
- def test_sap_gpt35_instantiation(self):
- config = Config("sap", "gpt-35-turbo")
- model = create_model_instance(config)
- assert isinstance(model, OpenAI)
-
- def test_sap_gpt4_instantiation(self):
- config = Config("sap", "gpt-4")
- model = create_model_instance(config)
- assert isinstance(model, OpenAI)
-
- def test_thirdparty_gpt35_instantiation(self):
- config = Config("third_party", "gpt-3.5-turbo")
- model = create_model_instance(config)
- assert isinstance(model, ChatOpenAI)
-
- def test_thirdparty_gpt4_instantiation(self):
- config = Config("third_party", "gpt-4")
- model = create_model_instance(config)
- assert isinstance(model, ChatOpenAI)
-
- def test_invoke_fail(self):
- with pytest.raises(SystemExit):
- config = Config("sap", "gpt-35-turbo")
- vuln_id = "random"
- get_repository_url(llm_config=config, vuln_id=vuln_id)
diff --git a/prospector/requirements.txt b/prospector/requirements.txt
index dc864b5d7..0ca435446 100644
--- a/prospector/requirements.txt
+++ b/prospector/requirements.txt
@@ -4,9 +4,6 @@
#
# pip-compile --no-annotate --strip-extras
#
---extra-index-url https://int.repositories.cloud.sap/artifactory/api/pypi/deploy-releases-pypi/simple
---extra-index-url https://int.repositories.cloud.sap/artifactory/api/pypi/proxy-deploy-releases-hyperspace-pypi/simple
---trusted-host int.repositories.cloud.sap
aiohttp==3.9.5
aiosignal==1.3.1
diff --git a/prospector/util/config_parser.py b/prospector/util/config_parser.py
index 281b17bb3..aace08cbf 100644
--- a/prospector/util/config_parser.py
+++ b/prospector/util/config_parser.py
@@ -181,6 +181,8 @@ class ReportConfig:
class LLMServiceConfig:
type: str
model_name: str
+ use_llm_repository_url: bool
+ ai_core_sk: str
temperature: float = 0.0
@@ -194,7 +196,6 @@ class ConfigSchema:
use_nvd: bool = MISSING
use_backend: str = MISSING
backend: str = MISSING
- use_llm_repository_url: bool = MISSING
report: ReportConfig = MISSING
log_level: str = MISSING
git_cache: str = MISSING
@@ -226,7 +227,6 @@ def __init__(
fetch_references: bool,
use_backend: str,
backend: str,
- use_llm_repository_url: bool,
report: ReportConfig,
report_filename: str,
ping: bool,
@@ -237,7 +237,6 @@ def __init__(
):
self.vuln_id = vuln_id
self.repository = repository
- self.use_llm_repository_url = use_llm_repository_url
self.llm_service = llm_service
self.preprocess_only = preprocess_only
self.pub_date = pub_date
@@ -267,11 +266,13 @@ def get_configuration(argv):
sys.exit(
"No configuration file found, or error in configuration file. Check logs."
)
+ # --repository in CL overrides config.yaml settings for LLM usage
+ if args.repository:
+ conf.llm_service.use_llm_repository_url = False
try:
config = Config(
vuln_id=args.vuln_id,
repository=args.repository,
- use_llm_repository_url=conf.use_llm_repository_url,
llm_service=conf.llm_service,
preprocess_only=args.preprocess_only or conf.preprocess_only,
pub_date=args.pub_date,
diff --git a/prospector/util/singleton.py b/prospector/util/singleton.py
new file mode 100644
index 000000000..cbee215f0
--- /dev/null
+++ b/prospector/util/singleton.py
@@ -0,0 +1,16 @@
+from log.logger import logger
+
+
+class Singleton(type):
+ """Singleton class to ensure that any class inheriting from this one can only be instantiated once."""
+
+ _instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
+ else:
+ logger.info(
+ f"Cannot instantiate a Singleton twice. Returning already existing instance of class {cls}."
+ )
+ return cls._instances[cls]