Skip to content

Commit

Permalink
Merge pull request #1327 from dagardner-nv/david-fea-sherlock-lang-ch…
Browse files Browse the repository at this point in the history
…ain-agent

Docstrings and tests for LangChainAgentNode
  • Loading branch information
dagardner-nv authored Oct 31, 2023
2 parents 3c4802e + b6168e5 commit 4692529
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 5 deletions.
20 changes: 15 additions & 5 deletions morpheus/llm/nodes/langchain_agent_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,28 @@

import asyncio
import logging

from langchain.agents import AgentExecutor
import typing

from morpheus.llm import LLMContext
from morpheus.llm import LLMNodeBase

logger = logging.getLogger(__name__)

if typing.TYPE_CHECKING:
from langchain.agents import AgentExecutor


class LangChainAgentNode(LLMNodeBase):
"""
Executes a LangChain agent in an LLMEngine
Parameters
----------
agent_executor : AgentExecutor
The agent executor to use to execute.
"""

def __init__(self, agent_executor: AgentExecutor):
def __init__(self, agent_executor: "AgentExecutor"):
super().__init__()

self._agent_executor = agent_executor
Expand All @@ -35,7 +45,7 @@ def __init__(self, agent_executor: AgentExecutor):
def get_input_names(self):
return self._input_names

async def _run_single(self, **kwargs):
async def _run_single(self, **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]:

all_lists = all(isinstance(v, list) for v in kwargs.values())

Expand All @@ -58,7 +68,7 @@ async def _run_single(self, **kwargs):
# We are not dealing with a list, so run single
return await self._agent_executor.arun(**kwargs)

async def execute(self, context: LLMContext):
async def execute(self, context: LLMContext) -> LLMContext:

input_dict = context.get_inputs()

Expand Down
9 changes: 9 additions & 0 deletions tests/llm/nodes/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,12 @@ def mock_llm_client_fixture():
mock_client.return_value = mock_client
mock_client.generate_batch_async = mock.AsyncMock()
return mock_client


@pytest.fixture(name="mock_agent_executor")
def mock_agent_executor_fixture():
mock_agent_ex = mock.MagicMock()
mock_agent_ex.return_value = mock_agent_ex
mock_agent_ex.input_keys = ["prompt"]
mock_agent_ex.arun = mock.AsyncMock()
return mock_agent_ex
57 changes: 57 additions & 0 deletions tests/llm/nodes/test_langchain_agent_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import pytest

from _utils.llm import execute_node
from morpheus.llm import LLMNodeBase
from morpheus.llm.nodes.langchain_agent_node import LangChainAgentNode


def test_constructor(mock_agent_executor: mock.MagicMock):
node = LangChainAgentNode(agent_executor=mock_agent_executor)
assert isinstance(node, LLMNodeBase)


def test_get_input_names(mock_agent_executor: mock.MagicMock):
node = LangChainAgentNode(agent_executor=mock_agent_executor)
assert node.get_input_names() == ["prompt"]


@pytest.mark.parametrize(
"values,arun_return,expected_output,expected_calls",
[({
'prompt': "prompt1"
}, list(range(3)), list(range(3)), [mock.call(prompt="prompt1")]),
({
'a': ['b', 'c', 'd'], 'c': ['d', 'e', 'f'], 'e': ['f', 'g', 'h']
},
list(range(3)), [list(range(3))] * 3,
[mock.call(a='b', c='d', e='f'), mock.call(a='c', c='e', e='g'), mock.call(a='d', c='f', e='h')])],
ids=["not-lists", "all-lists"])
def test_execute(
mock_agent_executor: mock.MagicMock,
values: dict,
arun_return: list,
expected_output: list,
expected_calls: list[mock.call],
):
mock_agent_executor.arun.return_value = arun_return

node = LangChainAgentNode(agent_executor=mock_agent_executor)
assert execute_node(node, **values) == expected_output
mock_agent_executor.arun.assert_has_calls(expected_calls)
65 changes: 65 additions & 0 deletions tests/llm/nodes/test_langchain_agent_node_pipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import pytest

from _utils import assert_results
from _utils.dataset_manager import DatasetManager
from morpheus.config import Config
from morpheus.llm import LLMEngine
from morpheus.llm.nodes.extracter_node import ExtracterNode
from morpheus.llm.nodes.langchain_agent_node import LangChainAgentNode
from morpheus.llm.task_handlers.simple_task_handler import SimpleTaskHandler
from morpheus.messages import ControlMessage
from morpheus.pipeline.linear_pipeline import LinearPipeline
from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage
from morpheus.stages.llm.llm_engine_stage import LLMEngineStage
from morpheus.stages.output.compare_dataframe_stage import CompareDataFrameStage
from morpheus.stages.preprocess.deserialize_stage import DeserializeStage


def _build_engine(mock_agent_executor: mock.MagicMock) -> LLMEngine:
engine = LLMEngine()
engine.add_node("extracter", node=ExtracterNode())
engine.add_node("chain", inputs=["/extracter"], node=LangChainAgentNode(agent_executor=mock_agent_executor))
engine.add_task_handler(inputs=["/chain"], handler=SimpleTaskHandler())

return engine


@pytest.mark.use_python
def test_pipeline(config: Config, dataset_cudf: DatasetManager, mock_agent_executor: mock.MagicMock):
input_df = dataset_cudf["filter_probs.csv"]
expected_df = input_df.copy(deep=True)

mock_agent_executor.arun.return_value = 'frogs'
expected_df['response'] = 'frogs'
expected_calls = [mock.call(prompt=x) for x in expected_df['v3'].values_host]

task_payload = {"task_type": "llm_engine", "task_dict": {"input_keys": ['v3']}}

pipe = LinearPipeline(config)
pipe.set_source(InMemorySourceStage(config, dataframes=[input_df]))
pipe.add_stage(
DeserializeStage(config, message_type=ControlMessage, task_type="llm_engine", task_payload=task_payload))
pipe.add_stage(LLMEngineStage(config, engine=_build_engine(mock_agent_executor)))
sink = pipe.add_stage(CompareDataFrameStage(config, compare_df=expected_df))

pipe.run()

assert_results(sink.get_results())
mock_agent_executor.arun.assert_has_calls(expected_calls)

0 comments on commit 4692529

Please sign in to comment.