Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing metadata to LangChainAgentNode._run_single #1710

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions morpheus/llm/nodes/langchain_agent_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,35 @@ def __init__(self, agent_executor: "AgentExecutor"):
def get_input_names(self):
return self._input_names

async def _run_single(self, **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]:
@staticmethod
def _is_all_lists(data: dict[str, typing.Any]) -> bool:
return all(isinstance(v, list) for v in data.values())

all_lists = all(isinstance(v, list) for v in kwargs.values())
@staticmethod
def _transform_dict_of_lists(data: dict[str, typing.Any]) -> list[dict[str, typing.Any]]:
return [dict(zip(data, t)) for t in zip(*data.values())]

async def _run_single(self, metadata: dict[str, typing.Any] = None, **kwargs) -> dict[str, typing.Any]:

all_lists = self._is_all_lists(kwargs)

# Check if all values are a list
if all_lists:

# Transform from dict[str, list[Any]] to list[dict[str, Any]]
input_list = [dict(zip(kwargs, t)) for t in zip(*kwargs.values())]
input_list = self._transform_dict_of_lists(kwargs)

# If all metadata values are lists of the same length and the same length as the input list
# then transform them the same way as the input list
if (metadata is not None and self._is_all_lists(metadata)
and all(len(v) == len(input_list) for v in metadata.values())):
metadata_list = self._transform_dict_of_lists(metadata)

else:
metadata_list = [metadata] * len(input_list)

# Run multiple again
results_async = [self._run_single(**x) for x in input_list]
results_async = [self._run_single(metadata=metadata_list[i], **x) for (i, x) in enumerate(input_list)]

results = await asyncio.gather(*results_async, return_exceptions=True)

Expand All @@ -67,7 +84,7 @@ async def _run_single(self, **kwargs: dict[str, typing.Any]) -> dict[str, typing

# We are not dealing with a list, so run single
try:
return await self._agent_executor.arun(**kwargs)
return await self._agent_executor.arun(metadata=metadata, **kwargs)
except Exception as e:
logger.exception("Error running agent: %s", e)
return e
Expand Down
127 changes: 124 additions & 3 deletions tests/llm/nodes/test_langchain_agent_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import typing
from operator import itemgetter
from unittest import mock

import pytest
from langchain.agents import AgentType
from langchain.agents import Tool
from langchain.agents import initialize_agent
from langchain.chat_models import ChatOpenAI # pylint: disable=no-name-in-module
from langchain.callbacks.manager import AsyncCallbackManagerForToolRun
from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain_community.chat_models import ChatOpenAI
from langchain_core.tools import BaseTool

from _utils.llm import execute_node
from _utils.llm import mk_mock_langchain_tool
Expand All @@ -42,12 +48,16 @@ def test_get_input_names(mock_agent_executor: mock.MagicMock):
"values,arun_return,expected_output,expected_calls",
[({
'prompt': "prompt1"
}, list(range(3)), list(range(3)), [mock.call(prompt="prompt1")]),
}, list(range(3)), list(range(3)), [mock.call(prompt="prompt1", metadata=None)]),
({
'a': ['b', 'c', 'd'], 'c': ['d', 'e', 'f'], 'e': ['f', 'g', 'h']
},
list(range(3)), [list(range(3))] * 3,
[mock.call(a='b', c='d', e='f'), mock.call(a='c', c='e', e='g'), mock.call(a='d', c='f', e='h')])],
[
mock.call(a='b', c='d', e='f', metadata=None),
mock.call(a='c', c='e', e='g', metadata=None),
mock.call(a='d', c='f', e='h', metadata=None)
])],
ids=["not-lists", "all-lists"])
def test_execute(
mock_agent_executor: mock.MagicMock,
Expand Down Expand Up @@ -143,3 +153,114 @@ def test_execute_error(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMoc

node = LangChainAgentNode(agent_executor=agent)
assert isinstance(execute_node(node, input="input1"), RuntimeError)


class MetadataSaverTool(BaseTool):
# The base class defines *args and **kwargs in the signature for _run and _arun requiring the arguments-differ
# pylint: disable=arguments-differ
name: str = "MetadataSaverTool"
description: str = "useful for when you need to know the name of a reptile"

saved_metadata: list[dict] = []

def _run(
self,
query: str,
run_manager: typing.Optional[CallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("This tool only supports async")

async def _arun(
self,
query: str,
run_manager: typing.Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
assert query is not None # avoiding unused-argument
assert run_manager is not None
self.saved_metadata.append(run_manager.metadata.copy())
return "frog"


@pytest.mark.parametrize("metadata",
[{
"morpheus": "unittest"
}, {
"morpheus": ["unittest"]
}, {
"morpheus": [f"unittest_{i}" for i in range(3)]
}],
ids=["single-metadata", "single-metadata-list", "multiple-metadata-list"])
def test_metadata(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], metadata: dict):
if isinstance(metadata['morpheus'], list):
num_meta = len(metadata['morpheus'])
input_data = [f"input_{i}" for i in range(num_meta)]
expected_result = [f"{input_val}: Yes!" for input_val in input_data]
expected_saved_metadata = [{"morpheus": meta} for meta in metadata['morpheus']]
response_per_input_counter = {input_val: 0 for input_val in input_data}
else:
num_meta = 1
input_data = "input_0"
expected_result = "input_0: Yes!"
expected_saved_metadata = [metadata.copy()]
response_per_input_counter = {input_data: 0}

check_tool_response = 'I should check Tool1\nAction: MetadataSaverTool\nAction Input: "name a reptile"'
final_response = 'Observation: Answer: Yes!\nI now know the final answer.\nFinal Answer: {}: Yes!'

# Tests the execute method of the LangChainAgentNode with a a mocked tools and chat completion
(_, mock_async_client) = mock_chat_completion

# Regex to find the actual prompt from the input which includes the REACT and tool description boilerplate
input_re = re.compile(r'^Question: (input_\d+)$', re.MULTILINE)

def mock_llm_chat(*_, messages, **__):
"""
This method avoids a race condition when running in aysnc mode over multiple inputs. Ensuring that the final
response is only given for an input after the initial check tool response.
"""

query = None
for msg in messages:
if msg['role'] == 'user':
query = msg['content']

assert query is not None

match = input_re.search(query)
assert match is not None

input_key = match.group(1)

call_count = response_per_input_counter[input_key]

if call_count == 0:
response = check_tool_response
else:
response = final_response.format(input_key)

response_per_input_counter[input_key] += 1

return mk_mock_openai_response([response])

mock_async_client.chat.completions.create.side_effect = mock_llm_chat

llm_chat = ChatOpenAI(model="fake-model", openai_api_key="fake-key")

metadata_saver_tool = MetadataSaverTool()

tools = [metadata_saver_tool]

agent = initialize_agent(tools,
llm_chat,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
handle_parsing_errors=True,
early_stopping_method="generate",
return_intermediate_steps=False)

node = LangChainAgentNode(agent_executor=agent)

assert execute_node(node, input=input_data, metadata=metadata) == expected_result

# Since we are running in async mode, we will need to sort saved metadata
assert sorted(metadata_saver_tool.saved_metadata, key=itemgetter('morpheus')) == expected_saved_metadata
2 changes: 1 addition & 1 deletion tests/llm/nodes/test_langchain_agent_node_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_pipeline(config: Config, dataset_cudf: DatasetManager, mock_agent_execu

mock_agent_executor.arun.return_value = 'frogs'
expected_df['response'] = 'frogs'
expected_calls = [mock.call(prompt=x) for x in expected_df['v3'].values_host]
expected_calls = [mock.call(prompt=x, metadata=None) for x in expected_df['v3'].values_host]

task_payload = {"task_type": "llm_engine", "task_dict": {"input_keys": ['v3']}}

Expand Down
11 changes: 6 additions & 5 deletions tests/llm/test_agents_simple_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import re
from unittest import mock

import langchain
import pytest
from langchain.agents import AgentType
from langchain.agents import initialize_agent
from langchain.agents import load_tools
from langchain.agents.tools import Tool
from langchain.utilities import serpapi
from langchain_community.llms import OpenAI # pylint: disable=no-name-in-module
from langchain_community.utilities import serpapi

import cudf

Expand All @@ -50,7 +50,7 @@ def questions_fixture():

def _build_agent_executor(model_name: str):

llm = langchain.OpenAI(model=model_name, temperature=0, cache=False)
llm = OpenAI(model=model_name, temperature=0, cache=False)

# Explicitly construct the serpapi tool, loading it via load_tools makes it too difficult to mock
tools = [
Expand Down Expand Up @@ -125,8 +125,9 @@ def test_agents_simple_pipe_integration_openai(config: Config, questions: list[s


@pytest.mark.usefixtures("openai", "restore_environ")
@mock.patch("langchain.utilities.serpapi.SerpAPIWrapper.aresults")
@mock.patch("langchain.OpenAI._agenerate", autospec=True) # autospec is needed as langchain will inspect the function
@mock.patch("langchain_community.utilities.serpapi.SerpAPIWrapper.aresults")
@mock.patch("langchain_community.llms.OpenAI._agenerate",
autospec=True) # autospec is needed as langchain will inspect the function
def test_agents_simple_pipe(mock_openai_agenerate: mock.AsyncMock,
mock_serpapi_aresults: mock.AsyncMock,
config: Config,
Expand Down
Loading