diff --git a/conda/environments/all_cuda-121_arch-x86_64.yaml b/conda/environments/all_cuda-121_arch-x86_64.yaml index 3b310995fb..f320d3ac86 100644 --- a/conda/environments/all_cuda-121_arch-x86_64.yaml +++ b/conda/environments/all_cuda-121_arch-x86_64.yaml @@ -13,7 +13,7 @@ dependencies: - appdirs - arxiv=1.4 - automake -- beautifulsoup4 +- beautifulsoup4=4.12 - benchmark=1.8.3 - boost-cpp=1.84 - boto3 @@ -69,7 +69,7 @@ dependencies: - numexpr - numpydoc=1.5 - nvtabular=23.08.00 -- onnx +- onnx=1.15 - openai=1.13 - papermill=2.4.0 - pip @@ -95,11 +95,11 @@ dependencies: - rdma-core>=48 - requests - requests-cache=1.1 -- requests-toolbelt +- requests-toolbelt=1.0 - s3fs=2023.12.2 - scikit-build=0.17.6 - scikit-learn=1.3.2 -- sentence-transformers +- sentence-transformers=2.7 - sphinx - sphinx_rtd_theme - sqlalchemy @@ -117,15 +117,18 @@ dependencies: - pip: - --find-links https://data.dgl.ai/wheels-test/repo.html - --find-links https://data.dgl.ai/wheels/cu121/repo.html + - PyMuPDF==1.23.* - PyMuPDF==1.23.21 - databricks-cli < 0.100 - databricks-connect - dgl==2.0.0 - dglgo - - google-search-results==2.4 - - langchain==0.1.9 + - faiss-gpu==1.7.* + - google-search-results==2.4.* + - langchain-nvidia-ai-endpoints==0.0.11 + - langchain==0.1.16 - milvus==2.3.5 - - nemollm + - nemollm==0.3.5 - pymilvus==2.3.6 - pytest-kafka==0.6.0 name: all_cuda-121_arch-x86_64 diff --git a/conda/environments/dev_cuda-121_arch-x86_64.yaml b/conda/environments/dev_cuda-121_arch-x86_64.yaml index 23ff2c707e..55ada60795 100644 --- a/conda/environments/dev_cuda-121_arch-x86_64.yaml +++ b/conda/environments/dev_cuda-121_arch-x86_64.yaml @@ -11,7 +11,6 @@ channels: dependencies: - appdirs - automake -- beautifulsoup4 - benchmark=1.8.3 - boost-cpp=1.84 - breathe=4.35.0 @@ -78,7 +77,6 @@ dependencies: - rdma-core>=48 - requests - requests-cache=1.1 -- requests-toolbelt - scikit-build=0.17.6 - scikit-learn=1.3.2 - sphinx diff --git a/conda/environments/examples_cuda-121_arch-x86_64.yaml b/conda/environments/examples_cuda-121_arch-x86_64.yaml index 11d5e535ce..cda5d37df4 100644 --- a/conda/environments/examples_cuda-121_arch-x86_64.yaml +++ b/conda/environments/examples_cuda-121_arch-x86_64.yaml @@ -12,7 +12,7 @@ dependencies: - anyio>=3.7 - appdirs - arxiv=1.4 -- beautifulsoup4 +- beautifulsoup4=4.12 - boto3 - click >=8 - cuml=24.02.* @@ -35,7 +35,7 @@ dependencies: - numexpr - numpydoc=1.5 - nvtabular=23.08.00 -- onnx +- onnx=1.15 - openai=1.13 - papermill=2.4.0 - pip @@ -48,10 +48,10 @@ dependencies: - pytorch=*=*cuda* - requests - requests-cache=1.1 -- requests-toolbelt +- requests-toolbelt=1.0 - s3fs=2023.12.2 - scikit-learn=1.3.2 -- sentence-transformers +- sentence-transformers=2.7 - sqlalchemy - tqdm=4 - transformers=4.36.2 @@ -61,14 +61,16 @@ dependencies: - pip: - --find-links https://data.dgl.ai/wheels-test/repo.html - --find-links https://data.dgl.ai/wheels/cu121/repo.html - - PyMuPDF==1.23.21 + - PyMuPDF==1.23.* - databricks-cli < 0.100 - databricks-connect - dgl==2.0.0 - dglgo - - google-search-results==2.4 - - langchain==0.1.9 + - faiss-gpu==1.7.* + - google-search-results==2.4.* + - langchain-nvidia-ai-endpoints==0.0.11 + - langchain==0.1.16 - milvus==2.3.5 - - nemollm + - nemollm==0.3.5 - pymilvus==2.3.6 name: examples_cuda-121_arch-x86_64 diff --git a/conda/environments/runtime_cuda-121_arch-x86_64.yaml b/conda/environments/runtime_cuda-121_arch-x86_64.yaml index 80f6f995d2..e6b76b43aa 100644 --- a/conda/environments/runtime_cuda-121_arch-x86_64.yaml +++ b/conda/environments/runtime_cuda-121_arch-x86_64.yaml @@ -10,7 +10,6 @@ channels: - pytorch dependencies: - appdirs -- beautifulsoup4 - click >=8 - datacompy=0.10 - dill=0.3.7 @@ -30,7 +29,6 @@ dependencies: - pytorch=*=*cuda* - requests - requests-cache=1.1 -- requests-toolbelt - scikit-learn=1.3.2 - sqlalchemy - tqdm=4 diff --git a/dependencies.yaml b/dependencies.yaml index 7f1f9145ef..8d41be4f50 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -32,10 +32,7 @@ files: - docs - example-dfp-prod - example-gnn - - example-llm-agents - - example-llm-completion - - example-llm-rag - - example-llm-vdb-upload + - example-llms - python - runtime - test_python_morpheus @@ -86,10 +83,7 @@ files: - development - example-dfp-prod - example-gnn - - example-llm-agents - - example-llm-completion - - example-llm-rag - - example-llm-vdb-upload + - example-llms - python - runtime - test_python_morpheus @@ -107,10 +101,7 @@ files: - docs - example-dfp-prod - example-gnn - - example-llm-agents - - example-llm-completion - - example-llm-rag - - example-llm-vdb-upload + - example-llms - python - runtime @@ -132,10 +123,7 @@ files: - cve-mitigation - example-dfp-prod - example-gnn - - example-llm-agents - - example-llm-completion - - example-llm-rag - - example-llm-vdb-upload + - example-llms - python - runtime @@ -249,7 +237,6 @@ dependencies: - &dill dill=0.3.7 - &scikit-learn scikit-learn=1.3.2 - appdirs - - beautifulsoup4 - datacompy=0.10 - elasticsearch==8.9.0 - feedparser=6.0.10 @@ -264,7 +251,6 @@ dependencies: - pytorch=*=*cuda* - requests - requests-cache=1.1 - - requests-toolbelt # Transitive dep needed by nemollm, specified here to ensure we get a compatible version - sqlalchemy - tqdm=4 - typing_utils=0.1 @@ -318,55 +304,32 @@ dependencies: - dgl==2.0.0 - dglgo - example-llm-agents: + example-llms: common: - output_types: [conda] packages: - - &grpcio-status grpcio-status==1.59 - &transformers transformers=4.36.2 # newer versions are incompatible with our pinned version of huggingface_hub - - huggingface_hub=0.20.2 # work-around for https://github.com/UKPLab/sentence-transformers/issues/1762 - - numexpr - - sentence-transformers - - pip - - pip: - - &langchain langchain==0.1.9 - - nemollm - - example-llm-completion: - common: - - output_types: [conda] - packages: - - *grpcio-status - - &arxiv arxiv=1.4 - - &newspaper3k newspaper3k=0.2 - - &pypdf pypdf=3.17.4 - - example-llm-rag: - common: - - output_types: [conda] - packages: - - *grpcio-status - anyio>=3.7 + - arxiv=1.4 + - beautifulsoup4=4.12 + - grpcio-status==1.59 + - huggingface_hub=0.20.2 # work-around for https://github.com/UKPLab/sentence-transformers/issues/1762 - jsonpatch>=1.33 + - newspaper3k=0.2 + - numexpr + - onnx=1.15 - openai=1.13 + - pypdf=3.17.4 + - requests-toolbelt=1.0 # Transitive dep needed by nemollm, specified here to ensure we get a compatible version + - sentence-transformers=2.7 - pip - pip: - - *langchain - - google-search-results==2.4 - - example-llm-vdb-upload: - common: - - output_types: [conda] - packages: - - *arxiv - - *grpcio-status - - *newspaper3k - - *pypdf - - onnx - - pip - - pip: - - PyMuPDF==1.23.21 - - *langchain + - faiss-gpu==1.7.* + - google-search-results==2.4.* + - langchain-nvidia-ai-endpoints==0.0.11 + - langchain==0.1.16 + - nemollm==0.3.5 + - PyMuPDF==1.23.* model-training-tuning: common: diff --git a/morpheus.code-workspace b/morpheus.code-workspace index f81d904f8e..9e25e5b414 100644 --- a/morpheus.code-workspace +++ b/morpheus.code-workspace @@ -695,7 +695,7 @@ "tests" ], "python.testing.pytestEnabled": true, - "python.testing.unittestEnabled": true, + "python.testing.unittestEnabled": false, "rewrap.wrappingColumn": 120, "testMate.cpp.debug.configTemplate": { "args": "${argsArray}", diff --git a/morpheus/llm/services/nvfoundation_llm_service.py b/morpheus/llm/services/nvfoundation_llm_service.py index 2f95e15900..62bc355662 100644 --- a/morpheus/llm/services/nvfoundation_llm_service.py +++ b/morpheus/llm/services/nvfoundation_llm_service.py @@ -13,16 +13,18 @@ # limitations under the License. import logging -import os +import typing from morpheus.llm.services.llm_service import LLMClient from morpheus.llm.services.llm_service import LLMService +from morpheus.utils.env_config_value import EnvConfigValue logger = logging.getLogger(__name__) IMPORT_EXCEPTION = None IMPORT_ERROR_MESSAGE = ( - "The `langchain-nvidia-ai-endpoints` package was not found. Install it and other additional dependencies by running the following command:\n" + "The `langchain-nvidia-ai-endpoints` package was not found. Install it and other additional dependencies by " + "running the following command:" "`conda env update --solver=libmamba -n morpheus " "--file morpheus/conda/environments/dev_cuda-121_arch-x86_64.yaml --prune`") @@ -39,7 +41,7 @@ class NVFoundationLLMClient(LLMClient): `NeMoLLMService.get_client` method. Parameters ---------- - parent : NeMoLLMService + parent : NVFoundationMService The parent service for this client. model_name : str The name of the model to interact with. @@ -62,8 +64,8 @@ def __init__(self, parent: "NVFoundationLLMService", *, model_name: str, **model chat_kwargs = { "model": model_name, - "api_key": self._parent._api_key, - "base_url": self._parent._base_url, + "api_key": self._parent._api_key.value, + "base_url": self._parent._base_url.value, } # Remove None values set by the environment in the kwargs @@ -77,10 +79,12 @@ def __init__(self, parent: "NVFoundationLLMService", *, model_name: str, **model self._client = ChatNVIDIA(**{**chat_kwargs, **model_kwargs}) # type: ignore def get_input_names(self) -> list[str]: - schema = self._client.get_input_schema() - return [self._prompt_key] + @property + def model_kwargs(self): + return self._model_kwargs + def generate(self, **input_dict) -> str: """ Issue a request to generate a response based on a given prompt. @@ -111,70 +115,129 @@ async def generate_async(self, **input_dict) -> str: return (await self.generate_batch_async(inputs=inputs, **input_dict))[0] - def generate_batch(self, inputs: dict[str, list], **kwargs) -> list[str]: + def generate_batch(self, + inputs: dict[str, list], + return_exceptions: typing.Literal[True] = True, + **kwargs) -> list[str] | list[str | BaseException]: """ Issue a request to generate a list of responses based on a list of prompts. Parameters ---------- inputs : dict Inputs containing prompt data. + return_exceptions : bool + Whether to return exceptions in the output list or raise them immediately. + **kwargs + Additional keyword arguments for generate batch. """ - prompts = [StringPromptValue(text=p) for p in inputs[self._prompt_key]] + prompts = [StringPromptValue(text=p) for p in inputs[self._prompt_key]] final_kwargs = {**self._model_kwargs, **kwargs} - responses = self._client.generate_prompt(prompts=prompts, **final_kwargs) # type: ignore - - return [g[0].text for g in responses.generations] - - async def generate_batch_async(self, inputs: dict[str, list], **kwargs) -> list[str]: + responses = [] + try: + generated_responses = self._client.generate_prompt(prompts=prompts, **final_kwargs) # type: ignore + responses = [g[0].text for g in generated_responses.generations] + except Exception as e: + if return_exceptions: + responses.append(e) + else: + raise e + + return responses + + @typing.overload + async def generate_batch_async(self, + inputs: dict[str, list], + return_exceptions: typing.Literal[True] = True) -> list[str | BaseException]: + ... + + @typing.overload + async def generate_batch_async(self, + inputs: dict[str, list], + return_exceptions: typing.Literal[False] = False) -> list[str]: + ... + + async def generate_batch_async(self, + inputs: dict[str, list], + return_exceptions=True, + **kwargs) -> list[str] | list[str | BaseException]: """ Issue an asynchronous request to generate a list of responses based on a list of prompts. + Parameters ---------- inputs : dict Inputs containing prompt data. + return_exceptions : bool + Whether to return exceptions in the output list or raise them immediately. + **kwargs + Additional keyword arguments for generate batch async. """ prompts = [StringPromptValue(text=p) for p in inputs[self._prompt_key]] - final_kwargs = {**self._model_kwargs, **kwargs} - responses = await self._client.agenerate_prompt(prompts=prompts, **final_kwargs) # type: ignore + responses = [] + try: + generated_responses = await self._client.agenerate_prompt(prompts=prompts, **final_kwargs) # type: ignore + responses = [g[0].text for g in generated_responses.generations] + except Exception as e: + if return_exceptions: + responses.append(e) + else: + raise e - return [g[0].text for g in responses.generations] + return responses class NVFoundationLLMService(LLMService): """ A service for interacting with NeMo LLM models, this class should be used to create a client for a specific model. + Parameters ---------- api_key : str, optional - The API key for the LLM service, by default None. If `None` the API key will be read from the `NGC_API_KEY` - environment variable. If neither are present an error will be raised. - org_id : str, optional - The organization ID for the LLM service, by default None. If `None` the organization ID will be read from the - `NGC_ORG_ID` environment variable. This value is only required if the account associated with the `api_key` is - a member of multiple NGC organizations. + The API key for the LLM service, by default None. If `None` the API key will be read from the `NVIDIA_API_KEY` + environment variable. If neither are present an error will be raised, by default None base_url : str, optional - The api host url, by default None. If `None` the url will be read from the `NVAI_BASE_URL` environment - variable. If neither are present `https://api.nvcf.nvidia.com/v2/nvcf` will be used by langchain. + The api host url, by default None. If `None` the url will be read from the `NVIDIA_API_BASE` environment + variable. If neither are present the NVIDIA default will be used, by default None """ - def __init__(self, *, api_key: str = None, base_url: str = None) -> None: + class APIKey(EnvConfigValue): + _ENV_KEY: str = "NVIDIA_API_KEY" + + class BaseURL(EnvConfigValue): + _ENV_KEY: str = "NVIDIA_API_BASE" + _ALLOW_NONE: bool = True + + def __init__(self, *, api_key: APIKey | str = None, base_url: BaseURL | str = None, **model_kwargs) -> None: if IMPORT_EXCEPTION is not None: raise ImportError(IMPORT_ERROR_MESSAGE) from IMPORT_EXCEPTION super().__init__() + if not isinstance(api_key, NVFoundationLLMService.APIKey): + api_key = NVFoundationLLMService.APIKey(api_key) + + if not isinstance(base_url, NVFoundationLLMService.BaseURL): + base_url = NVFoundationLLMService.BaseURL(base_url) + self._api_key = api_key + self._base_url = base_url + self._default_model_kwargs = model_kwargs + + def _merge_model_kwargs(self, model_kwargs: dict) -> dict: + return {**self._default_model_kwargs, **model_kwargs} - # Set the base url from the environment if not provided. Default to None to allow the client to set the url. - if base_url is None: - self._base_url = os.getenv('NVAI_BASE_URL', None) - else: - self._base_url = base_url + @property + def api_key(self): + return self._api_key.value + + @property + def base_url(self): + return self._base_url.value def get_client(self, *, model_name: str, **model_kwargs) -> NVFoundationLLMClient: """ @@ -187,4 +250,6 @@ def get_client(self, *, model_name: str, **model_kwargs) -> NVFoundationLLMClien Additional keyword arguments to pass to the model when generating text. """ - return NVFoundationLLMClient(self, model_name=model_name, **model_kwargs) + final_model_kwargs = self._merge_model_kwargs(model_kwargs) + + return NVFoundationLLMClient(self, model_name=model_name, **final_model_kwargs) diff --git a/morpheus/llm/services/openai_chat_service.py b/morpheus/llm/services/openai_chat_service.py index 8fe1919a90..3b2c87b4f2 100644 --- a/morpheus/llm/services/openai_chat_service.py +++ b/morpheus/llm/services/openai_chat_service.py @@ -137,6 +137,31 @@ def __init__(self, api_key=self._parent._api_key.value, base_url=self._parent._base_url.value) + @property + def model_name(self): + """ + Get the name of the model associated with this client. + + Returns + ------- + str + The name of the model. + """ + return self._model_name + + @property + def model_kwargs(self): + """ + Get the keyword args that will be passed to the model when calling generation functions. + + Returns + ------- + dict + The keyword arguments dictionary. + """ + # Return a copy to avoid modification of the original + return self._model_kwargs.copy() + def get_input_names(self) -> list[str]: input_names = [self._prompt_key] if self._set_assistant: diff --git a/morpheus/service/vdb/faiss_vdb_service.py b/morpheus/service/vdb/faiss_vdb_service.py index 81f63aef5b..7b7d3362bd 100644 --- a/morpheus/service/vdb/faiss_vdb_service.py +++ b/morpheus/service/vdb/faiss_vdb_service.py @@ -13,13 +13,9 @@ # limitations under the License. import asyncio -import copy -import json import logging -import threading import time import typing -from functools import wraps import pandas as pd @@ -31,9 +27,10 @@ logger = logging.getLogger(__name__) IMPORT_EXCEPTION = None -IMPORT_ERROR_MESSAGE = "MilvusVectorDBResourceService requires the milvus and pymilvus packages to be installed." +IMPORT_ERROR_MESSAGE = "FaissDBResourceService requires the FAISS library to be installed." try: + from langchain.embeddings.base import Embeddings from langchain.vectorstores.faiss import FAISS except ImportError as import_exc: IMPORT_EXCEPTION = import_exc @@ -41,14 +38,14 @@ class FaissVectorDBResourceService(VectorDBResourceService): """ - Represents a service for managing resources in a Milvus Vector Database. + Represents a service for managing resources in a FAISS Vector Database. Parameters ---------- + parent : FaissVectorDBService + The parent service for this resource. name : str - Name of the resource. - client : MilvusClient - An instance of the MilvusClient for interaction with the Milvus Vector Database. + The name of the resource. """ def __init__(self, parent: "FaissVectorDBService", *, name: str) -> None: @@ -58,14 +55,15 @@ def __init__(self, parent: "FaissVectorDBService", *, name: str) -> None: super().__init__() self._parent = parent - self._name = name + self._folder_path = self._parent._local_dir + self._index_name = name self._index = FAISS.load_local(folder_path=self._parent._local_dir, embeddings=self._parent._embeddings, - index_name=self._name, + index_name=self._index_name, allow_dangerous_deserialization=True) - def insert(self, data: list[list] | list[dict], **kwargs: dict[str, typing.Any]) -> dict: + def insert(self, data: list[list] | list[dict], **kwargs) -> dict: """ Insert data into the vector database. @@ -73,7 +71,7 @@ def insert(self, data: list[list] | list[dict], **kwargs: dict[str, typing.Any]) ---------- data : list[list] | list[dict] Data to be inserted into the collection. - **kwargs : dict[str, typing.Any] + **kwargs Extra keyword arguments specific to the vector database implementation. Returns @@ -83,7 +81,7 @@ def insert(self, data: list[list] | list[dict], **kwargs: dict[str, typing.Any]) """ raise NotImplementedError("Insert operation is not supported in FAISS") - def insert_dataframe(self, df: typing.Union[cudf.DataFrame, pd.DataFrame], **kwargs: dict[str, typing.Any]) -> dict: + def insert_dataframe(self, df: typing.Union[cudf.DataFrame, pd.DataFrame], **kwargs) -> dict: """ Insert a dataframe entires into the vector database. @@ -91,7 +89,7 @@ def insert_dataframe(self, df: typing.Union[cudf.DataFrame, pd.DataFrame], **kwa ---------- df : typing.Union[cudf.DataFrame, pd.DataFrame] Dataframe to be inserted into the collection. - **kwargs : dict[str, typing.Any] + **kwargs Extra keyword arguments specific to the vector database implementation. Returns @@ -101,13 +99,13 @@ def insert_dataframe(self, df: typing.Union[cudf.DataFrame, pd.DataFrame], **kwa """ raise NotImplementedError("Insert operation is not supported in FAISS") - def describe(self, **kwargs: dict[str, typing.Any]) -> dict: + def describe(self, **kwargs) -> dict: """ Provides a description of the collection. Parameters ---------- - **kwargs : dict[str, typing.Any] + **kwargs Extra keyword arguments specific to the vector database implementation. Returns @@ -115,41 +113,32 @@ def describe(self, **kwargs: dict[str, typing.Any]) -> dict: dict Returns response content as a dictionary. """ - raise NotImplementedError("Describe operation is not supported in FAISS") + return { + "index_name": self._index_name, + "folder_path": self._folder_path, + } - def query(self, query: str, **kwargs: dict[str, typing.Any]) -> typing.Any: + def query(self, query: str, **kwargs) -> typing.Any: """ - Query data in a collection in the Milvus vector database. - - This method performs a search operation in the specified collection/partition in the Milvus vector database. + Query data in a collection in the vector database. Parameters ---------- query : str, optional The search query, which can be a filter expression, by default None. - **kwargs : dict + **kwargs Additional keyword arguments for the search operation. Returns ------- typing.Any The search result, which can vary depending on the query and options. - - Raises - ------ - RuntimeError - If an error occurs during the search operation. - If query argument is `None` and `data` keyword argument doesn't exist. - If `data` keyword arguement is `None`. """ raise NotImplementedError("Query operation is not supported in FAISS") - async def similarity_search(self, - embeddings: list[list[float]], - k: int = 4, - **kwargs: dict[str, typing.Any]) -> list[list[dict]]: + async def similarity_search(self, embeddings: list[list[float]], k: int = 4, **kwargs) -> list[list[dict]]: """ - Perform a similarity search within the collection. + Perform a similarity search within the FAISS docstore. Parameters ---------- @@ -157,7 +146,7 @@ async def similarity_search(self, Embeddings for which to perform the similarity search. k : int, optional The number of nearest neighbors to return, by default 4. - **kwargs : dict[str, typing.Any] + **kwargs Extra keyword arguments specific to the vector database implementation. Returns @@ -173,7 +162,7 @@ async def single_search(single_embedding): return list(await asyncio.gather(*[single_search(embedding) for embedding in embeddings])) - def update(self, data: list[typing.Any], **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]: + def update(self, data: list[typing.Any], **kwargs) -> dict[str, typing.Any]: """ Update data in the collection. @@ -181,7 +170,7 @@ def update(self, data: list[typing.Any], **kwargs: dict[str, typing.Any]) -> dic ---------- data : list[typing.Any] Data to be updated in the collection. - **kwargs : dict[str, typing.Any] + **kwargs Extra keyword arguments specific to upsert operation. Returns @@ -191,7 +180,7 @@ def update(self, data: list[typing.Any], **kwargs: dict[str, typing.Any]) -> dic """ raise NotImplementedError("Update operation is not supported in FAISS") - def delete_by_keys(self, keys: int | str | list, **kwargs: dict[str, typing.Any]) -> typing.Any: + def delete_by_keys(self, keys: int | str | list, **kwargs) -> typing.Any: """ Delete vectors by keys from the collection. @@ -199,7 +188,7 @@ def delete_by_keys(self, keys: int | str | list, **kwargs: dict[str, typing.Any] ---------- keys : int | str | list Primary keys to delete vectors. - **kwargs : dict[str, typing.Any] + **kwargs Extra keyword arguments specific to the vector database implementation. Returns @@ -209,15 +198,15 @@ def delete_by_keys(self, keys: int | str | list, **kwargs: dict[str, typing.Any] """ raise NotImplementedError("Delete by keys operation is not supported in FAISS") - def delete(self, expr: str, **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]: + def delete(self, expr: str, **kwargs) -> dict[str, typing.Any]: """ - Delete vectors from the collection using expressions. + Delete vectors by giving a list of IDs. Parameters ---------- expr : str Delete expression. - **kwargs : dict[str, typing.Any] + **kwargs Extra keyword arguments specific to the vector database implementation. Returns @@ -225,9 +214,9 @@ def delete(self, expr: str, **kwargs: dict[str, typing.Any]) -> dict[str, typing dict[str, typing.Any] Returns result of the given keys that are deleted from the collection. """ - raise NotImplementedError("Delete operation is not supported in FAISS") + raise NotImplementedError("delete operation is not supported in FAISS") - def retrieve_by_keys(self, keys: int | str | list, **kwargs: dict[str, typing.Any]) -> list[typing.Any]: + def retrieve_by_keys(self, keys: int | str | list, **kwargs) -> list[typing.Any]: """ Retrieve the inserted vectors using their primary keys. @@ -236,7 +225,7 @@ def retrieve_by_keys(self, keys: int | str | list, **kwargs: dict[str, typing.An keys : int | str | list Primary keys to get vectors for. Depending on pk_field type it can be int or str or a list of either. - **kwargs : dict[str, typing.Any] + **kwargs Additional keyword arguments for the retrieval operation. Returns @@ -246,13 +235,13 @@ def retrieve_by_keys(self, keys: int | str | list, **kwargs: dict[str, typing.An """ raise NotImplementedError("Retrieve by keys operation is not supported in FAISS") - def count(self, **kwargs: dict[str, typing.Any]) -> int: + def count(self, **kwargs) -> int: """ Returns number of rows/entities. Parameters ---------- - **kwargs : dict[str, typing.Any] + **kwargs Additional keyword arguments for the count operation. Returns @@ -260,17 +249,17 @@ def count(self, **kwargs: dict[str, typing.Any]) -> int: int Returns number of entities in the collection. """ - raise NotImplementedError("Count operation is not supported in FAISS") + return self._index.index.ntotal - def drop(self, **kwargs: dict[str, typing.Any]) -> None: + def drop(self, **kwargs) -> None: """ - Drop a collection, index, or partition in the Milvus vector database. + Drops the resource from the vector database service. This function allows you to drop a collection. Parameters ---------- - **kwargs : dict + **kwargs Additional keyword arguments for specifying the type and partition name (if applicable). """ raise NotImplementedError("Drop operation is not supported in FAISS") @@ -278,26 +267,22 @@ def drop(self, **kwargs: dict[str, typing.Any]) -> None: class FaissVectorDBService(VectorDBService): """ - Service class for Milvus Vector Database implementation. This class provides functions for interacting - with a Milvus vector database. + Service class for FAISS Vector Database implementation. This class provides functions for interacting + with a FAISS vector database. Parameters ---------- - host : str - The hostname or IP address of the Milvus server. - port : str - The port number for connecting to the Milvus server. - alias : str, optional - Alias for the Milvus connection, by default "default". - **kwargs : dict - Additional keyword arguments specific to the Milvus connection configuration. + local_dir : str + The local directory where the FAISS index files are stored. + embeddings : Embeddings + The embeddings object to use for embedding text. """ _collection_locks = {} _cleanup_interval = 600 # 10mins _last_cleanup_time = time.time() - def __init__(self, local_dir: str, embeddings, **kwargs: dict[str, typing.Any]): + def __init__(self, local_dir: str, embeddings: "Embeddings"): if IMPORT_EXCEPTION is not None: raise ImportError(IMPORT_ERROR_MESSAGE) from IMPORT_EXCEPTION @@ -305,48 +290,67 @@ def __init__(self, local_dir: str, embeddings, **kwargs: dict[str, typing.Any]): self._local_dir = local_dir self._embeddings = embeddings - def load_resource(self, name: str = "index", **kwargs: dict[str, typing.Any]) -> FaissVectorDBResourceService: + @property + def embeddings(self): + return self._embeddings + + def load_resource(self, name: str = "index", **kwargs) -> FaissVectorDBResourceService: + """ + Loads a VDB resource into memory for use. + + Parameters + ---------- + name : str, optional + The VDB resource to load. For FAISS, this corresponds to the index name, by default "index" + **kwargs + Additional keyword arguments specific to the resource service. + + Returns + ------- + FaissVectorDBResourceService + The loaded resource service. + """ return FaissVectorDBResourceService(self, name=name, **kwargs) def has_store_object(self, name: str) -> bool: """ - Check if a collection exists in the Milvus vector database. + Check if specific index file name exists by attempting to load FAISS index, docstore, + and index_to_docstore_id from disk with the index file name. Parameters ---------- name : str - Name of the collection to check. + Name of the FAISS index file to check. Returns ------- bool - True if the collection exists, False otherwise. + True if the file exists, False otherwise. """ - return self._client.has_collection(collection_name=name) + try: + FAISS.load_local(folder_path=self._local_dir, + embeddings=self._embeddings, + index_name=name, + allow_dangerous_deserialization=True) + return True + except Exception: + return False - def list_store_objects(self, **kwargs: dict[str, typing.Any]) -> list[str]: + def list_store_objects(self, **kwargs) -> list[str]: """ - List the names of all collections in the Milvus vector database. + List the names of all resources in the vector database. Returns ------- list[str] A list of collection names. """ - return self._client.list_collections(**kwargs) - - def _create_schema_field(self, field_conf: dict) -> "pymilvus.FieldSchema": - - field_schema = pymilvus.FieldSchema.construct_from_dict(field_conf) - - return field_schema + raise NotImplementedError("list_store_objects operation is not supported in FAISS") - def create(self, name: str, overwrite: bool = False, **kwargs: dict[str, typing.Any]): + def create(self, name: str, overwrite: bool = False, **kwargs): """ - Create a collection in the Milvus vector database with the specified name and configuration. This method - creates a new collection in the Milvus vector database with the provided name and configuration options. - If the collection already exists, it can be overwritten if the `overwrite` parameter is set to True. + Create a collection. Parameters ---------- @@ -354,7 +358,7 @@ def create(self, name: str, overwrite: bool = False, **kwargs: dict[str, typing. Name of the collection to be created. overwrite : bool, optional If True, the collection will be overwritten if it already exists, by default False. - **kwargs : dict + **kwargs Additional keyword arguments containing collection configuration. Raises @@ -362,48 +366,13 @@ def create(self, name: str, overwrite: bool = False, **kwargs: dict[str, typing. ValueError If the provided schema fields configuration is empty. """ - logger.debug("Creating collection: %s, overwrite=%s, kwargs=%s", name, overwrite, kwargs) - - # Preserve original configuration. - collection_conf = copy.deepcopy(kwargs) - - auto_id = collection_conf.get("auto_id", False) - index_conf = collection_conf.get("index_conf", None) - partition_conf = collection_conf.get("partition_conf", None) - - schema_conf = collection_conf.get("schema_conf") - schema_fields_conf = schema_conf.pop("schema_fields") - - if not self.has_store_object(name) or overwrite: - if overwrite and self.has_store_object(name): - self.drop(name) - - if len(schema_fields_conf) == 0: - raise ValueError("Cannot create collection as provided empty schema_fields configuration") - - schema_fields = [FieldSchemaEncoder.from_dict(field_conf) for field_conf in schema_fields_conf] - - schema = pymilvus.CollectionSchema(fields=schema_fields, **schema_conf) - - self._client.create_collection_with_schema(collection_name=name, - schema=schema, - index_params=index_conf, - auto_id=auto_id, - shards_num=collection_conf.get("shards", 2), - consistency_level=collection_conf.get( - "consistency_level", "Strong")) - - if partition_conf: - timeout = partition_conf.get("timeout", 1.0) - # Iterate over each partition configuration - for part in partition_conf["partitions"]: - self._client.create_partition(collection_name=name, partition_name=part["name"], timeout=timeout) + raise NotImplementedError("create operation is not supported in FAISS") def create_from_dataframe(self, name: str, df: typing.Union[cudf.DataFrame, pd.DataFrame], overwrite: bool = False, - **kwargs: dict[str, typing.Any]) -> None: + **kwargs) -> None: """ Create collections in the vector database. @@ -415,37 +384,15 @@ def create_from_dataframe(self, The dataframe to create the collection from. overwrite : bool, optional Whether to overwrite the collection if it already exists. Default is False. - **kwargs : dict[str, typing.Any] + **kwargs Extra keyword arguments specific to the vector database implementation. """ - fields = self._build_schema_conf(df=df) - - create_kwargs = { - "schema_conf": { - "description": "Auto generated schema from DataFrame in Morpheus", - "schema_fields": fields, - } - } - - if (kwargs.get("index_field", None) is not None): - # Check to make sure the column name exists in the fields - create_kwargs["index_conf"] = { - "field_name": kwargs.get("index_field"), # Default index type - "metric_type": "L2", - "index_type": "HNSW", - "params": { - "M": 8, - "efConstruction": 64, - }, - } - - self.create(name=name, overwrite=overwrite, **create_kwargs) + raise NotImplementedError("create_from_dataframe operation is not supported in FAISS") - def insert(self, name: str, data: list[list] | list[dict], **kwargs: dict[str, - typing.Any]) -> dict[str, typing.Any]: + def insert(self, name: str, data: list[list] | list[dict], **kwargs) -> dict[str, typing.Any]: """ - Insert a collection specific data in the Milvus vector database. + Insert a collection specific data in the vector database. Parameters ---------- @@ -453,7 +400,7 @@ def insert(self, name: str, data: list[list] | list[dict], **kwargs: dict[str, Name of the collection to be inserted. data : list[list] | list[dict] Data to be inserted in the collection. - **kwargs : dict[str, typing.Any] + **kwargs Additional keyword arguments containing collection configuration. Returns @@ -467,15 +414,12 @@ def insert(self, name: str, data: list[list] | list[dict], **kwargs: dict[str, If the collection not exists exists. """ - resource = self.load_resource(name) - return resource.insert(data, **kwargs) + raise NotImplementedError("create_from_dataframe operation is not supported in FAISS") - def insert_dataframe(self, - name: str, - df: typing.Union[cudf.DataFrame, pd.DataFrame], - **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]: + def insert_dataframe(self, name: str, df: typing.Union[cudf.DataFrame, pd.DataFrame], + **kwargs) -> dict[str, typing.Any]: """ - Converts dataframe to rows and insert to a collection in the Milvus vector database. + Converts dataframe to rows and insert to the vector database. Parameters ---------- @@ -483,7 +427,7 @@ def insert_dataframe(self, Name of the collection to be inserted. df : typing.Union[cudf.DataFrame, pd.DataFrame] Dataframe to be inserted in the collection. - **kwargs : dict[str, typing.Any] + **kwargs Additional keyword arguments containing collection configuration. Returns @@ -496,15 +440,11 @@ def insert_dataframe(self, RuntimeError If the collection not exists exists. """ - resource = self.load_resource(name) + raise NotImplementedError("insert_dataframe operation is not supported in FAISS") - return resource.insert_dataframe(df=df, **kwargs) - - def query(self, name: str, query: str = None, **kwargs: dict[str, typing.Any]) -> typing.Any: + def query(self, name: str, query: str = None, **kwargs) -> typing.Any: """ - Query data in a collection in the Milvus vector database. - - This method performs a search operation in the specified collection/partition in the Milvus vector database. + Query data in a vector database. Parameters ---------- @@ -512,7 +452,7 @@ def query(self, name: str, query: str = None, **kwargs: dict[str, typing.Any]) - Name of the collection to search within. query : str The search query, which can be a filter expression. - **kwargs : dict + **kwargs Additional keyword arguments for the search operation. Returns @@ -521,11 +461,9 @@ def query(self, name: str, query: str = None, **kwargs: dict[str, typing.Any]) - The search result, which can vary depending on the query and options. """ - resource = self.load_resource(name) - - return resource.query(query, **kwargs) + raise NotImplementedError("query operation is not supported in FAISS") - async def similarity_search(self, name: str, **kwargs: dict[str, typing.Any]) -> list[dict]: + async def similarity_search(self, name: str, **kwargs) -> list[dict]: """ Perform a similarity search within the collection. @@ -533,7 +471,7 @@ async def similarity_search(self, name: str, **kwargs: dict[str, typing.Any]) -> ---------- name : str Name of the collection. - **kwargs : dict[str, typing.Any] + **kwargs Extra keyword arguments specific to the vector database implementation. Returns @@ -542,11 +480,9 @@ async def similarity_search(self, name: str, **kwargs: dict[str, typing.Any]) -> Returns a list of dictionaries representing the results of the similarity search. """ - resource = self.load_resource(name) + raise NotImplementedError("similarity_search operation is not supported in FAISS") - return resource.similarity_search(**kwargs) - - def update(self, name: str, data: list[typing.Any], **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]: + def update(self, name: str, data: list[typing.Any], **kwargs) -> dict[str, typing.Any]: """ Update data in the vector database. @@ -556,7 +492,7 @@ def update(self, name: str, data: list[typing.Any], **kwargs: dict[str, typing.A Name of the collection. data : list[typing.Any] Data to be updated in the collection. - **kwargs : dict[str, typing.Any] + **kwargs Extra keyword arguments specific to upsert operation. Returns @@ -565,14 +501,9 @@ def update(self, name: str, data: list[typing.Any], **kwargs: dict[str, typing.A Returns result of the updated operation stats. """ - if not isinstance(data, list): - raise RuntimeError("Data is not of type list.") - - resource = self.load_resource(name) - - return resource.update(data=data, **kwargs) + raise NotImplementedError("update operation is not supported in FAISS") - def delete_by_keys(self, name: str, keys: int | str | list, **kwargs: dict[str, typing.Any]) -> typing.Any: + def delete_by_keys(self, name: str, keys: int | str | list, **kwargs) -> typing.Any: """ Delete vectors by keys from the collection. @@ -582,7 +513,7 @@ def delete_by_keys(self, name: str, keys: int | str | list, **kwargs: dict[str, Name of the collection. keys : int | str | list Primary keys to delete vectors. - **kwargs : dict[str, typing.Any] + **kwargs Extra keyword arguments specific to the vector database implementation. Returns @@ -591,11 +522,9 @@ def delete_by_keys(self, name: str, keys: int | str | list, **kwargs: dict[str, Returns result of the given keys that are delete from the collection. """ - resource = self.load_resource(name) + raise NotImplementedError("delete_by_keys operation is not supported in FAISS") - return resource.delete_by_keys(keys=keys, **kwargs) - - def delete(self, name: str, expr: str, **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]: + def delete(self, name: str, expr: str, **kwargs) -> dict[str, typing.Any]: """ Delete vectors from the collection using expressions. @@ -605,7 +534,7 @@ def delete(self, name: str, expr: str, **kwargs: dict[str, typing.Any]) -> dict[ Name of the collection. expr : str Delete expression. - **kwargs : dict[str, typing.Any] + **kwargs Extra keyword arguments specific to the vector database implementation. Returns @@ -614,12 +543,9 @@ def delete(self, name: str, expr: str, **kwargs: dict[str, typing.Any]) -> dict[ Returns result of the given keys that are delete from the collection. """ - resource = self.load_resource(name) - result = resource.delete(expr=expr, **kwargs) - - return result + raise NotImplementedError("delete operation is not supported in FAISS") - def retrieve_by_keys(self, name: str, keys: int | str | list, **kwargs: dict[str, typing.Any]) -> list[typing.Any]: + def retrieve_by_keys(self, name: str, keys: int | str | list, **kwargs) -> list[typing.Any]: """ Retrieve the inserted vectors using their primary keys from the Collection. @@ -630,7 +556,7 @@ def retrieve_by_keys(self, name: str, keys: int | str | list, **kwargs: dict[str keys : int | str | list Primary keys to get vectors for. Depending on pk_field type it can be int or str or a list of either. - **kwargs : dict[str, typing.Any] + **kwargs Additional keyword arguments for the retrieval operation. Returns @@ -639,13 +565,9 @@ def retrieve_by_keys(self, name: str, keys: int | str | list, **kwargs: dict[str Returns result rows of the given keys from the collection. """ - resource = self.load_resource(name) - - result = resource.retrieve_by_keys(keys=keys, **kwargs) - - return result + raise NotImplementedError("retrieve_by_keys operation is not supported in FAISS") - def count(self, name: str, **kwargs: dict[str, typing.Any]) -> int: + def count(self, name: str, **kwargs) -> int: """ Returns number of rows/entities in the given collection. @@ -653,7 +575,7 @@ def count(self, name: str, **kwargs: dict[str, typing.Any]) -> int: ---------- name : str Name of the collection. - **kwargs : dict[str, typing.Any] + **kwargs Additional keyword arguments for the count operation. Returns @@ -661,68 +583,29 @@ def count(self, name: str, **kwargs: dict[str, typing.Any]) -> int: int Returns number of entities in the collection. """ - resource = self.load_resource(name) - return resource.count(**kwargs) + raise NotImplementedError("count operation is not supported in FAISS") - def drop(self, name: str, **kwargs: dict[str, typing.Any]) -> None: + def drop(self, name: str, **kwargs) -> None: """ - Drop a collection, index, or partition in the Milvus vector database. - - This method allows you to drop a collection, an index within a collection, - or a specific partition within a collection in the Milvus vector database. + Drop a collection. Parameters ---------- name : str Name of the collection, index, or partition to be dropped. - **kwargs : dict + **kwargs Additional keyword arguments for specifying the type and partition name (if applicable). - Notes on Expected Keyword Arguments: - ------------------------------------ - - 'collection' (str, optional): - Specifies the type of collection to drop. Possible values: 'collection' (default), 'index', 'partition'. - - - 'partition_name' (str, optional): - Required when dropping a specific partition within a collection. Specifies the partition name to be dropped. - - - 'field_name' (str, optional): - Required when dropping an index within a collection. Specifies the field name for which the index is created. - - - 'index_name' (str, optional): - Required when dropping an index within a collection. Specifies the name of the index to be dropped. - Raises ------ ValueError If mandatory arguments are missing or if the provided 'collection' value is invalid. """ - logger.debug("Dropping collection: %s, kwargs=%s", name, kwargs) - - if self.has_store_object(name): - resource = kwargs.get("resource", "collection") - if resource == "collection": - self._client.drop_collection(collection_name=name) - elif resource == "partition": - if "partition_name" not in kwargs: - raise ValueError("Mandatory argument 'partition_name' is required when resource='partition'") - partition_name = kwargs["partition_name"] - if self._client.has_partition(collection_name=name, partition_name=partition_name): - # Collection need to be released before dropping the partition. - self._client.release_collection(collection_name=name) - self._client.drop_partition(collection_name=name, partition_name=partition_name) - elif resource == "index": - if "field_name" in kwargs and "index_name" in kwargs: - self._client.drop_index(collection_name=name, - field_name=kwargs["field_name"], - index_name=kwargs["index_name"]) - else: - raise ValueError( - "Mandatory arguments 'field_name' and 'index_name' are required when resource='index'") - - def describe(self, name: str, **kwargs: dict[str, typing.Any]) -> dict: + raise NotImplementedError("drop operation is not supported in FAISS") + + def describe(self, name: str, **kwargs) -> dict: """ Describe the collection in the vector database. @@ -730,8 +613,8 @@ def describe(self, name: str, **kwargs: dict[str, typing.Any]) -> dict: ---------- name : str Name of the collection. - **kwargs : dict[str, typing.Any] - Additional keyword arguments specific to the Milvus vector database. + **kwargs + Additional keyword arguments specific to the vector database. Returns ------- @@ -739,9 +622,7 @@ def describe(self, name: str, **kwargs: dict[str, typing.Any]) -> dict: Returns collection information. """ - resource = self.load_resource(name) - - return resource.describe(**kwargs) + raise NotImplementedError("describe operation is not supported in FAISS") def release_resource(self, name: str) -> None: """ @@ -753,13 +634,10 @@ def release_resource(self, name: str) -> None: Name of the collection to release. """ - self._client.release_collection(collection_name=name) + raise NotImplementedError("release_resource operation is not supported in FAISS") def close(self) -> None: """ - Close the connection to the Milvus vector database. - - This method disconnects from the Milvus vector database by removing the connection. - + Close the vector database service and release all resources. """ - self._client.close() + raise NotImplementedError("close operation is not supported in FAISS") diff --git a/tests/conftest.py b/tests/conftest.py index ad9158b7d6..ee5181d3bc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,7 @@ import types import typing import warnings +from pathlib import Path from unittest import mock import pytest @@ -485,7 +486,7 @@ def seed_fn(seed=42): @pytest.fixture(scope="function") -def chdir_tmpdir(request: pytest.FixtureRequest, tmp_path): +def chdir_tmpdir(request: pytest.FixtureRequest, tmp_path: Path): """ Executes a test in the tmp_path directory """ @@ -1054,6 +1055,18 @@ def nemollm_fixture(fail_missing: bool): yield import_or_skip("nemollm", reason=skip_reason, fail_missing=fail_missing) +@pytest.fixture(name="nvfoundationllm", scope='session') +def nvfoundationllm_fixture(fail_missing: bool): + """ + Fixture to ensure nvfoundationllm is installed + """ + skip_reason = ( + "Tests for NVFoundation require the langchain-nvidia-ai-endpoints package to be installed, to install this " + "run:\n `conda env update --solver=libmamba -n morpheus " + "--file conda/environments/dev_cuda-121_arch-x86_64.yaml --prune`") + yield import_or_skip("langchain_nvidia_ai_endpoints", reason=skip_reason, fail_missing=fail_missing) + + @pytest.fixture(name="openai", scope='session') def openai_fixture(fail_missing: bool): """ diff --git a/tests/llm/services/test_llm_service_pipe.py b/tests/llm/services/test_llm_service_pipe.py index e6e2f8bbf3..13fb5f652e 100644 --- a/tests/llm/services/test_llm_service_pipe.py +++ b/tests/llm/services/test_llm_service_pipe.py @@ -18,12 +18,13 @@ import cudf from _utils import assert_results +from _utils.environment import set_env from _utils.llm import mk_mock_openai_response from morpheus.config import Config from morpheus.llm import LLMEngine from morpheus.llm.nodes.extracter_node import ExtracterNode from morpheus.llm.nodes.llm_generate_node import LLMGenerateNode -from morpheus.llm.services.llm_service import LLMService +from morpheus.llm.services.llm_service import LLMClient from morpheus.llm.services.nemo_llm_service import NeMoLLMService from morpheus.llm.services.openai_chat_service import OpenAIChatService from morpheus.llm.task_handlers.simple_task_handler import SimpleTaskHandler @@ -35,22 +36,17 @@ from morpheus.stages.preprocess.deserialize_stage import DeserializeStage -def _build_engine(llm_service_cls: type[LLMService]): - llm_service = llm_service_cls() - llm_clinet = llm_service.get_client(model_name="test_model") +def _build_engine(llm_client: LLMClient): engine = LLMEngine() engine.add_node("extracter", node=ExtracterNode()) - engine.add_node("completion", inputs=["/extracter"], node=LLMGenerateNode(llm_client=llm_clinet)) + engine.add_node("completion", inputs=["/extracter"], node=LLMGenerateNode(llm_client=llm_client)) engine.add_task_handler(inputs=["/completion"], handler=SimpleTaskHandler()) return engine -def _run_pipeline(config: Config, - llm_service_cls: type[LLMService], - country_prompts: list[str], - capital_responses: list[str]): +def _run_pipeline(config: Config, llm_client: LLMClient, country_prompts: list[str], capital_responses: list[str]): """ Loosely patterned after `examples/llm/completion` """ @@ -66,7 +62,7 @@ def _run_pipeline(config: Config, pipe.add_stage( DeserializeStage(config, message_type=ControlMessage, task_type="llm_engine", task_payload=completion_task)) - pipe.add_stage(LLMEngineStage(config, engine=_build_engine(llm_service_cls))) + pipe.add_stage(LLMEngineStage(config, engine=_build_engine(llm_client))) sink = pipe.add_stage(CompareDataFrameStage(config, compare_df=expected_df)) pipe.run() @@ -79,7 +75,13 @@ def test_completion_pipe_nemo(config: Config, country_prompts: list[str], capital_responses: list[str]): mock_nemollm.post_process_generate_response.side_effect = [{"text": response} for response in capital_responses] - _run_pipeline(config, NeMoLLMService, country_prompts, capital_responses) + + # Set a dummy key to bypass the API key check + with set_env(NGC_API_KEY="test"): + + llm_client = NeMoLLMService().get_client(model_name="test_model") + + _run_pipeline(config, llm_client, country_prompts, capital_responses) def test_completion_pipe_openai(config: Config, @@ -91,7 +93,10 @@ def test_completion_pipe_openai(config: Config, mk_mock_openai_response([response]) for response in capital_responses ] - _run_pipeline(config, OpenAIChatService, country_prompts, capital_responses) + with set_env(OPENAI_API_KEY="test"): + llm_client = OpenAIChatService().get_client(model_name="test_model") + + _run_pipeline(config, llm_client, country_prompts, capital_responses) - mock_client.chat.completions.create.assert_not_called() - mock_async_client.chat.completions.create.assert_called() + mock_client.chat.completions.create.assert_not_called() + mock_async_client.chat.completions.create.assert_called() diff --git a/tests/llm/services/test_nvfoundation_llm_service.py b/tests/llm/services/test_nvfoundation_llm_service.py new file mode 100644 index 0000000000..dec76060e8 --- /dev/null +++ b/tests/llm/services/test_nvfoundation_llm_service.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import mock + +import pytest +from langchain_core.messages import ChatMessage +from langchain_core.outputs import ChatGeneration +from langchain_core.outputs import LLMResult + +from morpheus.llm.services.nvfoundation_llm_service import NVFoundationLLMClient +from morpheus.llm.services.nvfoundation_llm_service import NVFoundationLLMService + + +@pytest.fixture(name="set_default_nvidia_api_key", autouse=True, scope="function") +def set_default_nvidia_api_key_fixture(): + # Must have an API key set to create the openai client + with mock.patch.dict(os.environ, clear=True, values={"NVIDIA_API_KEY": "nvapi-testing_api_key"}): + yield + + +@pytest.mark.parametrize("api_key", ["nvapi-12345", None]) +@pytest.mark.parametrize("base_url", ["http://test.nvidia.com/v1", None]) +def test_constructor(api_key: str, base_url: bool): + + service = NVFoundationLLMService(api_key=api_key, base_url=base_url) + + if (api_key is None): + api_key = os.environ["NVIDIA_API_KEY"] + + assert service.api_key == api_key + assert service.base_url == base_url + + +def test_get_client(): + service = NVFoundationLLMService(api_key="test_api_key") + client = service.get_client(model_name="test_model") + + assert isinstance(client, NVFoundationLLMClient) + + +def test_model_kwargs(): + service = NVFoundationLLMService(arg1="default_value1", arg2="default_value2") + + client = service.get_client(model_name="model_name", arg2="value2") + + assert client.model_kwargs["arg1"] == "default_value1" + assert client.model_kwargs["arg2"] == "value2" + + +def test_get_input_names(): + client = NVFoundationLLMService().get_client(model_name="test_model", additional_arg="test_arg") + + assert client.get_input_names() == ["prompt"] + + +def test_generate(): + with mock.patch("langchain_nvidia_ai_endpoints.ChatNVIDIA.generate_prompt", autospec=True) as mock_nvfoundationllm: + + def mock_generation_side_effect(*_, **kwargs): + return LLMResult(generations=[[ + ChatGeneration(message=ChatMessage(content=x.text, role="assistant")) for x in kwargs["prompts"] + ]]) + + mock_nvfoundationllm.side_effect = mock_generation_side_effect + + client = NVFoundationLLMService().get_client(model_name="test_model") + assert client.generate(prompt="test_prompt") == "test_prompt" + + +def test_generate_batch(): + + with mock.patch("langchain_nvidia_ai_endpoints.ChatNVIDIA.generate_prompt", autospec=True) as mock_nvfoundationllm: + + def mock_generation_side_effect(*_, **kwargs): + return LLMResult(generations=[[ChatGeneration(message=ChatMessage(content=x.text, role="assistant"))] + for x in kwargs["prompts"]]) + + mock_nvfoundationllm.side_effect = mock_generation_side_effect + + client = NVFoundationLLMService().get_client(model_name="test_model") + + assert client.generate_batch({'prompt': ["prompt1", "prompt2"]}) == ["prompt1", "prompt2"] + + +async def test_generate_async(): + + with mock.patch("langchain_nvidia_ai_endpoints.ChatNVIDIA.agenerate_prompt", autospec=True) as mock_nvfoundationllm: + + def mock_generation_side_effect(*_, **kwargs): + return LLMResult(generations=[[ChatGeneration(message=ChatMessage(content=x.text, role="assistant"))] + for x in kwargs["prompts"]]) + + mock_nvfoundationllm.side_effect = mock_generation_side_effect + + client = NVFoundationLLMService().get_client(model_name="test_model") + + assert await client.generate_async(prompt="test_prompt") == "test_prompt" + + +async def test_generate_batch_async(): + + with mock.patch("langchain_nvidia_ai_endpoints.ChatNVIDIA.agenerate_prompt", autospec=True) as mock_nvfoundationllm: + + def mock_generation_side_effect(*_, **kwargs): + return LLMResult(generations=[[ChatGeneration(message=ChatMessage(content=x.text, role="assistant"))] + for x in kwargs["prompts"]]) + + mock_nvfoundationllm.side_effect = mock_generation_side_effect + + client = NVFoundationLLMService().get_client(model_name="test_model") + + assert await client.generate_batch_async({'prompt': ["prompt1", "prompt2"]}) + + +async def test_generate_batch_async_error(): + with mock.patch("langchain_nvidia_ai_endpoints.ChatNVIDIA.agenerate_prompt", autospec=True) as mock_nvfoundationllm: + + def mock_generation_side_effect(*_, **kwargs): + raise RuntimeError("unittest") + + mock_nvfoundationllm.side_effect = mock_generation_side_effect + + client = NVFoundationLLMService().get_client(model_name="test_model") + + with pytest.raises(RuntimeError, match="unittest"): + await client.generate_batch_async({'prompt': ["prompt1", "prompt2"]}, return_exceptions=False) diff --git a/tests/llm/services/test_openai_chat_client.py b/tests/llm/services/test_openai_chat_client.py deleted file mode 100644 index 628274f68b..0000000000 --- a/tests/llm/services/test_openai_chat_client.py +++ /dev/null @@ -1,151 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from unittest import mock - -import pytest - -from _utils.llm import mk_mock_openai_response -from morpheus.llm.services.openai_chat_service import OpenAIChatService - - -@pytest.mark.parametrize("api_key", ["12345", None]) -@pytest.mark.parametrize("base_url", ["http://test.openai.com/v1", None]) -@pytest.mark.parametrize("max_retries", [5, 10]) -def test_constructor(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], - api_key: str, - base_url: str, - max_retries: int): - OpenAIChatService(api_key=api_key, base_url=base_url).get_client(model_name="test_model", max_retries=max_retries) - - for mock_client in mock_chat_completion: - mock_client.assert_called_once_with(api_key=api_key, base_url=base_url, max_retries=max_retries) - - -@pytest.mark.parametrize("max_retries", [5, 10]) -def test_constructor_default_service_constructor(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], - max_retries: int): - OpenAIChatService().get_client(model_name="test_model", max_retries=max_retries) - - for mock_client in mock_chat_completion: - mock_client.assert_called_once_with(max_retries=max_retries, organization=None, api_key=None, base_url=None) - - -@pytest.mark.parametrize("use_async", [True, False]) -@pytest.mark.parametrize( - "input_dict, set_assistant, expected_messages", - [({ - "prompt": "test_prompt", "assistant": "assistant_response" - }, - True, [{ - "role": "user", "content": "test_prompt" - }, { - "role": "assistant", "content": "assistant_response" - }]), ({ - "prompt": "test_prompt" - }, False, [{ - "role": "user", "content": "test_prompt" - }])]) -@pytest.mark.parametrize("temperature", [0, 1, 2]) -def test_generate(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], - use_async: bool, - input_dict: dict[str, str], - set_assistant: bool, - expected_messages: list[dict], - temperature: int): - (mock_client, mock_async_client) = mock_chat_completion - client = OpenAIChatService().get_client(model_name="test_model", - set_assistant=set_assistant, - temperature=temperature) - - if use_async: - results = asyncio.run(client.generate_async(**input_dict)) - mock_async_client.chat.completions.create.assert_called_once_with(model="test_model", - messages=expected_messages, - temperature=temperature) - mock_client.chat.completions.create.assert_not_called() - - else: - results = client.generate(**input_dict) - mock_client.chat.completions.create.assert_called_once_with(model="test_model", - messages=expected_messages, - temperature=temperature) - mock_async_client.chat.completions.create.assert_not_called() - - assert results == "test_output" - - -@pytest.mark.parametrize("use_async", [True, False]) -@pytest.mark.parametrize("inputs, set_assistant, expected_messages", - [({ - "prompt": ["prompt1", "prompt2"], "assistant": ["assistant1", "assistant2"] - }, - True, - [[{ - "role": "user", "content": "prompt1" - }, { - "role": "assistant", "content": "assistant1" - }], [{ - "role": "user", "content": "prompt2" - }, { - "role": "assistant", "content": "assistant2" - }]]), - ({ - "prompt": ["prompt1", "prompt2"] - }, - False, [[{ - "role": "user", "content": "prompt1" - }], [{ - "role": "user", "content": "prompt2" - }]])]) -@pytest.mark.parametrize("temperature", [0, 1, 2]) -def test_generate_batch(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], - use_async: bool, - inputs: dict[str, list[str]], - set_assistant: bool, - expected_messages: list[list[dict]], - temperature: int): - (mock_client, mock_async_client) = mock_chat_completion - client = OpenAIChatService().get_client(model_name="test_model", - set_assistant=set_assistant, - temperature=temperature) - - expected_results = ["test_output" for _ in range(len(inputs["prompt"]))] - expected_calls = [ - mock.call(model="test_model", messages=messages, temperature=temperature) for messages in expected_messages - ] - - if use_async: - results = asyncio.run(client.generate_batch_async(inputs)) - mock_async_client.chat.completions.create.assert_has_calls(expected_calls, any_order=False) - mock_client.chat.completions.create.assert_not_called() - - else: - results = client.generate_batch(inputs) - mock_client.chat.completions.create.assert_has_calls(expected_calls, any_order=False) - mock_async_client.chat.completions.create.assert_not_called() - - assert results == expected_results - - -@pytest.mark.parametrize("completion", [[], [None]], ids=["no_choices", "no_content"]) -@pytest.mark.usefixtures("mock_chat_completion") -def test_extract_completion_errors(completion: list): - client = OpenAIChatService().get_client(model_name="test_model") - mock_completion = mk_mock_openai_response(completion) - - with pytest.raises(ValueError): - client._extract_completion(mock_completion) diff --git a/tests/llm/services/test_openai_chat_service.py b/tests/llm/services/test_openai_chat_service.py index f3adc1023a..54b4290ded 100644 --- a/tests/llm/services/test_openai_chat_service.py +++ b/tests/llm/services/test_openai_chat_service.py @@ -13,50 +13,230 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import os from unittest import mock import pytest -from morpheus.llm.services.llm_service import LLMService -from morpheus.llm.services.openai_chat_service import OpenAIChatClient +from _utils.llm import mk_mock_openai_response from morpheus.llm.services.openai_chat_service import OpenAIChatService -def test_constructor(): - service = OpenAIChatService() - assert isinstance(service, LLMService) +@pytest.fixture(name="set_default_openai_api_key", autouse=True, scope="function") +def set_default_openai_api_key_fixture(): + # Must have an API key set to create the openai client + with mock.patch.dict(os.environ, clear=True, values={"OPENAI_API_KEY": "testing_api_key"}): + yield + + +def assert_called_once_with_relaxed(mock_obj, *args, **kwargs): + + if (len(mock_obj.call_args_list) == 1): + + recent_call = mock_obj.call_args_list[-1] + + # Ensure that the number of arguments matches by adding ANY to the back of the args + if (len(args) < len(recent_call.args)): + args = tuple(list(args) + [mock.ANY] * (len(recent_call.args) - len(args))) + + addl_kwargs = {key: mock.ANY for key in recent_call.kwargs.keys() if key not in kwargs} + + kwargs.update(addl_kwargs) + + mock_obj.assert_called_once_with(*args, **kwargs) + + +@pytest.mark.parametrize("api_key", ["12345", None]) +@pytest.mark.parametrize("base_url", ["http://test.openai.com/v1", None]) +@pytest.mark.parametrize("org_id", ["my-org-124", None]) +def test_constructor(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], + api_key: str, + base_url: str, + org_id: str): + + OpenAIChatService(api_key=api_key, base_url=base_url, org_id=org_id).get_client(model_name="test_model") + + if (api_key is None): + api_key = os.environ["OPENAI_API_KEY"] + + for mock_client in mock_chat_completion: + assert_called_once_with_relaxed(mock_client, organization=org_id, api_key=api_key, base_url=base_url) + + +@pytest.mark.parametrize("max_retries", [5, 10, -1, None]) +def test_max_retries(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], max_retries: int): + OpenAIChatService().get_client(model_name="test_model", max_retries=max_retries) + + for mock_client in mock_chat_completion: + assert_called_once_with_relaxed(mock_client, max_retries=max_retries) + + +@pytest.mark.parametrize("use_json", [True, False]) +def test_client_json(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], use_json: bool): + client = OpenAIChatService().get_client(model_name="test_model", json=use_json) + + # Perform a dummy generate call + client.generate(prompt="test_prompt") + + if (use_json): + assert_called_once_with_relaxed(mock_chat_completion[0].chat.completions.create, + response_format={"type": "json_object"}) + else: + assert mock_chat_completion[0].chat.completions.create.call_args_list[-1].kwargs.get("response_format") is None + + +@pytest.mark.parametrize("set_assistant", [True, False]) +def test_client_set_assistant(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], set_assistant: bool): + client = OpenAIChatService().get_client(model_name="test_model", set_assistant=set_assistant) + + # Perform a dummy generate call + client.generate(prompt="test_prompt", assistant="assistant_message") + + messages = mock_chat_completion[0].chat.completions.create.call_args_list[-1].kwargs["messages"] + + found_assistant = False + + for message in messages: + if (message.get("role") == "assistant"): + found_assistant = True + break + + assert found_assistant == set_assistant + + +@pytest.mark.parametrize("use_async", [True, False]) +@pytest.mark.parametrize( + "input_dict, set_assistant, expected_messages", + [({ + "prompt": "test_prompt", "assistant": "assistant_response" + }, + True, [{ + "role": "user", "content": "test_prompt" + }, { + "role": "assistant", "content": "assistant_response" + }]), ({ + "prompt": "test_prompt" + }, False, [{ + "role": "user", "content": "test_prompt" + }])]) +@pytest.mark.parametrize("temperature", [0, 1, 2]) +def test_generate(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], + use_async: bool, + input_dict: dict[str, str], + set_assistant: bool, + expected_messages: list[dict], + temperature: int): + (mock_client, mock_async_client) = mock_chat_completion + client = OpenAIChatService().get_client(model_name="test_model", + set_assistant=set_assistant, + temperature=temperature) + + if use_async: + results = asyncio.run(client.generate_async(**input_dict)) + mock_async_client.chat.completions.create.assert_called_once_with(model="test_model", + messages=expected_messages, + temperature=temperature) + mock_client.chat.completions.create.assert_not_called() + + else: + results = client.generate(**input_dict) + mock_client.chat.completions.create.assert_called_once_with(model="test_model", + messages=expected_messages, + temperature=temperature) + mock_async_client.chat.completions.create.assert_not_called() + + assert results == "test_output" + + +@pytest.mark.parametrize("use_async", [True, False]) +@pytest.mark.parametrize("inputs, set_assistant, expected_messages", + [({ + "prompt": ["prompt1", "prompt2"], "assistant": ["assistant1", "assistant2"] + }, + True, + [[{ + "role": "user", "content": "prompt1" + }, { + "role": "assistant", "content": "assistant1" + }], [{ + "role": "user", "content": "prompt2" + }, { + "role": "assistant", "content": "assistant2" + }]]), + ({ + "prompt": ["prompt1", "prompt2"] + }, + False, [[{ + "role": "user", "content": "prompt1" + }], [{ + "role": "user", "content": "prompt2" + }]])]) +@pytest.mark.parametrize("temperature", [0, 1, 2]) +def test_generate_batch(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], + use_async: bool, + inputs: dict[str, list[str]], + set_assistant: bool, + expected_messages: list[list[dict]], + temperature: int): + (mock_client, mock_async_client) = mock_chat_completion + client = OpenAIChatService().get_client(model_name="test_model", + set_assistant=set_assistant, + temperature=temperature) + + expected_results = ["test_output" for _ in range(len(inputs["prompt"]))] + expected_calls = [ + mock.call(model="test_model", messages=messages, temperature=temperature) for messages in expected_messages + ] + + if use_async: + results = asyncio.run(client.generate_batch_async(inputs)) + mock_async_client.chat.completions.create.assert_has_calls(expected_calls, any_order=False) + mock_client.chat.completions.create.assert_not_called() + + else: + results = client.generate_batch(inputs) + mock_client.chat.completions.create.assert_has_calls(expected_calls, any_order=False) + mock_async_client.chat.completions.create.assert_not_called() + + assert results == expected_results + + +@pytest.mark.parametrize("completion", [[], [None]], ids=["no_choices", "no_content"]) +@pytest.mark.usefixtures("mock_chat_completion") +def test_extract_completion_errors(completion: list): + client = OpenAIChatService().get_client(model_name="test_model") + mock_completion = mk_mock_openai_response(completion) + + with pytest.raises(ValueError): + client._extract_completion(mock_completion) def test_get_client(): service = OpenAIChatService() client = service.get_client(model_name="test_model") - assert isinstance(client, OpenAIChatClient) + assert client.model_name == "test_model" + + client = service.get_client(model_name="test_model2", extra_arg="test_arg") + + assert client.model_name == "test_model2" + assert client.model_kwargs == {"extra_arg": "test_arg"} -@pytest.mark.parametrize("use_json", [True, False]) -@pytest.mark.parametrize("set_assistant", [True, False]) @pytest.mark.parametrize("temperature", [0, 1, 2]) @pytest.mark.parametrize("max_retries", [5, 10]) -@mock.patch("morpheus.llm.services.openai_chat_service.OpenAIChatClient") -def test_get_client_passed_args(mock_client: mock.MagicMock, - set_assistant: bool, - use_json: bool, +def test_get_client_passed_args(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], temperature: int, max_retries: int): service = OpenAIChatService() - service.get_client(model_name="test_model", - set_assistant=set_assistant, - json=use_json, - temperature=temperature, - test='this', - max_retries=max_retries) + client = service.get_client(model_name="test_model", temperature=temperature, test='this', max_retries=max_retries) + + # Perform a dummy generate call + client.generate(prompt="test_prompt") # Ensure the get_client method passed on the set_assistant and model kwargs - mock_client.assert_called_once_with(service, - model_name="test_model", - set_assistant=set_assistant, - json=use_json, - temperature=temperature, - test='this', - max_retries=max_retries) + assert_called_once_with_relaxed(mock_chat_completion[0].chat.completions.create, + model="test_model", + temperature=temperature, + test='this') diff --git a/tests/llm/test_completion_pipe.py b/tests/llm/test_completion_pipe.py index 106eb39586..e57e36f09f 100644 --- a/tests/llm/test_completion_pipe.py +++ b/tests/llm/test_completion_pipe.py @@ -28,7 +28,7 @@ from morpheus.llm.nodes.extracter_node import ExtracterNode from morpheus.llm.nodes.llm_generate_node import LLMGenerateNode from morpheus.llm.nodes.prompt_template_node import PromptTemplateNode -from morpheus.llm.services.llm_service import LLMService +from morpheus.llm.services.llm_service import LLMClient from morpheus.llm.services.nemo_llm_service import NeMoLLMService from morpheus.llm.services.openai_chat_service import OpenAIChatService from morpheus.llm.task_handlers.simple_task_handler import SimpleTaskHandler @@ -42,9 +42,7 @@ logger = logging.getLogger(__name__) -def _build_engine(llm_service_cls: type[LLMService], model_name: str = "test_model"): - llm_service = llm_service_cls() - llm_client = llm_service.get_client(model_name=model_name) +def _build_engine(llm_client: LLMClient): engine = LLMEngine() engine.add_node("extracter", node=ExtracterNode()) @@ -57,11 +55,7 @@ def _build_engine(llm_service_cls: type[LLMService], model_name: str = "test_mod return engine -def _run_pipeline(config: Config, - llm_service_cls: type[LLMService], - countries: list[str], - capital_responses: list[str], - model_name: str = "test_model") -> dict: +def _run_pipeline(config: Config, llm_client: LLMClient, countries: list[str], capital_responses: list[str]) -> dict: """ Loosely patterned after `examples/llm/completion` """ @@ -81,7 +75,7 @@ def _run_pipeline(config: Config, task_type="llm_engine", task_payload=completion_task)) - pipe.add_stage(LLMEngineStage(config, engine=_build_engine(llm_service_cls, model_name=model_name))) + pipe.add_stage(LLMEngineStage(config, engine=_build_engine(llm_client))) sink = pipe.add_stage(CompareDataFrameStage(config, compare_df=expected_df)) @@ -99,8 +93,10 @@ def test_completion_pipe_nemo(config: Config, # Set a dummy key to bypass the API key check with set_env(NGC_API_KEY="test"): + llm_client = NeMoLLMService().get_client(model_name="test_model") + mock_nemollm.post_process_generate_response.side_effect = [{"text": response} for response in capital_responses] - results = _run_pipeline(config, NeMoLLMService, countries=countries, capital_responses=capital_responses) + results = _run_pipeline(config, llm_client, countries=countries, capital_responses=capital_responses) assert_results(results) @@ -114,20 +110,21 @@ def test_completion_pipe_openai(config: Config, mk_mock_openai_response([response]) for response in capital_responses ] - results = _run_pipeline(config, OpenAIChatService, countries=countries, capital_responses=capital_responses) - assert_results(results) - mock_client.chat.completions.create.assert_not_called() - mock_async_client.chat.completions.create.assert_called() + with set_env(OPENAI_API_KEY="test"): + llm_client = OpenAIChatService().get_client(model_name="test_model") + + results = _run_pipeline(config, llm_client, countries=countries, capital_responses=capital_responses) + assert_results(results) + mock_client.chat.completions.create.assert_not_called() + mock_async_client.chat.completions.create.assert_called() @pytest.mark.usefixtures("nemollm") @pytest.mark.usefixtures("ngc_api_key") def test_completion_pipe_integration_nemo(config: Config, countries: list[str], capital_responses: list[str]): - results = _run_pipeline(config, - NeMoLLMService, - countries=countries, - capital_responses=capital_responses, - model_name="gpt-43b-002") + llm_client = NeMoLLMService().get_client(model_name="gpt-43b-002") + + results = _run_pipeline(config, llm_client, countries=countries, capital_responses=capital_responses) assert results['diff_cols'] == 0 assert results['total_rows'] == len(countries) assert results['matching_rows'] + results['diff_rows'] == len(countries) @@ -136,11 +133,9 @@ def test_completion_pipe_integration_nemo(config: Config, countries: list[str], @pytest.mark.usefixtures("openai") @pytest.mark.usefixtures("openai_api_key") def test_completion_pipe_integration_openai(config: Config, countries: list[str], capital_responses: list[str]): - results = _run_pipeline(config, - OpenAIChatService, - countries=countries, - capital_responses=capital_responses, - model_name="gpt-3.5-turbo") + llm_client = NeMoLLMService().get_client(model_name="gpt-3.5-turbo") + + results = _run_pipeline(config, llm_client, countries=countries, capital_responses=capital_responses) assert results['diff_cols'] == 0 assert results['total_rows'] == len(countries) assert results['matching_rows'] + results['diff_rows'] == len(countries) diff --git a/tests/test_faiss_vector_db_service.py b/tests/test_faiss_vector_db_service.py new file mode 100644 index 0000000000..98a428bbe3 --- /dev/null +++ b/tests/test_faiss_vector_db_service.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import typing +from pathlib import Path + +import pytest + +from morpheus.service.vdb.faiss_vdb_service import FaissVectorDBResourceService +from morpheus.service.vdb.faiss_vdb_service import FaissVectorDBService + +if (typing.TYPE_CHECKING): + from langchain_core.embeddings import Embeddings +else: + lc_core_embeddings = pytest.importorskip("langchain_core.embeddings", reason="langchain_core not installed") + Embeddings = lc_core_embeddings.Embeddings + + +class FakeEmbedder(Embeddings): + + def embed_query(self, text: str) -> list[float]: + # One-hot encoding using length of text + vec = [float(0.0)] * 1024 + + vec[len(text) % 1024] = 1.0 + + return vec + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [self.embed_query(text) for text in texts] + + async def aembed_query(self, text: str) -> list[float]: + return self.embed_query(text) + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + return self.embed_documents(texts) + + +@pytest.fixture(scope="function", name="faiss_simple_store_dir") +def faiss_simple_store_dir_fixture(tmpdir_path: Path): + + from langchain_community.vectorstores.faiss import FAISS + + embeddings = FakeEmbedder() + + # create FAISS docstore for testing + index_store = FAISS.from_texts([str(x) * x for x in range(3)], embeddings, ids=[chr(x + 97) for x in range(3)]) + + index_store.save_local(str(tmpdir_path), index_name="index") + + # create a second index for testing + other_store = FAISS.from_texts([str(x) * x for x in range(3, 8)], + embeddings, + ids=[chr(x + 97) for x in range(3, 8)]) + other_store.save_local(str(tmpdir_path), index_name="other_index") + + return str(tmpdir_path) + + +@pytest.fixture(scope="function", name="faiss_service") +def faiss_service_fixture(faiss_simple_store_dir: str): + # Fixture for FAISS service; can edit FAISS docstore instantiated outside fixture if need to change + # embedding model, et. + service = FaissVectorDBService(local_dir=faiss_simple_store_dir, embeddings=FakeEmbedder()) + yield service + + +def test_load_resource(faiss_service: FaissVectorDBService): + + # Check the default implementation + resource = faiss_service.load_resource() + assert isinstance(resource, FaissVectorDBResourceService) + + # Check specifying a name + resource = faiss_service.load_resource("index") + assert resource.describe()["index_name"] == "index" + + # Check another name + resource = faiss_service.load_resource("other_index") + assert resource.describe()["index_name"] == "other_index" + + +def test_describe(faiss_service: FaissVectorDBService): + desc_dict = faiss_service.load_resource().describe() + + assert desc_dict["index_name"] == "index" + assert os.path.exists(desc_dict["folder_path"]) + # Room for other properties + + +def test_count(faiss_service: FaissVectorDBService): + + count = faiss_service.load_resource().count() + assert count == 3 + + +async def test_similarity_search(faiss_service: FaissVectorDBService): + + vdb = faiss_service.load_resource() + + query_vec = await faiss_service.embeddings.aembed_query("22") + + k_1 = await vdb.similarity_search(embeddings=[query_vec], k=1) + + assert len(k_1[0]) == 1 + assert k_1[0][0]["page_content"] == "22" + + k_3 = await vdb.similarity_search(embeddings=[query_vec], k=3) + + assert len(k_3[0]) == 3 + assert k_3[0][0]["page_content"] == "22" + + # Exceed the number of documents in the docstore + k_5 = await vdb.similarity_search(embeddings=[query_vec], k=vdb.count() + 2) + + assert len(k_5[0]) == vdb.count() + assert k_5[0][0]["page_content"] == "22" + + +def test_has_store_object(faiss_service: FaissVectorDBService): + assert faiss_service.has_store_object("index") + + assert faiss_service.has_store_object("other_index") + + assert not faiss_service.has_store_object("not_an_index")