Skip to content

Commit

Permalink
Allow passing metadata to LangChainAgentNode._run_single (#1710)
Browse files Browse the repository at this point in the history
* Allows passing arbitrary `metadata` in to the agent.
* Update a few imports to lower the number of deprecation warnings

Closes #1706

## By Submitting this PR I confirm:
- I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md).
- When the PR is ready for review, new or existing tests cover these changes.
- When the PR is ready for review, the documentation is up to date with these changes.

Authors:
  - David Gardner (https://github.com/dagardner-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #1710
  • Loading branch information
dagardner-nv authored Jun 1, 2024
1 parent 580be43 commit bb51e61
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 14 deletions.
27 changes: 22 additions & 5 deletions morpheus/llm/nodes/langchain_agent_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,35 @@ def __init__(self, agent_executor: "AgentExecutor"):
def get_input_names(self):
return self._input_names

async def _run_single(self, **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]:
@staticmethod
def _is_all_lists(data: dict[str, typing.Any]) -> bool:
return all(isinstance(v, list) for v in data.values())

all_lists = all(isinstance(v, list) for v in kwargs.values())
@staticmethod
def _transform_dict_of_lists(data: dict[str, typing.Any]) -> list[dict[str, typing.Any]]:
return [dict(zip(data, t)) for t in zip(*data.values())]

async def _run_single(self, metadata: dict[str, typing.Any] = None, **kwargs) -> dict[str, typing.Any]:

all_lists = self._is_all_lists(kwargs)

# Check if all values are a list
if all_lists:

# Transform from dict[str, list[Any]] to list[dict[str, Any]]
input_list = [dict(zip(kwargs, t)) for t in zip(*kwargs.values())]
input_list = self._transform_dict_of_lists(kwargs)

# If all metadata values are lists of the same length and the same length as the input list
# then transform them the same way as the input list
if (metadata is not None and self._is_all_lists(metadata)
and all(len(v) == len(input_list) for v in metadata.values())):
metadata_list = self._transform_dict_of_lists(metadata)

else:
metadata_list = [metadata] * len(input_list)

# Run multiple again
results_async = [self._run_single(**x) for x in input_list]
results_async = [self._run_single(metadata=metadata_list[i], **x) for (i, x) in enumerate(input_list)]

results = await asyncio.gather(*results_async, return_exceptions=True)

Expand All @@ -67,7 +84,7 @@ async def _run_single(self, **kwargs: dict[str, typing.Any]) -> dict[str, typing

# We are not dealing with a list, so run single
try:
return await self._agent_executor.arun(**kwargs)
return await self._agent_executor.arun(metadata=metadata, **kwargs)
except Exception as e:
logger.exception("Error running agent: %s", e)
return e
Expand Down
127 changes: 124 additions & 3 deletions tests/llm/nodes/test_langchain_agent_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import typing
from operator import itemgetter
from unittest import mock

import pytest
from langchain.agents import AgentType
from langchain.agents import Tool
from langchain.agents import initialize_agent
from langchain.chat_models import ChatOpenAI # pylint: disable=no-name-in-module
from langchain.callbacks.manager import AsyncCallbackManagerForToolRun
from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain_community.chat_models import ChatOpenAI
from langchain_core.tools import BaseTool

from _utils.llm import execute_node
from _utils.llm import mk_mock_langchain_tool
Expand All @@ -42,12 +48,16 @@ def test_get_input_names(mock_agent_executor: mock.MagicMock):
"values,arun_return,expected_output,expected_calls",
[({
'prompt': "prompt1"
}, list(range(3)), list(range(3)), [mock.call(prompt="prompt1")]),
}, list(range(3)), list(range(3)), [mock.call(prompt="prompt1", metadata=None)]),
({
'a': ['b', 'c', 'd'], 'c': ['d', 'e', 'f'], 'e': ['f', 'g', 'h']
},
list(range(3)), [list(range(3))] * 3,
[mock.call(a='b', c='d', e='f'), mock.call(a='c', c='e', e='g'), mock.call(a='d', c='f', e='h')])],
[
mock.call(a='b', c='d', e='f', metadata=None),
mock.call(a='c', c='e', e='g', metadata=None),
mock.call(a='d', c='f', e='h', metadata=None)
])],
ids=["not-lists", "all-lists"])
def test_execute(
mock_agent_executor: mock.MagicMock,
Expand Down Expand Up @@ -143,3 +153,114 @@ def test_execute_error(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMoc

node = LangChainAgentNode(agent_executor=agent)
assert isinstance(execute_node(node, input="input1"), RuntimeError)


class MetadataSaverTool(BaseTool):
# The base class defines *args and **kwargs in the signature for _run and _arun requiring the arguments-differ
# pylint: disable=arguments-differ
name: str = "MetadataSaverTool"
description: str = "useful for when you need to know the name of a reptile"

saved_metadata: list[dict] = []

def _run(
self,
query: str,
run_manager: typing.Optional[CallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("This tool only supports async")

async def _arun(
self,
query: str,
run_manager: typing.Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
assert query is not None # avoiding unused-argument
assert run_manager is not None
self.saved_metadata.append(run_manager.metadata.copy())
return "frog"


@pytest.mark.parametrize("metadata",
[{
"morpheus": "unittest"
}, {
"morpheus": ["unittest"]
}, {
"morpheus": [f"unittest_{i}" for i in range(3)]
}],
ids=["single-metadata", "single-metadata-list", "multiple-metadata-list"])
def test_metadata(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], metadata: dict):
if isinstance(metadata['morpheus'], list):
num_meta = len(metadata['morpheus'])
input_data = [f"input_{i}" for i in range(num_meta)]
expected_result = [f"{input_val}: Yes!" for input_val in input_data]
expected_saved_metadata = [{"morpheus": meta} for meta in metadata['morpheus']]
response_per_input_counter = {input_val: 0 for input_val in input_data}
else:
num_meta = 1
input_data = "input_0"
expected_result = "input_0: Yes!"
expected_saved_metadata = [metadata.copy()]
response_per_input_counter = {input_data: 0}

check_tool_response = 'I should check Tool1\nAction: MetadataSaverTool\nAction Input: "name a reptile"'
final_response = 'Observation: Answer: Yes!\nI now know the final answer.\nFinal Answer: {}: Yes!'

# Tests the execute method of the LangChainAgentNode with a a mocked tools and chat completion
(_, mock_async_client) = mock_chat_completion

# Regex to find the actual prompt from the input which includes the REACT and tool description boilerplate
input_re = re.compile(r'^Question: (input_\d+)$', re.MULTILINE)

def mock_llm_chat(*_, messages, **__):
"""
This method avoids a race condition when running in aysnc mode over multiple inputs. Ensuring that the final
response is only given for an input after the initial check tool response.
"""

query = None
for msg in messages:
if msg['role'] == 'user':
query = msg['content']

assert query is not None

match = input_re.search(query)
assert match is not None

input_key = match.group(1)

call_count = response_per_input_counter[input_key]

if call_count == 0:
response = check_tool_response
else:
response = final_response.format(input_key)

response_per_input_counter[input_key] += 1

return mk_mock_openai_response([response])

mock_async_client.chat.completions.create.side_effect = mock_llm_chat

llm_chat = ChatOpenAI(model="fake-model", openai_api_key="fake-key")

metadata_saver_tool = MetadataSaverTool()

tools = [metadata_saver_tool]

agent = initialize_agent(tools,
llm_chat,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
handle_parsing_errors=True,
early_stopping_method="generate",
return_intermediate_steps=False)

node = LangChainAgentNode(agent_executor=agent)

assert execute_node(node, input=input_data, metadata=metadata) == expected_result

# Since we are running in async mode, we will need to sort saved metadata
assert sorted(metadata_saver_tool.saved_metadata, key=itemgetter('morpheus')) == expected_saved_metadata
2 changes: 1 addition & 1 deletion tests/llm/nodes/test_langchain_agent_node_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_pipeline(config: Config, dataset_cudf: DatasetManager, mock_agent_execu

mock_agent_executor.arun.return_value = 'frogs'
expected_df['response'] = 'frogs'
expected_calls = [mock.call(prompt=x) for x in expected_df['v3'].values_host]
expected_calls = [mock.call(prompt=x, metadata=None) for x in expected_df['v3'].values_host]

task_payload = {"task_type": "llm_engine", "task_dict": {"input_keys": ['v3']}}

Expand Down
11 changes: 6 additions & 5 deletions tests/llm/test_agents_simple_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import re
from unittest import mock

import langchain
import pytest
from langchain.agents import AgentType
from langchain.agents import initialize_agent
from langchain.agents import load_tools
from langchain.agents.tools import Tool
from langchain.utilities import serpapi
from langchain_community.llms import OpenAI # pylint: disable=no-name-in-module
from langchain_community.utilities import serpapi

import cudf

Expand All @@ -50,7 +50,7 @@ def questions_fixture():

def _build_agent_executor(model_name: str):

llm = langchain.OpenAI(model=model_name, temperature=0, cache=False)
llm = OpenAI(model=model_name, temperature=0, cache=False)

# Explicitly construct the serpapi tool, loading it via load_tools makes it too difficult to mock
tools = [
Expand Down Expand Up @@ -125,8 +125,9 @@ def test_agents_simple_pipe_integration_openai(config: Config, questions: list[s


@pytest.mark.usefixtures("openai", "restore_environ")
@mock.patch("langchain.utilities.serpapi.SerpAPIWrapper.aresults")
@mock.patch("langchain.OpenAI._agenerate", autospec=True) # autospec is needed as langchain will inspect the function
@mock.patch("langchain_community.utilities.serpapi.SerpAPIWrapper.aresults")
@mock.patch("langchain_community.llms.OpenAI._agenerate",
autospec=True) # autospec is needed as langchain will inspect the function
def test_agents_simple_pipe(mock_openai_agenerate: mock.AsyncMock,
mock_serpapi_aresults: mock.AsyncMock,
config: Config,
Expand Down

0 comments on commit bb51e61

Please sign in to comment.