diff --git a/.gitignore b/.gitignore index 96042fb68..0b56f38a4 100644 --- a/.gitignore +++ b/.gitignore @@ -41,6 +41,7 @@ prospector/install_fastext.sh prospector/nvd.ipynb prospector/data/nvd.pkl prospector/data/nvd.csv +prospector/data_sources/reports .vscode/settings.json prospector/cov_html/* prospector/client/cli/cov_html/* @@ -51,6 +52,7 @@ prospector/.coverage **/cov_html prospector/cov_html .coverage +prospector/.venv prospector/prospector.code-workspace prospector/requests-cache.sqlite prospector/prospector-report.html diff --git a/prospector/README.md b/prospector/README.md index 991409403..f7729c7a8 100644 --- a/prospector/README.md +++ b/prospector/README.md @@ -5,18 +5,29 @@ currently under development: the instructions below are intended for development :exclamation: Please note that **Windows is not supported** while WSL and WSL2 are fine. -## Description +## Table of Contents + +1. [Description](#description) +2. [Quick Setup & Run](#setup--run) +3. [Development Setup](#development-setup) +4. [Contributing](#contributing) +5. [History](#history) + +## 📖 Description Prospector is a tool to reduce the effort needed to find security fixes for *known* vulnerabilities in open source software repositories. Given an advisory expressed in natural language, Prospector processes the commits found in the target source code repository, ranks them based on a set of predefined rules, and produces a report that the user can inspect to determine which commits to retain as the actual fix. -## Setup & Run +## ⚡️ Quick Setup & Run + +Prerequisites: -:warning: The tool requires Docker and Docker-compose, as it employes Docker containers for certain functionalities. Make sure you have Docker installed and running before proceeding with the setup and usage of Prospector. +* Docker (make sure you have Docker installed and running before proceeding with the setup) +* Docker-compose -To quickly set up Prospector: +To quickly set up Prospector, follow these steps. This will run Prospector in its containerised version. If you wish to debug or run Prospector's components individually, follow the steps below at [Development Setup](#development-setup). 1. Clone the project KB repository ``` @@ -44,7 +55,52 @@ To quickly set up Prospector: By default, Prospector saves the results in a HTML file named *prospector-report.html*. Open this file in a web browser to view what Prospector was able to find! -## Development Setup +### 🤖 LLM Support + +To use Prospector with LLM support, set the `use_llm_<...>` parameters in `config.yaml`. Additionally, you must specify required parameters for API access to the LLM. These parameters can vary depending on your choice of provider, please follow what fits your needs: + +
Use SAP AI CORE SDK + +You will need the following parameters in `config.yaml`: + +```yaml +llm_service: + type: sap + model_name: +``` + +`` refers to the model names available in the Generative AI Hub in SAP AI Core. [Here](https://github.tools.sap/I343697/generative-ai-hub-readme#1-supported-models) you can find an overview of available models. + +In `.env`, you must set the deployment URL as an environment variable following this naming convention: +```yaml +_URL +``` + +
+ +
Use personal third party provider + +Implemented third party providers are **OpenAI**, **Google** and **Mistral**. + +1. You will need the following parameters in `config.yaml`: + ```yaml + llm_service: + type: third_party + model_name: + ``` + + `` refers to the model names available, for example `gpt-4o` for OpenAI. You can find a lists of available models here: + 1. [OpenAI](https://platform.openai.com/docs/models) + 2. [Google](https://ai.google.dev/gemini-api/docs/models/gemini) + 3. [Mistral](https://docs.mistral.ai/getting-started/models/) + +2. Make sure to add your OpenAI API key to your `.env` file as `[OPENAI|GOOGLE|MISTRAL]_API_KEY`. + +
+ +## 👩‍💻 Development Setup + +Following these steps allows you to run Prospector's components individually: [Backend database and worker containers](#starting-the-backend-database-and-the-job-workers), [RESTful Server](#starting-the-restful-server) for API endpoints, [Prospector CLI](#running-the-cli-version) and [Tests](#testing). Prerequisites: @@ -53,6 +109,8 @@ Prerequisites: * gcc g++ libffi-dev python3-dev libpq-dev * Docker & Docker-compose +### General + You can setup everything and install the dependencies by running: ``` make setup @@ -81,11 +139,13 @@ your editor so that autoformatting is enforced "on save". The pre-commit hook en black is run prior to committing anyway, but the auto-formatting might save you some time and avoid frustration. -If you use VSCode, this can be achieved by pasting these lines in your configuration file: +If you use VSCode, this can be achieved by installing the Black Formatter extension and pasting these lines in your configuration file: -``` - "python.formatting.provider": "black", - "editor.formatOnSave": true, +```json + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnSave": true, + } ``` ### Starting the backend database and the job workers @@ -94,17 +154,23 @@ If you run the client without running the backend you will get a warning and hav You can then start the necessary containers with the following command: -`make docker-setup` +```bash +make docker-setup +``` This also starts a convenient DB administration tool at http://localhost:8080 If you wish to cleanup docker to run a fresh version of the backend you can run: -`make docker-clean` +```bash +make docker-clean +``` ### Starting the RESTful server -`uvicorn api.main:app --reload` +```bash +uvicorn service.main:app --reload +``` Note, that it requires `POSTGRES_USER`, `POSTGRES_HOST`, `POSTGRES_PORT`, `POSTGRES_DBNAME` to be set in the .env file. @@ -113,7 +179,9 @@ You might also want to take a look at `http://127.0.0.1:8000/docs`. *Alternatively*, you can execute the RESTful server explicitly with: -`python api/main.py` +```bash +python api/main.py +``` which is equivalent but more convenient for debugging. @@ -127,11 +195,13 @@ Prospector makes use of `pytest`. :exclamation: **NOTE:** before using it please make sure to have running instances of the backend and the database. +## 🤝 Contributing + If you find a bug, please open an issue. If you can also fix the bug, please create a pull request (make sure it includes a test case that passes with your correction but fails without it) -## History +## 🕰️ History The high-level structure of Prospector follows the approach of its predecessor FixFinder, which is described in: diff --git a/prospector/cli/main.py b/prospector/cli/main.py index f29a4a59c..95b5ef723 100644 --- a/prospector/cli/main.py +++ b/prospector/cli/main.py @@ -7,6 +7,8 @@ from dotenv import load_dotenv +import llm.operations as llm +from llm.model_instantiation import create_model_instance from util.http import ping_backend path_root = os.getcwd() @@ -32,10 +34,12 @@ def main(argv): # noqa: C901 with ConsoleWriter("Initialization") as console: config = get_configuration(argv) if not config: - logger.error("No configuration file found. Cannot proceed.") + logger.error( + "No configuration file found, or error in configuration file. Cannot proceed." + ) console.print( - "No configuration file found.", + "No configuration file found, or error in configuration file. Check logs.", status=MessageStatus.ERROR, ) return @@ -51,6 +55,20 @@ def main(argv): # noqa: C901 ) return + # instantiate LLM model if set in config.yaml + if config.llm_service: + model = create_model_instance(llm_config=config.llm_service) + + if not config.repository and not config.use_llm_repository_url: + logger.error( + "Either provide the repository URL or allow LLM usage to obtain it." + ) + console.print( + "Either provide the repository URL or allow LLM usage to obtain it.", + status=MessageStatus.ERROR, + ) + sys.exit(1) + # if config.ping: # return ping_backend(backend, get_level() < logging.INFO) @@ -63,6 +81,9 @@ def main(argv): # noqa: C901 logger.debug("Vulnerability ID: " + config.vuln_id) + if not config.repository: + config.repository = llm.get_repository_url(model=model, vuln_id=config.vuln_id) + results, advisory_record = prospector( vulnerability_id=config.vuln_id, repository_url=config.repository, @@ -88,7 +109,7 @@ def main(argv): # noqa: C901 ) execution_time = execution_statistics["core"]["execution time"][0] - ConsoleWriter.print(f"Execution time: {execution_time:.3f}s") + ConsoleWriter.print(f"Execution time: {execution_time:.3f}s\n") return diff --git a/prospector/config-sample.yaml b/prospector/config-sample.yaml index 92feb6596..7208bc3dd 100644 --- a/prospector/config-sample.yaml +++ b/prospector/config-sample.yaml @@ -1,5 +1,3 @@ - - # Wheter to preprocess only the repository's commits or fully run prospector preprocess_only: False @@ -12,13 +10,12 @@ fetch_references: False use_nvd: True # The NVD API token -nvd_token: Null +# nvd_token: # Wheter to use a backend or not: "always", "never", "optional" use_backend: optional -# Optional backend info to save/use already preprocessed data -#backend: http://backend:8000 +# Backend address; when in containerised version, use http://backend:8000, otherwise http://localhost:8000 backend: http://localhost:8000 database: @@ -30,6 +27,14 @@ database: redis_url: redis://redis:6379/0 +# LLM Usage (check README for help) +llm_service: + type: sap # use "sap" or "third_party" + model_name: gpt-4-turbo + # temperature: 0.0 # optional, default is 0.0 + +use_llm_repository_url: True # whether to use LLM's to obtain the repository URL + # Report file format: "html", "json", "console" or "all" # and the file name report: @@ -43,4 +48,4 @@ log_level: INFO git_cache: /tmp/gitcache # The GitHub API token -github_token: Null +# github_token: diff --git a/prospector/core/prospector.py b/prospector/core/prospector.py index d132eb9c0..a95603aae 100644 --- a/prospector/core/prospector.py +++ b/prospector/core/prospector.py @@ -1,10 +1,12 @@ # flake8: noqa import logging +import os import re import sys import time from typing import Dict, List, Set, Tuple +from urllib.parse import urlparse import requests from tqdm import tqdm @@ -37,6 +39,9 @@ MAX_CANDIDATES = 2000 DEFAULT_BACKEND = "http://localhost:8000" +USE_BACKEND_ALWAYS = "always" +USE_BACKEND_OPTIONAL = "optional" +USE_BACKEND_NEVER = "never" core_statistics = execution_statistics.sub_collection("core") @@ -57,7 +62,7 @@ def prospector( # noqa: C901 use_nvd: bool = True, nvd_rest_endpoint: str = "", backend_address: str = DEFAULT_BACKEND, - use_backend: str = "always", + use_backend: str = USE_BACKEND_ALWAYS, git_cache: str = "/tmp/git_cache", limit_candidates: int = MAX_CANDIDATES, rules: List[str] = ["ALL"], @@ -145,7 +150,7 @@ def prospector( # noqa: C901 ) as timer: with ConsoleWriter("\nProcessing commits") as writer: try: - if use_backend != "never": + if use_backend != USE_BACKEND_NEVER: missing, preprocessed_commits = retrieve_preprocessed_commits( repository_url, backend_address, @@ -156,10 +161,13 @@ def prospector( # noqa: C901 "Backend not reachable", exc_info=get_level() < logging.WARNING, ) - if use_backend == "always": - print("Backend not reachable: aborting") + if use_backend == USE_BACKEND_ALWAYS: + if not is_correct_backend_url(backend_address): + print( + "The backend address should be 'backend:8000' when running the containerised version of Prospector, and 'localhost:8000' otherwise: Aborting." + ) sys.exit(1) - print("Backend not reachable: continuing") + print("Backend not reachable: Continuing.") if "missing" not in locals(): missing = list(candidates.values()) @@ -194,7 +202,7 @@ def prospector( # noqa: C901 payload = [c.to_dict() for c in preprocessed_commits] - if len(payload) > 0 and use_backend != "never" and len(missing) > 0: + if len(payload) > 0 and use_backend != USE_BACKEND_NEVER and len(missing) > 0: save_preprocessed_commits(backend_address, payload) else: logger.warning("Preprocessed commits are not being sent to backend") @@ -227,7 +235,7 @@ def preprocess_commits(commits: List[RawCommit], timer: ExecutionTimer) -> List[ def filter(commits: Dict[str, RawCommit]) -> Dict[str, RawCommit]: - with ConsoleWriter("\nCandidate filtering\n") as console: + with ConsoleWriter("\nCandidate filtering") as console: commits, rejected = filter_commits(commits) if rejected > 0: console.print(f"Dropped {rejected} candidates") @@ -420,6 +428,29 @@ def get_commits_no_tags(repository: Git, commit_ids: List[str]): return candidates +def is_correct_backend_url(backend_url: str) -> bool: + """Returns True if the backend URL set in the config file matches the way prospector is run. Returns False if + - Prospector is run containerised and backend_url is not 'backend:8000' + - Prospector is run locally and backend_url is not 'localhost:8000' + """ + parsed_config_url = urlparse(backend_url) + parsed_default_url = urlparse(DEFAULT_BACKEND) + + if parsed_config_url.port != 8000: + return False + + in_container = os.environ.get("IN_CONTAINER", "") == "1" + + if in_container: + if parsed_config_url.hostname != "backend": + return False + else: + if parsed_config_url.hostname != parsed_default_url.hostname: + return False + + return True + + # def prospector_find_twins( # advisory_record: AdvisoryRecord, # repository: Git, diff --git a/prospector/datamodel/nlp.py b/prospector/datamodel/nlp.py index 1c5ed76e9..150f4203f 100644 --- a/prospector/datamodel/nlp.py +++ b/prospector/datamodel/nlp.py @@ -139,23 +139,24 @@ def extract_ghissue_references(repository: str, text: str) -> Dict[str, str]: id = result.group(1) url = f"{repository}/issues/{id}" content = fetch_url(url=url, extract_text=False) - gh_ref_data = content.find_all( - attrs={ - "class": ["comment-body", "markdown-title"], - }, - recursive=False, - ) - # TODO: when an issue/pr is referenced somewhere, the page contains also the "message" of that reference (e.g. a commit). This may lead to unwanted detection of certain rules. - gh_ref_data.extend( - content.find_all( + if content is not None: + gh_ref_data = content.find_all( attrs={ - "id": re.compile(r"ref-issue|ref-pullrequest"), - } + "class": ["comment-body", "markdown-title"], + }, + recursive=False, + ) + # TODO: when an issue/pr is referenced somewhere, the page contains also the "message" of that reference (e.g. a commit). This may lead to unwanted detection of certain rules. + gh_ref_data.extend( + content.find_all( + attrs={ + "id": re.compile(r"ref-issue|ref-pullrequest"), + } + ) + ) + refs[id] = " ".join( + [" ".join(block.get_text().split()) for block in gh_ref_data] ) - ) - refs[id] = " ".join( - [" ".join(block.get_text().split()) for block in gh_ref_data] - ) return refs diff --git a/prospector/docker/cli/Dockerfile b/prospector/docker/cli/Dockerfile index 50df300d2..058e4bcdf 100644 --- a/prospector/docker/cli/Dockerfile +++ b/prospector/docker/cli/Dockerfile @@ -21,4 +21,7 @@ WORKDIR /clirun VOLUME ["/clirun"] ENV PYTHONPATH "${PYTHONPATH}:/clirun" +# check if Prospector is running containerised +ENV IN_CONTAINER=1 + ENTRYPOINT [ "python","cli/main.py" ] diff --git a/prospector/llm/model_instantiation.py b/prospector/llm/model_instantiation.py new file mode 100644 index 000000000..2ca1560f1 --- /dev/null +++ b/prospector/llm/model_instantiation.py @@ -0,0 +1,116 @@ +from typing import Dict + +from dotenv import dotenv_values +from langchain_core.language_models.llms import LLM +from langchain_google_vertexai import ChatVertexAI +from langchain_mistralai import ChatMistralAI +from langchain_openai import ChatOpenAI + +from llm.models import Gemini, Mistral, OpenAI +from log.logger import logger + + +class ModelDef: + def __init__(self, access_info: str, _class: LLM): + self.access_info = ( + access_info # either deployment_url (for SAP) or API key (for Third Party) + ) + self._class = _class + + +env: Dict[str, str | None] = dotenv_values() + +SAP_MAPPING = { + "gpt-35-turbo": ModelDef(env.get("GPT_35_TURBO_URL", None), OpenAI), + "gpt-35-turbo-16k": ModelDef(env.get("GPT_35_TURBO_16K_URL", None), OpenAI), + "gpt-35-turbo-0125": ModelDef(env.get("GPT_35_TURBO_0125_URL", None), OpenAI), + "gpt-4": ModelDef(env.get("GPT_4_URL", None), OpenAI), + "gpt-4-32k": ModelDef(env.get("GPT_4_32K_URL", None), OpenAI), + # "gpt-4-turbo": env.get("GPT_4_TURBO_URL", None), # currently TBD: https://github.tools.sap/I343697/generative-ai-hub-readme + # "gpt-4o": env.get("GPT_4O_URL", None), # currently TBD: https://github.tools.sap/I343697/generative-ai-hub-readme + "gemini-1.0-pro": ModelDef(env.get("GEMINI_1_0_PRO_URL", None), Gemini), + "mistralai--mixtral-8x7b-instruct-v01": ModelDef( + env.get("MISTRALAI_MIXTRAL_8X7B_INSTRUCT_V01", None), Mistral + ), +} + +THIRD_PARTY_MAPPING = { + "gpt-4": ModelDef(env.get("OPENAI_API_KEY", None), ChatOpenAI), + "gpt-3.5-turbo": ModelDef(env.get("OPENAI_API_KEY", None), ChatOpenAI), + "gemini-pro": ModelDef(env.get("GOOGLE_API_KEY", None), ChatVertexAI), + "mistral-large-latest": ModelDef(env.get("MISTRAL_API_KEY", None), ChatMistralAI), +} + + +def create_model_instance(llm_config) -> LLM: + """Creates and returns the model object given the user's configuration. + + Args: + llm_config (dict): A dictionary containing the configuration for the LLM. Expected keys are: + - 'type' (str): Method for accessing the LLM API ('sap' for SAP's AI Core, 'third_party' for + external providers). + - 'model_name' (str): Which model to use, e.g. gpt-4. + - 'temperature' (Optional(float)): The temperature for the model, default 0.0. + + Returns: + LLM: An instance of the specified LLM model. + """ + + def create_sap_provider(model_name: str, temperature: float): + d = SAP_MAPPING.get(model_name, None) + + if d is None: + raise ValueError(f"Model '{model_name}' is not available.") + + model = d._class( + model_name=model_name, + deployment_url=d.access_info, + temperature=temperature, + ) + + return model + + def create_third_party_provider(model_name: str, temperature: float): + # obtain definition from main mapping + d = THIRD_PARTY_MAPPING.get(model_name, None) + + if d is None: + logger.error(f"Model '{model_name}' is not available.") + raise ValueError(f"Model '{model_name}' is not available.") + + model = d._class( + model=model_name, + api_key=d.access_info, + temperature=temperature, + ) + + return model + + if llm_config is None: + raise ValueError( + "When using LLM support, please add necessary parameters to configuration file." + ) + + # LLM Instantiation + try: + match llm_config.type: + case "sap": + model = create_sap_provider( + llm_config.model_name, llm_config.temperature + ) + case "third_party": + model = create_third_party_provider( + llm_config.model_name, llm_config.temperature + ) + case _: + logger.error( + f"Invalid LLM type specified, '{llm_config.type}' is not available." + ) + raise ValueError( + f"Invalid LLM type specified, '{llm_config.type}' is not available." + ) + except Exception as e: + logger.error(f"Problem when initialising model: {e}") + raise ValueError(f"Problem when initialising model: {e}") + + return model diff --git a/prospector/llm/models.py b/prospector/llm/models.py new file mode 100644 index 000000000..cb3fb0d81 --- /dev/null +++ b/prospector/llm/models.py @@ -0,0 +1,193 @@ +import json +from typing import Any, List, Mapping, Optional + +import requests +from dotenv import dotenv_values +from langchain_core.language_models.llms import LLM + +from log.logger import logger + + +class SAPProvider(LLM): + model_name: str + deployment_url: str + temperature: float + + @property + def _llm_type(self) -> str: + return "custom" + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return { + "model_name": self.model_name, + } + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> str: + """Run the LLM on the given input. + + Override this method to implement the LLM logic. + + Args: + prompt: The prompt to generate from. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of the stop substrings. + If stop tokens are not supported consider raising NotImplementedError. + run_manager: Callback manager for the run. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + The model output as a string. Actual completions SHOULD NOT include the prompt. + """ + if self.deployment_url is None: + raise ValueError( + "Deployment URL not set. Maybe you forgot to create the environment variable." + ) + if stop is not None: + raise ValueError("stop kwargs are not permitted.") + return "" + + +class OpenAI(SAPProvider): + def _call( + self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any + ) -> str: + # Call super() to make sure model_name is valid + super()._call(prompt, stop, **kwargs) + # Model specific request data + endpoint = f"{self.deployment_url}/chat/completions?api-version=2023-05-15" + headers = get_headers() + data = { + "messages": [ + { + "role": "user", + "content": f"{prompt}", + } + ], + "temperature": self.temperature, + } + + response = requests.post(endpoint, headers=headers, json=data) + + if not response.status_code == 200: + logger.error( + f"Invalid response from AI Core API with error code {response.status_code}" + ) + raise Exception("Invalid response from AI Core API.") + + return self.parse(response.json()) + + def parse(self, message) -> str: + """Parse the returned JSON object from OpenAI.""" + return message["choices"][0]["message"]["content"] + + +class Gemini(SAPProvider): + def _call( + self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any + ) -> str: + # Call super() to make sure model_name is valid + super()._call(prompt, stop, **kwargs) + # Model specific request data + endpoint = f"{self.deployment_url}/models/{self.model_name}:generateContent" + headers = get_headers() + data = { + "generation_config": { + "maxOutputTokens": 1000, + "temperature": self.temperature, + }, + "contents": [{"role": "user", "parts": [{"text": prompt}]}], + "safetySettings": [ + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE", + }, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, + ], + } + + response = requests.post(endpoint, headers=headers, json=data) + + if not response.status_code == 200: + logger.error( + f"Invalid response from AI Core API with error code {response.status_code}" + ) + raise Exception("Invalid response from AI Core API.") + + return self.parse(response.json()) + + def parse(self, message) -> str: + """Parse the returned JSON object from OpenAI.""" + return message["candidates"][0]["content"]["parts"][0]["text"] + + +class Mistral(SAPProvider): + def _call( + self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any + ) -> str: + # Call super() to make sure model_name is valid + super()._call(prompt, stop, **kwargs) + # Model specific request data + endpoint = f"{self.deployment_url}/chat/completions" + headers = get_headers() + data = { + "model": "mistralai--mixtral-8x7b-instruct-v01", + "max_tokens": 100, + "temperature": self.temperature, + "messages": [{"role": "user", "content": prompt}], + } + + response = requests.post(endpoint, headers=headers, json=data) + + if not response.status_code == 200: + logger.error( + f"Invalid response from AI Core API with error code {response.status_code}" + ) + raise Exception("Invalid response from AI Core API.") + + return self.parse(response.json()) + + def parse(self, message) -> str: + """Parse the returned JSON object from OpenAI.""" + return message["choices"][0]["message"]["content"] + + +def get_headers(): + """Generate the request headers to use SAP AI Core. This method generates the authentication token and returns a Dict with headers. + + Returns: + The headers object needed to send requests to the SAP AI Core. + """ + with open(dotenv_values()["AI_CORE_KEY_FILEPATH"]) as f: + sk = json.load(f) + + auth_url = f"{sk['url']}/oauth/token" + client_id = sk["clientid"] + client_secret = sk["clientsecret"] + # api_base_url = f"{sk['serviceurls']['AI_API_URL']}/v2" + + response = requests.post( + auth_url, + data={"grant_type": "client_credentials"}, + auth=(client_id, client_secret), + timeout=8000, + ) + + headers = { + "AI-Resource-Group": "default", + "Content-Type": "application/json", + "Authorization": f"Bearer {response.json()['access_token']}", + } + return headers diff --git a/prospector/llm/operations.py b/prospector/llm/operations.py new file mode 100644 index 000000000..f157e8590 --- /dev/null +++ b/prospector/llm/operations.py @@ -0,0 +1,60 @@ +import sys +from typing import Dict + +import validators +from langchain_core.language_models.llms import LLM + +from cli.console import ConsoleWriter, MessageStatus +from datamodel.advisory import get_from_mitre +from llm.prompts import best_guess +from log.logger import logger + + +def get_repository_url(model: LLM, vuln_id: str): + """Ask an LLM to obtain the repository URL given the advisory description and references. + + Args: + model (LLM): The instantiated model (instantiated with create_model_instance()) + vuln_id: The ID of the advisory, e.g. CVE-2020-1925. + + Returns: + The repository URL as a string. + + Raises: + ValueError if advisory information cannot be obtained or there is an error in the model invocation. + """ + with ConsoleWriter("Invoking LLM") as console: + details, _ = get_from_mitre(vuln_id) + if details is None: + logger.error("Error when getting advisory information from Mitre.") + console.print( + "Error when getting advisory information from Mitre.", + status=MessageStatus.ERROR, + ) + sys.exit(1) + + try: + chain = best_guess | model + + url = chain.invoke( + { + "description": details["descriptions"][0]["value"], + "references": details["references"], + } + ) + if not validators.url(url): + logger.error(f"LLM returned invalid URL: {url}") + console.print( + f"LLM returned invalid URL: {url}", + status=MessageStatus.ERROR, + ) + sys.exit(1) + except Exception as e: + logger.error(f"Prompt-model chain could not be invoked: {e}") + console.print( + "Prompt-model chain could not be invoked.", + status=MessageStatus.ERROR, + ) + sys.exit(1) + + return url diff --git a/prospector/llm/prompts.py b/prospector/llm/prompts.py new file mode 100644 index 000000000..57fd2444a --- /dev/null +++ b/prospector/llm/prompts.py @@ -0,0 +1,45 @@ +from langchain.prompts import FewShotPromptTemplate, PromptTemplate + +# example output for few-shot prompting +examples_without_num = [ + { + "cve_description": "Apache Olingo versions 4.0.0 to 4.7.0 provide the AsyncRequestWrapperImpl class which reads a URL from the Location header, and then sends a GET or DELETE request to this URL. It may allow to implement a SSRF attack. If an attacker tricks a client to connect to a malicious server, the server can make the client call any URL including internal resources which are not directly accessible by the attacker.", + "cve_references": "https://www.zerodayinitiative.com/advisories/ZDI-24-196/", + "result": "https://github.com/apache/olingo-odata4", + }, + { + "cve_description": "Open-source project Online Shopping System Advanced is vulnerable to Reflected Cross-Site Scripting (XSS). An attacker might trick somebody into using a crafted URL, which will cause a script to be run in user's browser.", + "cve_references": "https://cert.pl/en/posts/2024/05/CVE-2024-3579, https://cert.pl/posts/2024/05/CVE-2024-3579", + "result": "https://github.com/PuneethReddyHC/online-shopping-system-advanced", + }, + { + "cve_description": "The Hoppscotch Browser Extension is a browser extension for Hoppscotch, a community-driven end-to-end open-source API development ecosystem. Due to an oversight during a change made to the extension in the commit d4e8e4830326f46ba17acd1307977ecd32a85b58, a critical check for the origin list was missed and allowed for messages to be sent to the extension which the extension gladly processed and responded back with the results of, while this wasn't supposed to happen and be blocked by the origin not being present in the origin list.\n\nThis vulnerability exposes Hoppscotch Extension users to sites which call into Hoppscotch Extension APIs internally. This fundamentally allows any site running on the browser with the extension installed to bypass CORS restrictions if the user is running extensions with the given version. This security hole was patched in the commit 7e364b928ab722dc682d0fcad713a96cc38477d6 which was released along with the extension version `0.35`. As a workaround, Chrome users can use the Extensions Settings to disable the extension access to only the origins that you want. Firefox doesn't have an alternative to upgrading to a fixed version.", + "cve_references": "https://github.com/hoppscotch/hoppscotch-extension/commit/7e364b928ab722dc682d0fcad713a96cc38477d6, https://github.com/hoppscotch/hoppscotch-extension/commit/d4e8e4830326f46ba17acd1307977ecd32a85b58, https://github.com/hoppscotch/hoppscotch-extension/security/advisories/GHSA-jjh5-pvqx-gg5v, https://server.yadhu.in/poc/hoppscotch-poc.html", + "result": "https://github.com/hoppscotch/hoppscotch-extension", + }, +] + +# Formatter for the few-shot examples without CVE numbers +examples_prompt_without_num = PromptTemplate( + input_variables=["cve_references", "result"], + template=""" {cve_description} + {cve_references} + + {result} """, +) + +best_guess = FewShotPromptTemplate( + prefix="""You will be provided with the ID, description and references of a vulnerability advisory (CVE). Return nothing but the URL of the repository the given CVE is concerned with.'. + +Here are a few examples delimited with XML tags:""", + examples=examples_without_num, + example_prompt=examples_prompt_without_num, + suffix="""Here is the CVE information: + {description} + {references} + +If you cannot find the URL, return your best guess of what the repository URL could be. Use any hints (eg. the mention of GitHub or GitLab) in the CVE description and references. Return nothing but the URL. +""", + input_variables=["description", "references"], + metadata={"name": "best_guess"}, +) diff --git a/prospector/llm/test_llm.py b/prospector/llm/test_llm.py new file mode 100644 index 000000000..9ef091659 --- /dev/null +++ b/prospector/llm/test_llm.py @@ -0,0 +1,48 @@ +import pytest +import requests +from langchain_openai import ChatOpenAI + +from llm.models import Gemini, Mistral, OpenAI +from llm.operations import create_model_instance, get_repository_url + + +# Mock the llm_service configuration object +class Config: + type: str = None + model_name: str = None + + def __init__(self, type, model_name): + self.type = type + self.model_name = model_name + + +# Vulnerability ID +vuln_id = "CVE-2024-32480" + + +class TestModel: + def test_sap_gpt35_instantiation(self): + config = Config("sap", "gpt-35-turbo") + model = create_model_instance(config) + assert isinstance(model, OpenAI) + + def test_sap_gpt4_instantiation(self): + config = Config("sap", "gpt-4") + model = create_model_instance(config) + assert isinstance(model, OpenAI) + + def test_thirdparty_gpt35_instantiation(self): + config = Config("third_party", "gpt-3.5-turbo") + model = create_model_instance(config) + assert isinstance(model, ChatOpenAI) + + def test_thirdparty_gpt4_instantiation(self): + config = Config("third_party", "gpt-4") + model = create_model_instance(config) + assert isinstance(model, ChatOpenAI) + + def test_invoke_fail(self): + with pytest.raises(SystemExit): + config = Config("sap", "gpt-35-turbo") + vuln_id = "random" + get_repository_url(llm_config=config, vuln_id=vuln_id) diff --git a/prospector/pyproject.toml b/prospector/pyproject.toml index 3511de431..87fd3e89b 100644 --- a/prospector/pyproject.toml +++ b/prospector/pyproject.toml @@ -9,7 +9,8 @@ testpaths = [ "api", "filtering", "stats", - "util" + "util", + "llm", ] [tool.isort] diff --git a/prospector/requirements.in b/prospector/requirements.in index bf42e07cb..6d7d7f4b3 100644 --- a/prospector/requirements.in +++ b/prospector/requirements.in @@ -1,19 +1,27 @@ -beautifulsoup4==4.11.1 -colorama==0.4.6 -datasketch==1.5.8 -fastapi==0.85.1 -Jinja2==3.1.2 -pandas==1.5.1 -plac==1.3.5 -psycopg2==2.9.5 -pydantic==1.10.2 -pytest==7.2.0 -python-dotenv==0.21.0 -python_dateutil==2.8.2 -redis==4.3.4 -requests==2.28.1 + +beautifulsoup4 +colorama +datasketch +fastapi +Jinja2 +langchain +langchain_openai +langchain_google_vertexai +langchain_mistralai +langchain_community +omegaconf +pandas +plac +psycopg2 +pydantic +pytest +python_dateutil +python-dotenv +redis +requests requests_cache==0.9.6 -rq==1.11.1 -spacy==3.4.2 -tqdm==4.64.1 -uvicorn==0.19.0 +rq +spacy +tqdm +uvicorn +validators diff --git a/prospector/requirements.txt b/prospector/requirements.txt index 16385f316..dc864b5d7 100644 --- a/prospector/requirements.txt +++ b/prospector/requirements.txt @@ -1,76 +1,153 @@ # -# This file is autogenerated by pip-compile with python 3.10 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: # # pip-compile --no-annotate --strip-extras # -anyio==3.6.2 +--extra-index-url https://int.repositories.cloud.sap/artifactory/api/pypi/deploy-releases-pypi/simple +--extra-index-url https://int.repositories.cloud.sap/artifactory/api/pypi/proxy-deploy-releases-hyperspace-pypi/simple +--trusted-host int.repositories.cloud.sap + +aiohttp==3.9.5 +aiosignal==1.3.1 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio==4.4.0 appdirs==1.4.4 -argparse==1.4.0 -async-timeout==4.0.2 -attrs==22.1.0 -beautifulsoup4==4.11.1 -blis==0.7.9 -catalogue==2.0.8 -cattrs==22.2.0 -certifi==2022.9.24 -charset-normalizer==2.1.1 -click==8.1.3 +async-timeout==4.0.3 +attrs==23.2.0 +beautifulsoup4==4.12.3 +blis==0.7.11 +cachetools==5.3.3 +catalogue==2.0.10 +cattrs==23.2.3 +certifi==2024.6.2 +charset-normalizer==3.3.2 +click==8.1.7 +cloudpathlib==0.18.1 colorama==0.4.6 -confection==0.0.3 -cymem==2.0.7 -datasketch==1.5.8 -deprecated==1.2.13 -exceptiongroup==1.0.0rc9 -fastapi==0.85.1 +confection==0.1.5 +cymem==2.0.8 +dataclasses-json==0.6.6 +datasketch==1.6.5 +distro==1.9.0 +dnspython==2.6.1 +docstring-parser==0.16 +email-validator==2.1.1 +exceptiongroup==1.2.1 +fastapi==0.111.0 +fastapi-cli==0.0.4 +filelock==3.14.0 +frozenlist==1.4.1 +fsspec==2024.6.0 +google-api-core==2.19.0 +google-auth==2.29.0 +google-cloud-aiplatform==1.53.0 +google-cloud-bigquery==3.24.0 +google-cloud-core==2.4.1 +google-cloud-resource-manager==1.12.3 +google-cloud-storage==2.16.0 +google-crc32c==1.5.0 +google-resumable-media==2.7.0 +googleapis-common-protos==1.63.1 +greenlet==3.0.3 +grpc-google-iam-v1==0.13.0 +grpcio==1.64.1 +grpcio-status==1.62.2 h11==0.14.0 -idna==3.4 -iniconfig==1.1.1 -jinja2==3.1.2 -langcodes==3.3.0 -markupsafe==2.1.1 -murmurhash==1.0.9 -numpy==1.23.4 -packaging==21.3 -pandas==1.5.1 -pathy==0.6.2 -plac==1.3.5 -pluggy==1.0.0 -preshed==3.0.8 -psycopg2==2.9.5 -pydantic==1.10.2 -pyparsing==3.0.9 -pytest==7.2.0 -python-dateutil==2.8.2 -python-dotenv==0.21.0 -pytz==2022.5 -redis==4.3.4 -requests==2.28.1 +httpcore==1.0.5 +httptools==0.6.1 +httpx==0.27.0 +httpx-sse==0.4.0 +huggingface-hub==0.23.3 +idna==3.7 +iniconfig==2.0.0 +jinja2==3.1.4 +jsonpatch==1.33 +jsonpointer==2.4 +langchain==0.2.2 +langchain-community==0.2.3 +langchain-core==0.2.4 +langchain-google-vertexai==1.0.5 +langchain-mistralai==0.1.8 +langchain-openai==0.1.8 +langchain-text-splitters==0.2.1 +langcodes==3.4.0 +langsmith==0.1.74 +language-data==1.2.0 +marisa-trie==1.2.0 +markdown-it-py==3.0.0 +markupsafe==2.1.5 +marshmallow==3.21.3 +mdurl==0.1.2 +multidict==6.0.5 +murmurhash==1.0.10 +mypy-extensions==1.0.0 +numpy==1.26.4 +omegaconf==2.3.0 +openai==1.31.1 +orjson==3.10.3 +packaging==23.2 +pandas==2.2.2 +plac==1.4.3 +pluggy==1.5.0 +preshed==3.0.9 +proto-plus==1.23.0 +protobuf==4.25.3 +psycopg2==2.9.9 +pyasn1==0.6.0 +pyasn1-modules==0.4.0 +pydantic==2.7.3 +pydantic-core==2.18.4 +pygments==2.18.0 +pytest==8.2.2 +python-dateutil==2.9.0.post0 +python-dotenv==1.0.1 +python-multipart==0.0.9 +pytz==2024.1 +pyyaml==6.0.1 +redis==5.0.5 +regex==2024.5.15 +requests==2.32.3 requests-cache==0.9.6 -rq==1.11.1 -scipy==1.9.3 +rich==13.7.1 +rq==1.16.2 +rsa==4.9 +scipy==1.13.1 +shapely==2.0.4 +shellingham==1.5.4 six==1.16.0 -smart-open==5.2.1 -sniffio==1.3.0 -soupsieve==2.3.2.post1 -spacy==3.4.2 -spacy-legacy==3.0.10 -spacy-loggers==1.0.3 -srsly==2.4.5 -starlette==0.20.4 -thinc==8.1.5 +smart-open==7.0.4 +sniffio==1.3.1 +soupsieve==2.5 +spacy==3.7.5 +spacy-legacy==3.0.12 +spacy-loggers==1.0.5 +sqlalchemy==2.0.30 +srsly==2.4.8 +starlette==0.37.2 +tenacity==8.3.0 +thinc==8.2.4 +tiktoken==0.7.0 +tokenizers==0.19.1 tomli==2.0.1 -tqdm==4.64.1 -typer==0.4.2 -typing-extensions==4.4.0 +tqdm==4.66.4 +typer==0.12.3 +typing-extensions==4.12.1 +typing-inspect==0.9.0 +tzdata==2024.1 +ujson==5.10.0 url-normalize==1.4.3 -urllib3==1.26.12 -uvicorn==0.19.0 -validators==0.20.0 -wasabi==0.10.1 -wrapt==1.14.1 -python-multipart==0.0.5 -omegaconf==2.2.3 +urllib3==2.2.1 +uvicorn==0.30.1 +uvloop==0.19.0 +validators==0.28.3 +wasabi==1.1.3 +watchfiles==0.22.0 +weasel==0.4.1 +websockets==12.0 +wrapt==1.16.0 +yarl==1.9.4 # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/prospector/service/api/routers/endpoints.py b/prospector/service/api/routers/endpoints.py index d5cb0ff5c..9f446f86c 100644 --- a/prospector/service/api/routers/endpoints.py +++ b/prospector/service/api/routers/endpoints.py @@ -3,7 +3,6 @@ from datetime import datetime import redis -from api.rq_utils import get_all_jobs, queue from fastapi import APIRouter, FastAPI, Request from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates @@ -12,6 +11,7 @@ from starlette.responses import RedirectResponse from data_sources.nvd.job_creation import run_prospector +from service.api.rq_utils import get_all_jobs, queue from util.config_parser import parse_config_file # from core.report import generate_report diff --git a/prospector/service/api/routers/home.py b/prospector/service/api/routers/home.py index 46d80cdd6..020ea6469 100644 --- a/prospector/service/api/routers/home.py +++ b/prospector/service/api/routers/home.py @@ -1,14 +1,15 @@ -from api.rq_utils import queue, get_all_jobs -from fastapi import FastAPI, Request -from fastapi import APIRouter -from fastapi.templating import Jinja2Templates +import time + +import redis +from fastapi import APIRouter, FastAPI, Request from fastapi.responses import HTMLResponse -from starlette.responses import RedirectResponse -from util.config_parser import parse_config_file +from fastapi.templating import Jinja2Templates from rq import Connection, Queue from rq.job import Job -import redis -import time +from starlette.responses import RedirectResponse + +from service.api.rq_utils import get_all_jobs, queue +from util.config_parser import parse_config_file # from core.report import generate_report diff --git a/prospector/service/api/routers/jobs.py b/prospector/service/api/routers/jobs.py index 77f0f7667..0f759e485 100644 --- a/prospector/service/api/routers/jobs.py +++ b/prospector/service/api/routers/jobs.py @@ -5,9 +5,9 @@ from rq import Connection, Queue from rq.job import Job -from api.routers.nvd_feed_update import main from git.git import do_clone from log.logger import logger +from service.api.routers.nvd_feed_update import main from util.config_parser import parse_config_file config = parse_config_file() diff --git a/prospector/service/main.py b/prospector/service/main.py index 1756c2984..c33f41d1f 100644 --- a/prospector/service/main.py +++ b/prospector/service/main.py @@ -2,12 +2,13 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse, RedirectResponse +from fastapi.staticfiles import StaticFiles -# from .dependencies import oauth2_scheme -from api.routers import jobs, nvd, preprocessed, users, endpoints, home from log.logger import logger + +# from .dependencies import oauth2_scheme +from service.api.routers import endpoints, home, jobs, nvd, preprocessed, users from util.config_parser import parse_config_file -from fastapi.staticfiles import StaticFiles api_metadata = [ {"name": "data", "description": "Operations with data used to train ML models."}, diff --git a/prospector/util/config_parser.py b/prospector/util/config_parser.py index 49b59e840..281b17bb3 100644 --- a/prospector/util/config_parser.py +++ b/prospector/util/config_parser.py @@ -1,15 +1,23 @@ import argparse import os import sys -from dataclasses import dataclass +from dataclasses import MISSING, dataclass +from typing import Optional from omegaconf import OmegaConf +from omegaconf.errors import ( + ConfigAttributeError, + ConfigKeyError, + ConfigTypeError, + MissingMandatoryValue, +) from log.logger import logger def parse_cli_args(args): parser = argparse.ArgumentParser(description="Prospector CLI") + parser.add_argument( "vuln_id", nargs="?", @@ -17,7 +25,9 @@ def parse_cli_args(args): help="ID of the vulnerability to analyze", ) - parser.add_argument("--repository", default="", type=str, help="Git repository url") + parser.add_argument( + "--repository", default=None, type=str, help="Git repository url" + ) parser.add_argument( "--preprocess-only", @@ -27,6 +37,7 @@ def parse_cli_args(args): parser.add_argument("--pub-date", type=str, help="Publication date of the advisory") + # Allow the user to manually supply advisory description parser.add_argument("--description", type=str, help="Advisory description") parser.add_argument( @@ -81,7 +92,6 @@ def parse_cli_args(args): parser.add_argument( "--use-backend", - default="always", choices=["always", "never", "optional"], type=str, help="Use the backend server", @@ -131,12 +141,72 @@ def parse_cli_args(args): def parse_config_file(filename: str = "config.yaml"): if os.path.isfile(filename): logger.info(f"Loading configuration from {filename}") + schema = OmegaConf.structured(ConfigSchema) config = OmegaConf.load(filename) - return config + try: + merged_config = OmegaConf.merge(schema, config) + return merged_config + except ConfigAttributeError as e: + logger.error(f"Attribute error in {filename}: {e}") + except ConfigKeyError as e: + logger.error(f"Key error in {filename}: {e}") + except ConfigTypeError as e: + logger.error(f"Type error in {filename}: {e}") + except Exception as e: + # General exception catch block for any other exceptions + logger.error(f"An unexpected error occurred when parsing config.yaml: {e}") + else: + logger.error("No configuration file found, cannot proceed.") + + +# Schema class for "database" configuration +@dataclass +class DatabaseConfig: + user: str + password: str + host: str + port: int + dbname: str + + +# Schema class for "report" configuration +@dataclass +class ReportConfig: + format: str + name: str + - return None +# Schema class for "llm_service" configuration +@dataclass +class LLMServiceConfig: + type: str + model_name: str + temperature: float = 0.0 + + +# Schema class for config.yaml parameters +@dataclass +class ConfigSchema: + redis_url: str = MISSING + preprocess_only: bool = MISSING + max_candidates: int = MISSING + fetch_references: bool = MISSING + use_nvd: bool = MISSING + use_backend: str = MISSING + backend: str = MISSING + use_llm_repository_url: bool = MISSING + report: ReportConfig = MISSING + log_level: str = MISSING + git_cache: str = MISSING + nvd_token: Optional[str] = None + database: DatabaseConfig = DatabaseConfig( + user="postgres", password="example", host="db", port=5432, dbname="postgres" + ) + llm_service: Optional[LLMServiceConfig] = None + github_token: Optional[str] = None +# Prospector's own Configuration object (combining args and config.yaml) @dataclass class Config: def __init__( @@ -154,17 +224,21 @@ def __init__( keywords: str, use_nvd: bool, fetch_references: bool, - backend: str, use_backend: str, - report: str, + backend: str, + use_llm_repository_url: bool, + report: ReportConfig, report_filename: str, ping: bool, log_level: str, git_cache: str, ignore_refs: bool, + llm_service: LLMServiceConfig, ): self.vuln_id = vuln_id self.repository = repository + self.use_llm_repository_url = use_llm_repository_url + self.llm_service = llm_service self.preprocess_only = preprocess_only self.pub_date = pub_date self.description = description @@ -190,27 +264,39 @@ def get_configuration(argv): args = parse_cli_args(argv) conf = parse_config_file(args.config) if conf is None: - sys.exit("No configuration file found") - return Config( - vuln_id=args.vuln_id, - repository=args.repository, - preprocess_only=args.preprocess_only or conf.preprocess_only, - pub_date=args.pub_date, - description=args.description, - modified_files=args.modified_files, - keywords=args.keywords, - max_candidates=args.max_candidates or conf.max_candidates, - # tag_interval=args.tag_interval, - version_interval=args.version_interval, - filter_extensions=args.filter_extensions, - use_nvd=args.use_nvd or conf.use_nvd, - fetch_references=args.fetch_references or conf.fetch_references, - backend=args.backend or conf.backend, - use_backend=args.use_backend or conf.use_backend, - report=args.report or conf.report.format, - report_filename=args.report_filename or conf.report.name, - ping=args.ping, - git_cache=conf.git_cache, - log_level=args.log_level or conf.log_level, - ignore_refs=args.ignore_refs, - ) + sys.exit( + "No configuration file found, or error in configuration file. Check logs." + ) + try: + config = Config( + vuln_id=args.vuln_id, + repository=args.repository, + use_llm_repository_url=conf.use_llm_repository_url, + llm_service=conf.llm_service, + preprocess_only=args.preprocess_only or conf.preprocess_only, + pub_date=args.pub_date, + description=args.description, + modified_files=args.modified_files, + keywords=args.keywords, + max_candidates=args.max_candidates or conf.max_candidates, + # tag_interval=args.tag_interval, + version_interval=args.version_interval, + filter_extensions=args.filter_extensions, + use_nvd=args.use_nvd or conf.use_nvd, + fetch_references=args.fetch_references or conf.fetch_references, + backend=args.backend or conf.backend, + use_backend=args.use_backend or conf.use_backend, + report=args.report or conf.report.format, + report_filename=args.report_filename or conf.report.name, + ping=args.ping, + git_cache=conf.git_cache, + log_level=args.log_level or conf.log_level, + ignore_refs=args.ignore_refs, + ) + return config + except MissingMandatoryValue as e: + logger.error(e) + sys.exit(f"'{e.full_key}' is missing in {args.config}.") + except Exception as e: + logger.error(f"Error in {args.config}: {e}.") + sys.exit(f"Error in {args.config}. Check logs.") diff --git a/prospector/util/http.py b/prospector/util/http.py index cfdb78f7d..5f3678594 100644 --- a/prospector/util/http.py +++ b/prospector/util/http.py @@ -9,6 +9,20 @@ def fetch_url(url: str, params=None, extract_text=True) -> Union[str, BeautifulSoup]: + """ + Fetches the content of a web page located at the specified URL and optionally extracts text from it. + + Parameters: + - url (str): The URL of the web page to fetch. + - params (dict, optional): Optional parameters to be sent with the request (default: None). + - extract_text (bool, optional): Whether to extract text content from the HTML (default: True). + + Returns: + - Union[str, BeautifulSoup]: If `extract_text` is True, returns the text content of the web page as a string. + If `extract_text` is False, returns the parsed HTML content as a BeautifulSoup object. + + If an exception occurs during the HTTP request, an empty string ("") is returned. + """ try: session = requests_cache.CachedSession("requests-cache", expire_after=604800) if params is None: @@ -17,7 +31,7 @@ def fetch_url(url: str, params=None, extract_text=True) -> Union[str, BeautifulS content = session.get(url, params=params).content except Exception: logger.debug(f"cannot retrieve url content: {url}", exc_info=True) - return "" + return None soup = BeautifulSoup(content, "html.parser") if extract_text: