Skip to content

Commit

Permalink
Add tests of AgentsRoutine (#58)
Browse files Browse the repository at this point in the history
* Add tests of AgentsRoutine

* Small changes

* Fix comments

---------

Co-authored-by: Nicolas Frank <[email protected]>
  • Loading branch information
WonderPG and Nicolas Frank authored Dec 17, 2024
1 parent 29f8e8a commit 2149c2a
Show file tree
Hide file tree
Showing 16 changed files with 740 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Tool implementations without langchain or langgraph dependencies
- CRUDs.
- BlueNaas CRUD tools
- Tests of AgentsRoutine.
- Unit tests for database

### Fixed
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ convention = "numpy"

[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D"]
"swarm_copy_tests/*" = ["D"]

[tool.mypy]
mypy_path = "src"
Expand Down
4 changes: 2 additions & 2 deletions swarm_copy/run.py → swarm_copy/agent_routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ async def astream(
) -> AsyncIterator[str | Response]:
"""Stream the agent response."""
active_agent = agent
context_variables = copy.deepcopy(context_variables)

history = copy.deepcopy(messages)
init_len = len(messages)
is_streaming = False
Expand Down Expand Up @@ -251,7 +251,7 @@ async def astream(
stream=True,
)
async for chunk in completion: # type: ignore
delta = json.loads(chunk.choices[0].delta.json())
delta = json.loads(chunk.choices[0].delta.model_dump_json())

# Check for tool calls
if delta["tool_calls"]:
Expand Down
2 changes: 1 addition & 1 deletion swarm_copy/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from starlette.status import HTTP_401_UNAUTHORIZED

from swarm_copy.agent_routine import AgentsRoutine
from swarm_copy.app.app_utils import validate_project
from swarm_copy.app.config import Settings
from swarm_copy.app.database.sql_schemas import Threads
from swarm_copy.cell_types import CellTypesMeta
from swarm_copy.new_types import Agent
from swarm_copy.run import AgentsRoutine
from swarm_copy.tools import (
ElectrophysFeatureTool,
GetMorphoTool,
Expand Down
2 changes: 1 addition & 1 deletion swarm_copy/app/routers/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession

from swarm_copy.agent_routine import AgentsRoutine
from swarm_copy.app.database.db_utils import get_history, get_thread, save_history
from swarm_copy.app.database.sql_schemas import Threads
from swarm_copy.app.dependencies import (
Expand All @@ -17,7 +18,6 @@
get_user_id,
)
from swarm_copy.new_types import Agent, AgentRequest, AgentResponse
from swarm_copy.run import AgentsRoutine
from swarm_copy.stream import stream_agent_response

router = APIRouter(prefix="/qa", tags=["Run the agent"])
Expand Down
2 changes: 1 addition & 1 deletion swarm_copy/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from openai import AsyncOpenAI
from sqlalchemy.ext.asyncio import AsyncSession

from swarm_copy.agent_routine import AgentsRoutine
from swarm_copy.app.database.db_utils import save_history
from swarm_copy.new_types import Agent, Response
from swarm_copy.run import AgentsRoutine


async def stream_agent_response(
Expand Down
1 change: 1 addition & 0 deletions swarm_copy_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Sarm copy tests."""
89 changes: 67 additions & 22 deletions swarm_copy_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,20 @@

import json
from pathlib import Path
from typing import ClassVar

import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from pydantic import BaseModel, ConfigDict
from sqlalchemy import MetaData
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine

from swarm_copy.app.config import Settings
from swarm_copy.app.dependencies import get_kg_token, get_settings
from swarm_copy.app.dependencies import Agent, get_kg_token, get_settings
from swarm_copy.app.main import app


@pytest.fixture(name="settings")
def settings():
return Settings(
tools={
"literature": {
"url": "fake_literature_url",
},
},
knowledge_graph={
"base_url": "https://fake_url/api/nexus/v1",
},
openai={
"token": "fake_token",
},
keycloak={
"username": "fake_username",
"password": "fake_password",
},
)
from swarm_copy.tools.base_tool import BaseTool
from swarm_copy_tests.mock_client import MockOpenAIClient, create_mock_response


@pytest.fixture(name="app_client")
Expand Down Expand Up @@ -62,6 +45,68 @@ def client_fixture():
yield app_client
app.dependency_overrides.clear()

@pytest.fixture
def mock_openai_client():
"""Fake openai client."""
m = MockOpenAIClient()
m.set_response(
create_mock_response(
{"role": "assistant", "content": "sample response content"}
)
)
return m


@pytest.fixture(name="get_weather_tool")
def fake_tool():
"""Fake get weather tool."""

class FakeToolInput(BaseModel):
location: str

class FakeToolMetadata(
BaseModel
): # Should be a BaseMetadata but we don't want httpx client here
model_config = ConfigDict(extra="ignore", arbitrary_types_allowed=True)
planet: str | None = None

class FakeTool(BaseTool):
name: ClassVar[str] = "get_weather"
description: ClassVar[str] = "Great description"
metadata: FakeToolMetadata
input_schema: FakeToolInput

async def arun(self):
if self.metadata.planet:
return f"It's sunny today in {self.input_schema.location} from planet {self.metadata.planet}."
return "It's sunny today."

return FakeTool


@pytest.fixture
def agent_handoff_tool():
"""Fake agent handoff tool."""

class HandoffToolInput(BaseModel):
pass

class HandoffToolMetadata(
BaseModel
): # Should be a BaseMetadata but we don't want httpx client here
to_agent: Agent
model_config = ConfigDict(extra="ignore", arbitrary_types_allowed=True)

class HandoffTool(BaseTool):
name: ClassVar[str] = "agent_handoff_tool"
description: ClassVar[str] = "Handoff to another agent."
metadata: HandoffToolMetadata
input_schema: HandoffToolInput

async def arun(self):
return self.metadata.to_agent

return HandoffTool

@pytest.fixture(autouse=True, scope="session")
def dont_look_at_env_file():
Expand Down
68 changes: 68 additions & 0 deletions swarm_copy_tests/mock_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import json
from unittest.mock import AsyncMock

from openai.types.chat import ChatCompletionMessage
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)


def create_mock_response(message, function_calls=[], model="gpt-4o-mini"):
role = message.get("role", "assistant")
content = message.get("content", "")
tool_calls = (
[
ChatCompletionMessageToolCall(
id="mock_tc_id",
type="function",
function=Function(
name=call.get("name", ""),
arguments=json.dumps(call.get("args", {})),
),
)
for call in function_calls
]
if function_calls
else None
)

return ChatCompletion(
id="mock_cc_id",
created=1234567890,
model=model,
object="chat.completion",
choices=[
Choice(
message=ChatCompletionMessage(
role=role, content=content, tool_calls=tool_calls
),
finish_reason="stop",
index=0,
)
],
)


class MockOpenAIClient:
def __init__(self):
self.chat = AsyncMock()
self.chat.completions = AsyncMock()

def set_response(self, response: ChatCompletion):
"""
Set the mock to return a specific response.
:param response: A ChatCompletion response to return.
"""
self.chat.completions.create.return_value = response

def set_sequential_responses(self, responses: list[ChatCompletion]):
"""
Set the mock to return different responses sequentially.
:param responses: A list of ChatCompletion responses to return in order.
"""
self.chat.completions.create.side_effect = responses

def assert_create_called_with(self, **kwargs):
self.chat.completions.create.assert_called_with(**kwargs)
Loading

0 comments on commit 2149c2a

Please sign in to comment.