Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added multi agent tests #13

Merged
merged 17 commits into from
Oct 15, 2024
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Update readme
- Extra multi agent unit tests
- Extra unit tests for dependencies.py

### Removed
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ filterwarnings = [
"ignore:Degrees of freedom:RuntimeWarning",
"ignore:Exception ignored in:pytest.PytestUnraisableExceptionWarning",
"ignore:This API is in beta:langchain_core._api.beta_decorator.LangChainBetaWarning",
"ignore:The configuration option 'asyncio_default_fixture_loop_scope' is unset."
]

addopts = "--cov=src/ --cov=tests/ -v --cov-report=term-missing --durations=20 --no-cov-on-fail"
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ class FakeFuntionChatModel(GenericFakeChatModel):
def bind_tools(self, functions: list):
return self

def bind_functions(self, **kwargs):
return self

# If you need another fake response to use different tools,
# you can do in your test
# ```python
Expand Down
Empty file added tests/multi_agents/__init__.py
Empty file.
82 changes: 82 additions & 0 deletions tests/multi_agents/test_supervisor_multi_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from unittest.mock import AsyncMock, MagicMock, Mock

import pytest
from langchain_core.messages import HumanMessage, SystemMessage

from neuroagent.multi_agents.supervisor_multi_agent import AgentState
from src.neuroagent.multi_agents import SupervisorMultiAgent


def test_create_main_agent_initialization():
mock_llm = Mock()
bind_function_result = MagicMock()
bind_function_result.__ror__.return_value = {}
mock_llm.bind_functions.return_value = bind_function_result
data = {"llm": mock_llm, "agents": [("agent1",)]}

result = SupervisorMultiAgent.create_main_agent(data)
assert "main_agent" in result
assert "summarizer" in result

cszsol marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.asyncio
async def test_agent_node():
mock_message = HumanMessage(
content="hello",
name="test_agent",
)

async def mock_ainvoke(_):
return {"messages": [mock_message]}

agent_state = Mock()
agent = Mock()
agent.ainvoke = mock_ainvoke

agent_node_test = await SupervisorMultiAgent.agent_node(
agent_state, agent, "test_agent"
)

assert isinstance(agent_node_test, dict)
assert "messages" in agent_node_test
assert len(agent_node_test["messages"]) == 1
assert agent_node_test["messages"][0].content == "hello"
assert agent_node_test["messages"][0].name == "test_agent"


@pytest.mark.asyncio
async def test_summarizer_node(fake_llm_with_tools):
fake_state = AgentState(
messages=[
HumanMessage(
content="What is the airspeed velocity of an unladen swallow?"
),
SystemMessage(content="11 m/s"),
]
)

mock_llm, _, _ = await anext(fake_llm_with_tools)
agent = SupervisorMultiAgent(agents=[("agent1", [])], llm=mock_llm)

mock_message = SystemMessage(
content="hello",
name="test_agent",
)

mock_summarizer = Mock()
mock_summarizer.ainvoke = AsyncMock()
mock_summarizer.ainvoke.return_value = mock_message
agent.summarizer = mock_summarizer
result = await agent.summarizer_node(fake_state)
assert result["messages"][0].content == "hello"


@pytest.mark.asyncio
async def test_create_graph(fake_llm_with_tools):
mock_llm, _, _ = await anext(fake_llm_with_tools)
agent = SupervisorMultiAgent(agents=[("agent1", [])], llm=mock_llm)
result = agent.create_graph()
nodes = result.nodes
assert "agent1" in nodes
assert "Supervisor" in nodes
assert "Summarizer" in nodes
Loading