Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added multi agent tests #13

Merged
merged 17 commits into from
Oct 15, 2024
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- Fixed a bug that prevented AsyncSqlite checkpoint to access the DB in streamed endpoints.
- Fixed a bug that caused some unit tests to fail due to a change in how httpx_mock works in version 0.32

## [0.1.0] - 19.09.2024

### Added
- Update readme
- Extra multi agent unit tests

### Removed
- Github action to create the docs.
Expand Down
2 changes: 2 additions & 0 deletions tests/agents/test_simple_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from neuroagent.agents import AgentOutput, AgentStep, SimpleChatAgent


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
cszsol marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.asyncio
async def test_arun(fake_llm_with_tools, httpx_mock):
llm, tools, fake_responses = await anext(fake_llm_with_tools)
Expand Down Expand Up @@ -64,6 +65,7 @@ async def test_arun(fake_llm_with_tools, httpx_mock):
assert len(messages_list) == 10


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
@pytest.mark.asyncio
async def test_astream(fake_llm_with_tools, httpx_mock):
llm, tools, fake_responses = await anext(fake_llm_with_tools)
Expand Down
Empty file added tests/multi_agents/__init__.py
Empty file.
99 changes: 99 additions & 0 deletions tests/multi_agents/test_supervisor_multi_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from unittest.mock import AsyncMock, MagicMock

import pytest
from langchain_core.language_models import GenericFakeChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from neuroagent.multi_agents.supervisor_multi_agent import AgentState

from src.neuroagent.multi_agents import SupervisorMultiAgent


def test_create_main_agent():
mock_llm = MagicMock()
bind_function_result = MagicMock()
bind_function_result.__ror__.return_value = {}
mock_llm.bind_functions.return_value = bind_function_result
data = {"llm": mock_llm, "agents": [("agent1",)]}
from src.neuroagent.multi_agents import SupervisorMultiAgent
cszsol marked this conversation as resolved.
Show resolved Hide resolved

result = SupervisorMultiAgent.create_main_agent(data)
assert "main_agent" in result
assert "summarizer" in result

cszsol marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.asyncio
async def test_agent_node():
mock_message = HumanMessage(
content="hello",
name="test_agent",
)

async def mock_ainvoke(_):
return {"messages": [mock_message]}

agent_state = MagicMock()
agent = MagicMock()
cszsol marked this conversation as resolved.
Show resolved Hide resolved
agent.ainvoke = mock_ainvoke
from src.neuroagent.multi_agents import SupervisorMultiAgent
cszsol marked this conversation as resolved.
Show resolved Hide resolved

agent_node_test = await SupervisorMultiAgent.agent_node(
agent_state, agent, "test_agent"
)

assert isinstance(agent_node_test, dict)
assert "messages" in agent_node_test
assert len(agent_node_test["messages"]) == 1
assert agent_node_test["messages"][0].content == "hello"
assert agent_node_test["messages"][0].name == "test_agent"


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
cszsol marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.asyncio
async def test_summarizer_node():
class FakeChatModel(GenericFakeChatModel):
def bind_tools(self, functions: list):
return self

def bind_functions(self, **kwargs):
return self

fake_state = AgentState(
messages=[
HumanMessage(
content="What is the airspeed velocity of an unladen swallow?"
),
SystemMessage(content="11 m/s"),
]
)

mock_llm = FakeChatModel(messages=iter([]))
agent = SupervisorMultiAgent(agents=[("agent1", [])], llm=mock_llm)

mock_message = SystemMessage(
content="hello",
name="test_agent",
)

mock_summarizer = MagicMock()
cszsol marked this conversation as resolved.
Show resolved Hide resolved
mock_summarizer.ainvoke = AsyncMock()
mock_summarizer.ainvoke.return_value = mock_message
agent.summarizer = mock_summarizer
result = await agent.summarizer_node(fake_state)
assert result["messages"][0].content == "hello"


def test_create_graph():
class FakeChatModel(GenericFakeChatModel):
def bind_tools(self, functions: list):
return self

def bind_functions(self, **kwargs):
return self
cszsol marked this conversation as resolved.
Show resolved Hide resolved

mock_llm = FakeChatModel(messages=iter([]))
agent = SupervisorMultiAgent(agents=[("agent1", [])], llm=mock_llm)
result = agent.create_graph()
nodes = result.nodes
assert "agent1" in nodes
assert "Supervisor" in nodes
assert "Summarizer" in nodes
1 change: 1 addition & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ async def test_get_kg_data_errors(httpx_mock):
)


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
@pytest.mark.asyncio
async def test_get_kg_data(httpx_mock):
url = "http://fake_url"
Expand Down
1 change: 1 addition & 0 deletions tests/tools/test_electrophys_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


class TestElectrophysTool:
@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
@pytest.mark.asyncio
async def test_arun(self, httpx_mock):
url = "http://fake_url"
Expand Down
1 change: 1 addition & 0 deletions tests/tools/test_traces_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


class TestTracesTool:
@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
@pytest.mark.asyncio
async def test_arun(self, httpx_mock, brain_region_json_path):
url = "http://fake_url"
Expand Down
Loading