Skip to content

Commit

Permalink
changes code structure now model gets instantiated and llm functions …
Browse files Browse the repository at this point in the history
…can be called with this model. This is because only one instantiation of the model is needed throughout the whole runtime of prospector
  • Loading branch information
lauraschauer committed Jun 11, 2024
1 parent 32d7b97 commit 0a09951
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
10 changes: 6 additions & 4 deletions prospector/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dotenv import load_dotenv

import llm.operations as llm
from llm.model_instantiation import create_model_instance
from util.http import ping_backend

path_root = os.getcwd()
Expand Down Expand Up @@ -54,6 +55,10 @@ def main(argv): # noqa: C901
)
return

# instantiate LLM model if set in config.yaml
if config.llm_service:
model = create_model_instance(llm_config=config.llm_service)

if not config.repository and not config.use_llm_repository_url:
logger.error(
"Either provide the repository URL or allow LLM usage to obtain it."
Expand All @@ -76,11 +81,8 @@ def main(argv): # noqa: C901

logger.debug("Vulnerability ID: " + config.vuln_id)

# whether to use LLM support
if not config.repository:
config.repository = llm.get_repository_url(
llm_config=config.llm_service, vuln_id=config.vuln_id
)
config.repository = llm.get_repository_url(model=model, vuln_id=config.vuln_id)

results, advisory_record = prospector(
vulnerability_id=config.vuln_id,
Expand Down
10 changes: 3 additions & 7 deletions prospector/llm/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,19 @@
from typing import Dict

import validators
from langchain_core.language_models.llms import LLM

from cli.console import ConsoleWriter, MessageStatus
from datamodel.advisory import get_from_mitre
from llm.model_instantiation import create_model_instance
from llm.prompts import best_guess
from log.logger import logger


def get_repository_url(llm_config: Dict, vuln_id: str):
def get_repository_url(model: LLM, vuln_id: str):
"""Ask an LLM to obtain the repository URL given the advisory description and references.
Args:
llm_config (dict): A dictionary containing the configuration for the LLM. Expected keys are:
- 'type' (str): Method for accessing the LLM API ('sap' for SAP's AI Core, 'third_party' for
external providers).
- 'model_name' (str): Which model to use, e.g. gpt-4.
model (LLM): The instantiated model (instantiated with create_model_instance())
vuln_id: The ID of the advisory, e.g. CVE-2020-1925.
Returns:
Expand All @@ -37,7 +34,6 @@ def get_repository_url(llm_config: Dict, vuln_id: str):
sys.exit(1)

try:
model = create_model_instance(llm_config=llm_config)
chain = best_guess | model

url = chain.invoke(
Expand Down

0 comments on commit 0a09951

Please sign in to comment.