Skip to content

Commit

Permalink
Added multi agent tests (#13)
Browse files Browse the repository at this point in the history
* Added multi agent tests

* Added changelog

* Added change to make tests backwards compatible with new httpx_mock version

* Update test_simple_agent.py

* Update test_simple_chat_agent.py

* Update test_tools.py

* Added changelog

* Reformatted with proper ruff version

* code review

* new ruff version

* Code review

* Code review

* Code review
  • Loading branch information
cszsol authored Oct 15, 2024
1 parent a98b9c0 commit 335d222
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Update readme
- Extra multi agent unit tests
- Extra unit tests for dependencies.py

### Removed
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ filterwarnings = [
"ignore:Degrees of freedom:RuntimeWarning",
"ignore:Exception ignored in:pytest.PytestUnraisableExceptionWarning",
"ignore:This API is in beta:langchain_core._api.beta_decorator.LangChainBetaWarning",
"ignore:The configuration option 'asyncio_default_fixture_loop_scope' is unset."
]

addopts = "--cov=src/ --cov=tests/ -v --cov-report=term-missing --durations=20 --no-cov-on-fail"
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ class FakeFuntionChatModel(GenericFakeChatModel):
def bind_tools(self, functions: list):
return self

def bind_functions(self, **kwargs):
return self

# If you need another fake response to use different tools,
# you can do in your test
# ```python
Expand Down
Empty file added tests/multi_agents/__init__.py
Empty file.
82 changes: 82 additions & 0 deletions tests/multi_agents/test_supervisor_multi_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from unittest.mock import AsyncMock, MagicMock, Mock

import pytest
from langchain_core.messages import HumanMessage, SystemMessage

from neuroagent.multi_agents.supervisor_multi_agent import AgentState
from src.neuroagent.multi_agents import SupervisorMultiAgent


def test_create_main_agent_initialization():
mock_llm = Mock()
bind_function_result = MagicMock()
bind_function_result.__ror__.return_value = {}
mock_llm.bind_functions.return_value = bind_function_result
data = {"llm": mock_llm, "agents": [("agent1",)]}

result = SupervisorMultiAgent.create_main_agent(data)
assert "main_agent" in result
assert "summarizer" in result


@pytest.mark.asyncio
async def test_agent_node():
mock_message = HumanMessage(
content="hello",
name="test_agent",
)

async def mock_ainvoke(_):
return {"messages": [mock_message]}

agent_state = Mock()
agent = Mock()
agent.ainvoke = mock_ainvoke

agent_node_test = await SupervisorMultiAgent.agent_node(
agent_state, agent, "test_agent"
)

assert isinstance(agent_node_test, dict)
assert "messages" in agent_node_test
assert len(agent_node_test["messages"]) == 1
assert agent_node_test["messages"][0].content == "hello"
assert agent_node_test["messages"][0].name == "test_agent"


@pytest.mark.asyncio
async def test_summarizer_node(fake_llm_with_tools):
fake_state = AgentState(
messages=[
HumanMessage(
content="What is the airspeed velocity of an unladen swallow?"
),
SystemMessage(content="11 m/s"),
]
)

mock_llm, _, _ = await anext(fake_llm_with_tools)
agent = SupervisorMultiAgent(agents=[("agent1", [])], llm=mock_llm)

mock_message = SystemMessage(
content="hello",
name="test_agent",
)

mock_summarizer = Mock()
mock_summarizer.ainvoke = AsyncMock()
mock_summarizer.ainvoke.return_value = mock_message
agent.summarizer = mock_summarizer
result = await agent.summarizer_node(fake_state)
assert result["messages"][0].content == "hello"


@pytest.mark.asyncio
async def test_create_graph(fake_llm_with_tools):
mock_llm, _, _ = await anext(fake_llm_with_tools)
agent = SupervisorMultiAgent(agents=[("agent1", [])], llm=mock_llm)
result = agent.create_graph()
nodes = result.nodes
assert "agent1" in nodes
assert "Supervisor" in nodes
assert "Summarizer" in nodes

0 comments on commit 335d222

Please sign in to comment.