diff --git a/prospector/cli/main.py b/prospector/cli/main.py index a51fcdb70..95b5ef723 100644 --- a/prospector/cli/main.py +++ b/prospector/cli/main.py @@ -8,6 +8,7 @@ from dotenv import load_dotenv import llm.operations as llm +from llm.model_instantiation import create_model_instance from util.http import ping_backend path_root = os.getcwd() @@ -54,6 +55,10 @@ def main(argv): # noqa: C901 ) return + # instantiate LLM model if set in config.yaml + if config.llm_service: + model = create_model_instance(llm_config=config.llm_service) + if not config.repository and not config.use_llm_repository_url: logger.error( "Either provide the repository URL or allow LLM usage to obtain it." @@ -76,11 +81,8 @@ def main(argv): # noqa: C901 logger.debug("Vulnerability ID: " + config.vuln_id) - # whether to use LLM support if not config.repository: - config.repository = llm.get_repository_url( - llm_config=config.llm_service, vuln_id=config.vuln_id - ) + config.repository = llm.get_repository_url(model=model, vuln_id=config.vuln_id) results, advisory_record = prospector( vulnerability_id=config.vuln_id, diff --git a/prospector/llm/operations.py b/prospector/llm/operations.py index 39a74ae27..f157e8590 100644 --- a/prospector/llm/operations.py +++ b/prospector/llm/operations.py @@ -2,22 +2,19 @@ from typing import Dict import validators +from langchain_core.language_models.llms import LLM from cli.console import ConsoleWriter, MessageStatus from datamodel.advisory import get_from_mitre -from llm.model_instantiation import create_model_instance from llm.prompts import best_guess from log.logger import logger -def get_repository_url(llm_config: Dict, vuln_id: str): +def get_repository_url(model: LLM, vuln_id: str): """Ask an LLM to obtain the repository URL given the advisory description and references. Args: - llm_config (dict): A dictionary containing the configuration for the LLM. Expected keys are: - - 'type' (str): Method for accessing the LLM API ('sap' for SAP's AI Core, 'third_party' for - external providers). - - 'model_name' (str): Which model to use, e.g. gpt-4. + model (LLM): The instantiated model (instantiated with create_model_instance()) vuln_id: The ID of the advisory, e.g. CVE-2020-1925. Returns: @@ -37,7 +34,6 @@ def get_repository_url(llm_config: Dict, vuln_id: str): sys.exit(1) try: - model = create_model_instance(llm_config=llm_config) chain = best_guess | model url = chain.invoke(