Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RetrieverNode Tests and Docs #1340

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions morpheus/llm/nodes/retriever_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@


class RetrieverNode(LLMNodeBase):
"""
Node for retrieving data from a vector database based on embeddings.

Parameters
----------
embedding : typing.Callable[[list[str]], typing.Coroutine[typing.Any, typing.Any, list[list[float]]]] | None
Callable function for generating vector embeddings. Default is None.
service : VectorDBResourceService
Vector database resource service for executing similarity searches.
similarity_search_kwargs : dict
Additional keyword arguments for the similarity search.
"""

def __init__(
self,
Expand All @@ -38,12 +50,33 @@ def __init__(
self._similarity_search_kwargs = similarity_search_kwargs

def get_input_names(self) -> list[str]:
"""
Get the input names for the RetrieverNode.

Returns
-------
list[str]
List of input names for the RetrieverNode.
"""
if (self._embedding is None):
return ["embedding"]

return ["query"]

async def execute(self, context: LLMContext):
"""
Execute the retrieval process based on the provided context.

Parameters
----------
context : LLMContext
Context object containing necessary information for execution.

Returns
-------
LLMContext
Updated context object after the execution.
"""

if (self._embedding is not None):
# Get the keys from the task
Expand Down
49 changes: 49 additions & 0 deletions tests/llm/nodes/test_llm_retriever_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing
from unittest import mock

import pytest

from _utils.llm import execute_node
from morpheus.llm import LLMNodeBase
from morpheus.llm.nodes.retriever_node import RetrieverNode


@pytest.mark.parametrize("embedding", [None, mock.AsyncMock()])
def test_constructor(embedding: typing.Callable | None):
mock_vdb_service = mock.MagicMock()
node = RetrieverNode(embedding=embedding, service=mock_vdb_service)
assert isinstance(node, LLMNodeBase)


@pytest.mark.parametrize("embedding, expected", [(None, ["embedding"]), (mock.AsyncMock(), ["query"])])
def test_get_input_names(embedding: typing.Callable | None, expected: list[str]):
mock_vdb_service = mock.MagicMock()
node = RetrieverNode(embedding=embedding, service=mock_vdb_service)
assert node.get_input_names() == expected


@pytest.mark.parametrize("embedding", [None, mock.AsyncMock(return_value=[[1.2, 2.3, 3.4], [4.5, 5.6, 6.7]])])
def test_execute(embedding: mock.AsyncMock | None):
mock_vdb_service = mock.MagicMock()
mock_vdb_service.similarity_search = mock.AsyncMock(return_value=[[1, 2, 3], [4, 5, 6]])

node = RetrieverNode(embedding=embedding, service=mock_vdb_service)

expected_output = [[1, 2, 3], [4, 5, 6]]

assert execute_node(node, query=["query"]) == expected_output
125 changes: 125 additions & 0 deletions tests/llm/nodes/test_llm_retriever_node_pipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import pytest

import cudf

from morpheus.config import Config
from morpheus.llm import LLMEngine
from morpheus.llm.nodes.extracter_node import ExtracterNode
from morpheus.llm.nodes.retriever_node import RetrieverNode
from morpheus.llm.task_handlers.simple_task_handler import SimpleTaskHandler
from morpheus.messages import ControlMessage
from morpheus.pipeline.linear_pipeline import LinearPipeline
from morpheus.service.vdb.milvus_vector_db_service import MilvusVectorDBResourceService
from morpheus.service.vdb.milvus_vector_db_service import MilvusVectorDBService
from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage
from morpheus.stages.llm.llm_engine_stage import LLMEngineStage
from morpheus.stages.output.in_memory_sink_stage import InMemorySinkStage
from morpheus.stages.preprocess.deserialize_stage import DeserializeStage


@pytest.fixture(scope="module", name="milvus_service")
def milvus_service_fixture(milvus_server_uri: str):
service = MilvusVectorDBService(uri=milvus_server_uri)
yield service


def _build_engine(vdb_service, **similarity_search_kwargs) -> LLMEngine:
mock_embedding = mock.AsyncMock(return_value=[[1.2, 2.3, 3.4], [4.5, 5.6, 6.7]])
engine = LLMEngine()
engine.add_node("extracter", node=ExtracterNode())
engine.add_node("retriever",
inputs=["/extracter"],
node=RetrieverNode(service=vdb_service, embedding=mock_embedding, **similarity_search_kwargs))
engine.add_task_handler(inputs=["/retriever"], handler=SimpleTaskHandler())

return engine


@pytest.mark.use_python
def test_pipeline(config: Config):
expected_output = [[1, 2, 3], [4, 5, 6]]

values = {'prompt': ["prompt1", "prompt2"]}
input_df = cudf.DataFrame(values)
expected_df = input_df.copy(deep=True)
expected_df["response"] = expected_output

task_payload = {"task_type": "llm_engine", "task_dict": {"input_keys": sorted(values.keys())}}

mock_vdb_service = mock.MagicMock()
mock_vdb_service.similarity_search = mock.AsyncMock(return_value=[[1, 2, 3], [4, 5, 6]])

pipe = LinearPipeline(config)
pipe.set_source(InMemorySourceStage(config, dataframes=[input_df]))
pipe.add_stage(
DeserializeStage(config, message_type=ControlMessage, task_type="llm_engine", task_payload=task_payload))
pipe.add_stage(LLMEngineStage(config, engine=_build_engine(vdb_service=mock_vdb_service)))
sink = pipe.add_stage(InMemorySinkStage(config))

pipe.run()

message = sink.get_messages()[0]
assert isinstance(message, ControlMessage)
actual_df = message.payload().df

# Using equals, as CompareDataFrameStage fails to compare.
assert actual_df.to_pandas().equals(expected_df.to_pandas())


@pytest.mark.use_python
def test_pipeline_with_milvus(config: Config,
milvus_service: MilvusVectorDBService,
idx_part_collection_config: dict,
milvus_data: list[dict]):

collection_name = "test_search_with_data_collection"
# Make sure to drop any existing collection from previous runs.
milvus_service.drop(collection_name)
# Create a collection.
milvus_service.create(collection_name, **idx_part_collection_config)
resource_service: MilvusVectorDBResourceService = milvus_service.load_resource(name=collection_name)
# Insert data into collection
resource_service.insert(milvus_data)

# Define a similarity_search filter.
expr = "age==26 or age==27"

values = {'prompt': ["prompt1", "prompt2"]}
input_df = cudf.DataFrame(values)
expected_df = input_df.copy(deep=True)
expected_df["response"] = [[{'0': 27, '1': 2}, {'0': 26, '1': 1}], [{'0': 27, '1': 2}, {'0': 26, '1': 1}]]

task_payload = {"task_type": "llm_engine", "task_dict": {"input_keys": sorted(values.keys())}}

pipe = LinearPipeline(config)
pipe.set_source(InMemorySourceStage(config, dataframes=[input_df]))
pipe.add_stage(
DeserializeStage(config, message_type=ControlMessage, task_type="llm_engine", task_payload=task_payload))
pipe.add_stage(LLMEngineStage(config, engine=_build_engine(vdb_service=resource_service, expr=expr)))
sink = pipe.add_stage(InMemorySinkStage(config))

pipe.run()

message = sink.get_messages()[0]
assert isinstance(message, ControlMessage)
actual_df = message.payload().df

# Using equals, as CompareDataFrameStage fails to compare.
assert actual_df.to_pandas().equals(expected_df.to_pandas())
Loading