From 0a099514c7f384c31f77a6169906b481a324ba3a Mon Sep 17 00:00:00 2001 From: I748376 Date: Tue, 11 Jun 2024 15:32:56 +0000 Subject: [PATCH] changes code structure now model gets instantiated and llm functions can be called with this model. This is because only one instantiation of the model is needed throughout the whole runtime of prospector --- prospector/cli/main.py | 10 ++++++---- prospector/llm/operations.py | 10 +++------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/prospector/cli/main.py b/prospector/cli/main.py index a51fcdb70..95b5ef723 100644 --- a/prospector/cli/main.py +++ b/prospector/cli/main.py @@ -8,6 +8,7 @@ from dotenv import load_dotenv import llm.operations as llm +from llm.model_instantiation import create_model_instance from util.http import ping_backend path_root = os.getcwd() @@ -54,6 +55,10 @@ def main(argv): # noqa: C901 ) return + # instantiate LLM model if set in config.yaml + if config.llm_service: + model = create_model_instance(llm_config=config.llm_service) + if not config.repository and not config.use_llm_repository_url: logger.error( "Either provide the repository URL or allow LLM usage to obtain it." @@ -76,11 +81,8 @@ def main(argv): # noqa: C901 logger.debug("Vulnerability ID: " + config.vuln_id) - # whether to use LLM support if not config.repository: - config.repository = llm.get_repository_url( - llm_config=config.llm_service, vuln_id=config.vuln_id - ) + config.repository = llm.get_repository_url(model=model, vuln_id=config.vuln_id) results, advisory_record = prospector( vulnerability_id=config.vuln_id, diff --git a/prospector/llm/operations.py b/prospector/llm/operations.py index 39a74ae27..f157e8590 100644 --- a/prospector/llm/operations.py +++ b/prospector/llm/operations.py @@ -2,22 +2,19 @@ from typing import Dict import validators +from langchain_core.language_models.llms import LLM from cli.console import ConsoleWriter, MessageStatus from datamodel.advisory import get_from_mitre -from llm.model_instantiation import create_model_instance from llm.prompts import best_guess from log.logger import logger -def get_repository_url(llm_config: Dict, vuln_id: str): +def get_repository_url(model: LLM, vuln_id: str): """Ask an LLM to obtain the repository URL given the advisory description and references. Args: - llm_config (dict): A dictionary containing the configuration for the LLM. Expected keys are: - - 'type' (str): Method for accessing the LLM API ('sap' for SAP's AI Core, 'third_party' for - external providers). - - 'model_name' (str): Which model to use, e.g. gpt-4. + model (LLM): The instantiated model (instantiated with create_model_instance()) vuln_id: The ID of the advisory, e.g. CVE-2020-1925. Returns: @@ -37,7 +34,6 @@ def get_repository_url(llm_config: Dict, vuln_id: str): sys.exit(1) try: - model = create_model_instance(llm_config=llm_config) chain = best_guess | model url = chain.invoke(