From 2149c2a3f0e4bce45400f7294e7c7ee6809aad30 Mon Sep 17 00:00:00 2001 From: Nicolas Frank <58003267+WonderPG@users.noreply.github.com> Date: Tue, 17 Dec 2024 15:10:19 +0100 Subject: [PATCH 1/2] Add tests of AgentsRoutine (#58) * Add tests of AgentsRoutine * Small changes * Fix comments --------- Co-authored-by: Nicolas Frank --- CHANGELOG.md | 1 + pyproject.toml | 1 + swarm_copy/{run.py => agent_routine.py} | 4 +- swarm_copy/app/dependencies.py | 2 +- swarm_copy/app/routers/qa.py | 2 +- swarm_copy/stream.py | 2 +- swarm_copy_tests/__init__.py | 1 + swarm_copy_tests/conftest.py | 89 ++- swarm_copy_tests/mock_client.py | 68 ++ swarm_copy_tests/test_agent_routine.py | 585 ++++++++++++++++++ tests/agents/test_simple_agent.py | 4 +- tests/agents/test_simple_chat_agent.py | 4 +- tests/app/database/test_threads.py | 4 +- tests/app/database/test_tools.py | 4 +- tests/conftest.py | 4 +- .../test_supervisor_multi_agent.py | 4 +- 16 files changed, 740 insertions(+), 39 deletions(-) rename swarm_copy/{run.py => agent_routine.py} (98%) create mode 100644 swarm_copy_tests/__init__.py create mode 100644 swarm_copy_tests/mock_client.py create mode 100644 swarm_copy_tests/test_agent_routine.py diff --git a/CHANGELOG.md b/CHANGELOG.md index bceac7e..34dc6cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Tool implementations without langchain or langgraph dependencies - CRUDs. - BlueNaas CRUD tools +- Tests of AgentsRoutine. - Unit tests for database ### Fixed diff --git a/pyproject.toml b/pyproject.toml index d0730c8..c27af47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ convention = "numpy" [tool.ruff.lint.per-file-ignores] "tests/*" = ["D"] +"swarm_copy_tests/*" = ["D"] [tool.mypy] mypy_path = "src" diff --git a/swarm_copy/run.py b/swarm_copy/agent_routine.py similarity index 98% rename from swarm_copy/run.py rename to swarm_copy/agent_routine.py index 17fb9cc..4797545 100644 --- a/swarm_copy/run.py +++ b/swarm_copy/agent_routine.py @@ -222,7 +222,7 @@ async def astream( ) -> AsyncIterator[str | Response]: """Stream the agent response.""" active_agent = agent - context_variables = copy.deepcopy(context_variables) + history = copy.deepcopy(messages) init_len = len(messages) is_streaming = False @@ -251,7 +251,7 @@ async def astream( stream=True, ) async for chunk in completion: # type: ignore - delta = json.loads(chunk.choices[0].delta.json()) + delta = json.loads(chunk.choices[0].delta.model_dump_json()) # Check for tool calls if delta["tool_calls"]: diff --git a/swarm_copy/app/dependencies.py b/swarm_copy/app/dependencies.py index f021c3e..c64223a 100644 --- a/swarm_copy/app/dependencies.py +++ b/swarm_copy/app/dependencies.py @@ -13,12 +13,12 @@ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession from starlette.status import HTTP_401_UNAUTHORIZED +from swarm_copy.agent_routine import AgentsRoutine from swarm_copy.app.app_utils import validate_project from swarm_copy.app.config import Settings from swarm_copy.app.database.sql_schemas import Threads from swarm_copy.cell_types import CellTypesMeta from swarm_copy.new_types import Agent -from swarm_copy.run import AgentsRoutine from swarm_copy.tools import ( ElectrophysFeatureTool, GetMorphoTool, diff --git a/swarm_copy/app/routers/qa.py b/swarm_copy/app/routers/qa.py index f7b2d3a..c04962c 100644 --- a/swarm_copy/app/routers/qa.py +++ b/swarm_copy/app/routers/qa.py @@ -7,6 +7,7 @@ from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession +from swarm_copy.agent_routine import AgentsRoutine from swarm_copy.app.database.db_utils import get_history, get_thread, save_history from swarm_copy.app.database.sql_schemas import Threads from swarm_copy.app.dependencies import ( @@ -17,7 +18,6 @@ get_user_id, ) from swarm_copy.new_types import Agent, AgentRequest, AgentResponse -from swarm_copy.run import AgentsRoutine from swarm_copy.stream import stream_agent_response router = APIRouter(prefix="/qa", tags=["Run the agent"]) diff --git a/swarm_copy/stream.py b/swarm_copy/stream.py index 75b10fe..62fd638 100644 --- a/swarm_copy/stream.py +++ b/swarm_copy/stream.py @@ -5,9 +5,9 @@ from openai import AsyncOpenAI from sqlalchemy.ext.asyncio import AsyncSession +from swarm_copy.agent_routine import AgentsRoutine from swarm_copy.app.database.db_utils import save_history from swarm_copy.new_types import Agent, Response -from swarm_copy.run import AgentsRoutine async def stream_agent_response( diff --git a/swarm_copy_tests/__init__.py b/swarm_copy_tests/__init__.py new file mode 100644 index 0000000..480d94b --- /dev/null +++ b/swarm_copy_tests/__init__.py @@ -0,0 +1 @@ +"""Sarm copy tests.""" diff --git a/swarm_copy_tests/conftest.py b/swarm_copy_tests/conftest.py index 54ee423..31deefb 100644 --- a/swarm_copy_tests/conftest.py +++ b/swarm_copy_tests/conftest.py @@ -2,37 +2,20 @@ import json from pathlib import Path +from typing import ClassVar import pytest import pytest_asyncio from fastapi.testclient import TestClient +from pydantic import BaseModel, ConfigDict from sqlalchemy import MetaData from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from swarm_copy.app.config import Settings -from swarm_copy.app.dependencies import get_kg_token, get_settings +from swarm_copy.app.dependencies import Agent, get_kg_token, get_settings from swarm_copy.app.main import app - - -@pytest.fixture(name="settings") -def settings(): - return Settings( - tools={ - "literature": { - "url": "fake_literature_url", - }, - }, - knowledge_graph={ - "base_url": "https://fake_url/api/nexus/v1", - }, - openai={ - "token": "fake_token", - }, - keycloak={ - "username": "fake_username", - "password": "fake_password", - }, - ) +from swarm_copy.tools.base_tool import BaseTool +from swarm_copy_tests.mock_client import MockOpenAIClient, create_mock_response @pytest.fixture(name="app_client") @@ -62,6 +45,68 @@ def client_fixture(): yield app_client app.dependency_overrides.clear() +@pytest.fixture +def mock_openai_client(): + """Fake openai client.""" + m = MockOpenAIClient() + m.set_response( + create_mock_response( + {"role": "assistant", "content": "sample response content"} + ) + ) + return m + + +@pytest.fixture(name="get_weather_tool") +def fake_tool(): + """Fake get weather tool.""" + + class FakeToolInput(BaseModel): + location: str + + class FakeToolMetadata( + BaseModel + ): # Should be a BaseMetadata but we don't want httpx client here + model_config = ConfigDict(extra="ignore", arbitrary_types_allowed=True) + planet: str | None = None + + class FakeTool(BaseTool): + name: ClassVar[str] = "get_weather" + description: ClassVar[str] = "Great description" + metadata: FakeToolMetadata + input_schema: FakeToolInput + + async def arun(self): + if self.metadata.planet: + return f"It's sunny today in {self.input_schema.location} from planet {self.metadata.planet}." + return "It's sunny today." + + return FakeTool + + +@pytest.fixture +def agent_handoff_tool(): + """Fake agent handoff tool.""" + + class HandoffToolInput(BaseModel): + pass + + class HandoffToolMetadata( + BaseModel + ): # Should be a BaseMetadata but we don't want httpx client here + to_agent: Agent + model_config = ConfigDict(extra="ignore", arbitrary_types_allowed=True) + + class HandoffTool(BaseTool): + name: ClassVar[str] = "agent_handoff_tool" + description: ClassVar[str] = "Handoff to another agent." + metadata: HandoffToolMetadata + input_schema: HandoffToolInput + + async def arun(self): + return self.metadata.to_agent + + return HandoffTool @pytest.fixture(autouse=True, scope="session") def dont_look_at_env_file(): diff --git a/swarm_copy_tests/mock_client.py b/swarm_copy_tests/mock_client.py new file mode 100644 index 0000000..69d9bb6 --- /dev/null +++ b/swarm_copy_tests/mock_client.py @@ -0,0 +1,68 @@ +import json +from unittest.mock import AsyncMock + +from openai.types.chat import ChatCompletionMessage +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + Function, +) + + +def create_mock_response(message, function_calls=[], model="gpt-4o-mini"): + role = message.get("role", "assistant") + content = message.get("content", "") + tool_calls = ( + [ + ChatCompletionMessageToolCall( + id="mock_tc_id", + type="function", + function=Function( + name=call.get("name", ""), + arguments=json.dumps(call.get("args", {})), + ), + ) + for call in function_calls + ] + if function_calls + else None + ) + + return ChatCompletion( + id="mock_cc_id", + created=1234567890, + model=model, + object="chat.completion", + choices=[ + Choice( + message=ChatCompletionMessage( + role=role, content=content, tool_calls=tool_calls + ), + finish_reason="stop", + index=0, + ) + ], + ) + + +class MockOpenAIClient: + def __init__(self): + self.chat = AsyncMock() + self.chat.completions = AsyncMock() + + def set_response(self, response: ChatCompletion): + """ + Set the mock to return a specific response. + :param response: A ChatCompletion response to return. + """ + self.chat.completions.create.return_value = response + + def set_sequential_responses(self, responses: list[ChatCompletion]): + """ + Set the mock to return different responses sequentially. + :param responses: A list of ChatCompletion responses to return in order. + """ + self.chat.completions.create.side_effect = responses + + def assert_create_called_with(self, **kwargs): + self.chat.completions.create.assert_called_with(**kwargs) diff --git a/swarm_copy_tests/test_agent_routine.py b/swarm_copy_tests/test_agent_routine.py new file mode 100644 index 0000000..37eb119 --- /dev/null +++ b/swarm_copy_tests/test_agent_routine.py @@ -0,0 +1,585 @@ +import json +from typing import AsyncIterator +from unittest.mock import patch + +import pytest +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) + +from swarm_copy.agent_routine import AgentsRoutine +from swarm_copy.new_types import Agent, Response, Result +from swarm_copy_tests.mock_client import create_mock_response + + +class TestAgentsRoutine: + @pytest.mark.asyncio + async def test_get_chat_completion_simple_message(self, mock_openai_client): + routine = AgentsRoutine(client=mock_openai_client) + + agent = Agent() + response = await routine.get_chat_completion( + agent=agent, + history=[{"role": "user", "content": "Hello !"}], + context_variables={}, + model_override=None, + ) + mock_openai_client.assert_create_called_with( + **{ + "model": "gpt-4o-mini", + "messages": [ + {"role": "system", "content": "You are a helpful agent."}, + {"role": "user", "content": "Hello !"}, + ], + "tools": None, + "tool_choice": None, + "stream": False, + } + ) + + assert response.choices[0].message.role == "assistant" + assert response.choices[0].message.content == "sample response content" + + @pytest.mark.asyncio + async def test_get_chat_completion_callable_sys_prompt(self, mock_openai_client): + routine = AgentsRoutine(client=mock_openai_client) + + def agent_instruction(context_variables): + twng = context_variables.get("twng") + mrt = context_variables.get("mrt") + return f"This is your new instructions with {twng} and {mrt}." + + agent = Agent(instructions=agent_instruction) + response = await routine.get_chat_completion( + agent=agent, + history=[{"role": "user", "content": "Hello !"}], + context_variables={"mrt": "Great mrt", "twng": "Bad twng"}, + model_override=None, + ) + mock_openai_client.assert_create_called_with( + **{ + "model": "gpt-4o-mini", + "messages": [ + { + "role": "system", + "content": "This is your new instructions with Bad twng and Great mrt.", + }, + {"role": "user", "content": "Hello !"}, + ], + "tools": None, + "tool_choice": None, + "stream": False, + } + ) + + assert response.choices[0].message.role == "assistant" + assert response.choices[0].message.content == "sample response content" + + @pytest.mark.asyncio + async def test_get_chat_completion_tools( + self, mock_openai_client, get_weather_tool + ): + routine = AgentsRoutine(client=mock_openai_client) + + agent = Agent(tools=[get_weather_tool]) + response = await routine.get_chat_completion( + agent=agent, + history=[{"role": "user", "content": "Hello !"}], + context_variables={}, + model_override=None, + ) + mock_openai_client.assert_create_called_with( + **{ + "model": "gpt-4o-mini", + "messages": [ + {"role": "system", "content": "You are a helpful agent."}, + {"role": "user", "content": "Hello !"}, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Great description", + "strict": False, + "parameters": { + "properties": { + "location": {"title": "Location", "type": "string"} + }, + "required": ["location"], + "title": "FakeToolInput", + "type": "object", + "additionalProperties": False, + }, + }, + } + ], + "tool_choice": None, + "stream": False, + "parallel_tool_calls": True, + } + ) + + assert response.choices[0].message.role == "assistant" + assert response.choices[0].message.content == "sample response content" + + def test_handle_function_result(self, mock_openai_client): + routine = AgentsRoutine(client=mock_openai_client) + + # Raw result is already a result + raw_result = Result(value="Nice weather") + result = routine.handle_function_result(raw_result) + assert result == raw_result + + # Raw result is an agent for handoff + raw_result = Agent(name="Test agent 2") + result = routine.handle_function_result(raw_result) + assert result == Result( + value=json.dumps({"assistant": raw_result.name}), agent=raw_result + ) + + # Raw result is a tool output (Typically dict/list dict) + raw_result = [{"result_1": "Great result", "result_2": "Bad result"}] + result = routine.handle_function_result(raw_result) + assert result == Result(value=str(raw_result)) + + @pytest.mark.asyncio + async def test_execute_tool_calls_simple( + self, mock_openai_client, get_weather_tool, agent_handoff_tool + ): + routine = AgentsRoutine(client=mock_openai_client) + + mock_openai_client.set_response( + create_mock_response( + message={"role": "assistant", "content": ""}, + function_calls=[ + {"name": "get_weather", "args": {"location": "Geneva"}} + ], + ), + ) + agent = Agent(tools=[get_weather_tool, agent_handoff_tool]) + context_variables = {} + + tool_call_message = await routine.get_chat_completion( + agent, + history=[{"role": "user", "content": "Hello"}], + context_variables=context_variables, + model_override=None, + ) + tool_calls = tool_call_message.choices[0].message.tool_calls + tool_calls_result = await routine.execute_tool_calls( + tool_calls=tool_calls, + tools=agent.tools, + context_variables=context_variables, + ) + assert isinstance(tool_calls_result, Response) + assert tool_calls_result.messages == [ + { + "role": "tool", + "tool_call_id": tool_calls[0].id, + "tool_name": "get_weather", + "content": "It's sunny today.", + } + ] + assert tool_calls_result.agent is None + assert tool_calls_result.context_variables == context_variables + + @pytest.mark.asyncio + async def test_execute_multiple_tool_calls( + self, mock_openai_client, get_weather_tool, agent_handoff_tool + ): + routine = AgentsRoutine(client=mock_openai_client) + + mock_openai_client.set_response( + create_mock_response( + message={"role": "assistant", "content": ""}, + function_calls=[ + {"name": "get_weather", "args": {"location": "Geneva"}}, + {"name": "get_weather", "args": {"location": "Lausanne"}}, + ], + ), + ) + agent = Agent(tools=[get_weather_tool, agent_handoff_tool]) + context_variables = {"planet": "Earth"} + + tool_call_message = await routine.get_chat_completion( + agent, + history=[{"role": "user", "content": "Hello"}], + context_variables=context_variables, + model_override=None, + ) + tool_calls = tool_call_message.choices[0].message.tool_calls + tool_calls_result = await routine.execute_tool_calls( + tool_calls=tool_calls, + tools=agent.tools, + context_variables=context_variables, + ) + + assert isinstance(tool_calls_result, Response) + assert tool_calls_result.messages == [ + { + "role": "tool", + "tool_call_id": tool_calls[0].id, + "tool_name": "get_weather", + "content": "It's sunny today in Geneva from planet Earth.", + }, + { + "role": "tool", + "tool_call_id": tool_calls[1].id, + "tool_name": "get_weather", + "content": "It's sunny today in Lausanne from planet Earth.", + }, + ] + assert tool_calls_result.agent is None + assert tool_calls_result.context_variables == context_variables + + @pytest.mark.asyncio + async def test_execute_tool_calls_handoff( + self, mock_openai_client, get_weather_tool, agent_handoff_tool + ): + routine = AgentsRoutine(client=mock_openai_client) + + mock_openai_client.set_response( + create_mock_response( + message={"role": "assistant", "content": ""}, + function_calls=[{"name": "agent_handoff_tool", "args": {}}], + ), + ) + agent_1 = Agent(name="Test agent 1", tools=[agent_handoff_tool]) + agent_2 = Agent( + name="Test agent 2", tools=[get_weather_tool, agent_handoff_tool] + ) + context_variables = {"to_agent": agent_2} + + tool_call_message = await routine.get_chat_completion( + agent_1, + history=[{"role": "user", "content": "Hello"}], + context_variables=context_variables, + model_override=None, + ) + tool_calls = tool_call_message.choices[0].message.tool_calls + tool_calls_result = await routine.execute_tool_calls( + tool_calls=tool_calls, + tools=agent_1.tools, + context_variables=context_variables, + ) + + assert isinstance(tool_calls_result, Response) + assert tool_calls_result.messages == [ + { + "role": "tool", + "tool_call_id": tool_calls[0].id, + "tool_name": "agent_handoff_tool", + "content": json.dumps({"assistant": agent_2.name}), + } + ] + assert tool_calls_result.agent == agent_2 + assert tool_calls_result.context_variables == context_variables + + @pytest.mark.asyncio + async def test_handle_tool_call_simple( + self, mock_openai_client, get_weather_tool, agent_handoff_tool + ): + routine = AgentsRoutine(client=mock_openai_client) + + mock_openai_client.set_response( + create_mock_response( + message={"role": "assistant", "content": ""}, + function_calls=[ + {"name": "get_weather", "args": {"location": "Geneva"}} + ], + ), + ) + agent = Agent(tools=[get_weather_tool, agent_handoff_tool]) + context_variables = {} + + tool_call_message = await routine.get_chat_completion( + agent, + history=[{"role": "user", "content": "Hello"}], + context_variables=context_variables, + model_override=None, + ) + tool_call = tool_call_message.choices[0].message.tool_calls[0] + tool_call_result = await routine.handle_tool_call( + tool_call=tool_call, tools=agent.tools, context_variables=context_variables + ) + + assert tool_call_result == ( + { + "role": "tool", + "tool_call_id": tool_call.id, + "tool_name": "get_weather", + "content": "It's sunny today.", + }, + None, + ) + + @pytest.mark.asyncio + async def test_handle_tool_call_context_var( + self, mock_openai_client, get_weather_tool, agent_handoff_tool + ): + routine = AgentsRoutine(client=mock_openai_client) + + mock_openai_client.set_response( + create_mock_response( + message={"role": "assistant", "content": ""}, + function_calls=[ + {"name": "get_weather", "args": {"location": "Geneva"}}, + ], + ), + ) + agent = Agent(tools=[get_weather_tool, agent_handoff_tool]) + context_variables = {"planet": "Earth"} + + tool_call_message = await routine.get_chat_completion( + agent, + history=[{"role": "user", "content": "Hello"}], + context_variables=context_variables, + model_override=None, + ) + tool_call = tool_call_message.choices[0].message.tool_calls[0] + tool_calls_result = await routine.handle_tool_call( + tool_call=tool_call, tools=agent.tools, context_variables=context_variables + ) + + assert tool_calls_result == ( + { + "role": "tool", + "tool_call_id": tool_call.id, + "tool_name": "get_weather", + "content": "It's sunny today in Geneva from planet Earth.", + }, + None, + ) + + @pytest.mark.asyncio + async def test_handle_tool_call_handoff( + self, mock_openai_client, get_weather_tool, agent_handoff_tool + ): + routine = AgentsRoutine(client=mock_openai_client) + + mock_openai_client.set_response( + create_mock_response( + message={"role": "assistant", "content": ""}, + function_calls=[{"name": "agent_handoff_tool", "args": {}}], + ), + ) + agent_1 = Agent(name="Test agent 1", tools=[agent_handoff_tool]) + agent_2 = Agent( + name="Test agent 2", tools=[get_weather_tool, agent_handoff_tool] + ) + context_variables = {"to_agent": agent_2} + + tool_call_message = await routine.get_chat_completion( + agent_1, + history=[{"role": "user", "content": "Hello"}], + context_variables=context_variables, + model_override=None, + ) + tool_call = tool_call_message.choices[0].message.tool_calls[0] + tool_calls_result = await routine.handle_tool_call( + tool_call=tool_call, + tools=agent_1.tools, + context_variables=context_variables, + ) + + assert tool_calls_result == ( + { + "role": "tool", + "tool_call_id": tool_call.id, + "tool_name": "agent_handoff_tool", + "content": json.dumps({"assistant": agent_2.name}), + }, + agent_2, + ) + + @pytest.mark.asyncio + async def test_arun(self, mock_openai_client, get_weather_tool, agent_handoff_tool): + agent_1 = Agent(name="Test Agent", tools=[agent_handoff_tool]) + agent_2 = Agent(name="Test Agent", tools=[get_weather_tool]) + messages = [ + {"role": "user", "content": "What's the weather like in San Francisco?"} + ] + context_variables = {"to_agent": agent_2, "planet": "Mars"} + # set mock to return a response that triggers function call + mock_openai_client.set_sequential_responses( + [ + create_mock_response( + message={"role": "assistant", "content": ""}, + function_calls=[{"name": "agent_handoff_tool", "args": {}}], + ), + create_mock_response( + message={"role": "assistant", "content": ""}, + function_calls=[ + {"name": "get_weather", "args": {"location": "Montreux"}} + ], + ), + create_mock_response( + {"role": "assistant", "content": "sample response content"} + ), + ] + ) + + # set up client and run + client = AgentsRoutine(client=mock_openai_client) + response = await client.arun( + agent=agent_1, messages=messages, context_variables=context_variables + ) + + assert response.messages[2]["role"] == "tool" + assert response.messages[2]["content"] == json.dumps( + {"assistant": agent_1.name} + ) + assert response.messages[-2]["role"] == "tool" + assert ( + response.messages[-2]["content"] + == "It's sunny today in Montreux from planet Mars." + ) + assert response.messages[-1]["role"] == "assistant" + assert response.messages[-1]["content"] == "sample response content" + assert response.agent == agent_2 + assert response.context_variables == context_variables + + @pytest.mark.asyncio + async def test_astream( + self, mock_openai_client, get_weather_tool, agent_handoff_tool + ): + agent_1 = Agent(name="Test Agent", tools=[agent_handoff_tool]) + agent_2 = Agent(name="Test Agent", tools=[get_weather_tool]) + messages = [ + {"role": "user", "content": "What's the weather like in San Francisco?"} + ] + context_variables = {"to_agent": agent_2, "planet": "Mars"} + routine = AgentsRoutine(client=mock_openai_client) + + async def return_iterator(*args, **kwargs): + async def mock_openai_streaming_response( + history, + ) -> AsyncIterator[ChatCompletionChunk]: + """ + Simulates streaming chunks of a response for patching. + + Yields + ------ + AsyncIterator[ChatCompletionChunk]: Streaming chunks of the response. + """ + responses = [ + { + "message": {"role": "assistant", "content": ""}, + "function_call": [{"name": "agent_handoff_tool", "args": {}}], + }, + { + "message": {"role": "assistant", "content": ""}, + "function_call": [ + {"name": "get_weather", "args": {"location": "Montreux"}} + ], + }, + { + "message": { + "role": "assistant", + "content": "sample response content", + }, + }, + ] + response_to_call = ( + len([hist for hist in history if hist["role"] != "tool"]) - 1 + ) + response = responses[response_to_call] + + if "message" in response and "content" in response["message"]: + content = response["message"]["content"] + for i in range( + 0, len(content), 10 + ): # Stream content in chunks of 10 chars + yield ChatCompletionChunk( + id="chatcmpl-AdfVmbjxczsgRAADk9pXkmKPFsikY", + choices=[ + Choice( + delta=ChoiceDelta(content=content[i : i + 10]), + finish_reason=None, + index=0, + ) + ], + created=1734017726, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + system_fingerprint="fp_bba3c8e70b", + ) + + if "function_call" in response: + for function_call in response["function_call"]: + yield ChatCompletionChunk( + id="chatcmpl-AdfVmbjxczsgRAADk9pXkmKPFsikY", + choices=[ + Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction( + arguments=json.dumps( + function_call["args"] + ), + name=function_call["name"], + ), + type="function", + ) + ] + ), + finish_reason=None, + index=0, + ) + ], + created=1734017726, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + system_fingerprint="fp_bba3c8e70b", + ) + + yield ChatCompletionChunk( + id="chatcmpl-AdfVmbjxczsgRAADk9pXkmKPFsikY", + choices=[ + Choice(delta=ChoiceDelta(), finish_reason="stop", index=0) + ], + created=1734017726, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + system_fingerprint="fp_bba3c8e70b", + ) + + return mock_openai_streaming_response(kwargs["history"]) + + tokens = [] + with patch( + "swarm_copy.agent_routine.AgentsRoutine.get_chat_completion", + new=return_iterator, + ): + async for token in routine.astream( + agent=agent_1, messages=messages, context_variables=context_variables + ): + if isinstance(token, str): + tokens.append(token) + else: + response = token + + assert ( + "".join(tokens) + == '\nCalling tool : agent_handoff_tool with arguments : {}\nCalling tool : get_weather with arguments : {"location": "Montreux"}\n\nsample response content' + ) + assert response.messages[2]["role"] == "tool" + assert response.messages[2]["content"] == json.dumps( + {"assistant": agent_1.name} + ) + assert response.messages[-2]["role"] == "tool" + assert ( + response.messages[-2]["content"] + == "It's sunny today in Montreux from planet Mars." + ) + assert response.messages[-1]["role"] == "assistant" + assert response.messages[-1]["content"] == "sample response content" + assert response.agent == agent_2 + assert response.context_variables == context_variables diff --git a/tests/agents/test_simple_agent.py b/tests/agents/test_simple_agent.py index 0f0ca1d..9bdd8ba 100644 --- a/tests/agents/test_simple_agent.py +++ b/tests/agents/test_simple_agent.py @@ -19,7 +19,7 @@ async def test_simple_agent_arun(fake_llm_with_tools, httpx_mock): json=knowledge_graph_response, ) - llm, tools, _ = await anext(fake_llm_with_tools) + llm, tools, _ = fake_llm_with_tools simple_agent = SimpleAgent(llm=llm, tools=tools) response = await simple_agent.arun(query="Call get_morpho with thalamus.") @@ -44,7 +44,7 @@ async def test_simple_agent_astream(fake_llm_with_tools, httpx_mock): json=knowledge_graph_response, ) - llm, tools, _ = await anext(fake_llm_with_tools) + llm, tools, _ = fake_llm_with_tools simple_agent = SimpleAgent(llm=llm, tools=tools) response_chunks = simple_agent.astream("Call get_morpho with thalamus.") diff --git a/tests/agents/test_simple_chat_agent.py b/tests/agents/test_simple_chat_agent.py index e0f5dea..86e105c 100644 --- a/tests/agents/test_simple_chat_agent.py +++ b/tests/agents/test_simple_chat_agent.py @@ -13,7 +13,7 @@ @pytest.mark.httpx_mock(can_send_already_matched_responses=True) @pytest.mark.asyncio async def test_arun(fake_llm_with_tools, httpx_mock): - llm, tools, fake_responses = await anext(fake_llm_with_tools) + llm, tools, fake_responses = fake_llm_with_tools json_path = Path(__file__).resolve().parent.parent / "data" / "knowledge_graph.json" with open(json_path) as f: knowledge_graph_response = json.load(f) @@ -69,7 +69,7 @@ async def test_arun(fake_llm_with_tools, httpx_mock): @pytest.mark.httpx_mock(can_send_already_matched_responses=True) @pytest.mark.asyncio async def test_astream(fake_llm_with_tools, httpx_mock): - llm, tools, fake_responses = await anext(fake_llm_with_tools) + llm, tools, fake_responses = fake_llm_with_tools json_path = Path(__file__).resolve().parent.parent / "data" / "knowledge_graph.json" with open(json_path) as f: knowledge_graph_response = json.load(f) diff --git a/tests/app/database/test_threads.py b/tests/app/database/test_threads.py index 246509b..be37e66 100644 --- a/tests/app/database/test_threads.py +++ b/tests/app/database/test_threads.py @@ -64,7 +64,7 @@ async def test_get_thread( patch_required_env, fake_llm_with_tools, httpx_mock, app_client, db_connection ): # Put data in the db - llm, _, _ = await anext(fake_llm_with_tools) + llm, _, _ = fake_llm_with_tools app.dependency_overrides[get_language_model] = lambda: llm test_settings = Settings( @@ -149,7 +149,7 @@ async def test_delete_thread( patch_required_env, fake_llm_with_tools, httpx_mock, app_client, db_connection ): # Put data in the db - llm, _, _ = await anext(fake_llm_with_tools) + llm, _, _ = fake_llm_with_tools app.dependency_overrides[get_language_model] = lambda: llm test_settings = Settings( diff --git a/tests/app/database/test_tools.py b/tests/app/database/test_tools.py index 3ce9750..31aedbd 100644 --- a/tests/app/database/test_tools.py +++ b/tests/app/database/test_tools.py @@ -14,7 +14,7 @@ async def test_get_tool_calls( patch_required_env, fake_llm_with_tools, httpx_mock, app_client, db_connection ): # Put data in the db - llm, _, _ = await anext(fake_llm_with_tools) + llm, _, _ = fake_llm_with_tools app.dependency_overrides[get_language_model] = lambda: llm test_settings = Settings( db={"prefix": db_connection}, @@ -75,7 +75,7 @@ async def test_get_tool_output( db_connection, ): # Put data in the db - llm, _, _ = await anext(fake_llm_with_tools) + llm, _, _ = fake_llm_with_tools app.dependency_overrides[get_language_model] = lambda: llm test_settings = Settings( diff --git a/tests/conftest.py b/tests/conftest.py index 88ee7ea..3f26988 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -107,7 +107,7 @@ def brain_region_json_path(): return br_path -@pytest.fixture +@pytest_asyncio.fixture async def fake_llm_with_tools(brain_region_json_path): class FakeFuntionChatModel(GenericFakeChatModel): def bind_tools(self, functions: list): @@ -119,7 +119,7 @@ def bind_functions(self, **kwargs): # If you need another fake response to use different tools, # you can do in your test # ```python - # llm, _ = await anext(fake_llm_with_tools) + # llm, _ = fake_llm_with_tools # llm.responses = my_fake_responses # ``` # and simply bind the corresponding tools diff --git a/tests/multi_agents/test_supervisor_multi_agent.py b/tests/multi_agents/test_supervisor_multi_agent.py index 7ae1618..630fa4f 100644 --- a/tests/multi_agents/test_supervisor_multi_agent.py +++ b/tests/multi_agents/test_supervisor_multi_agent.py @@ -55,7 +55,7 @@ async def test_summarizer_node(fake_llm_with_tools): ] ) - mock_llm, _, _ = await anext(fake_llm_with_tools) + mock_llm, _, _ = fake_llm_with_tools agent = SupervisorMultiAgent(agents=[("agent1", [])], llm=mock_llm) mock_message = SystemMessage( @@ -73,7 +73,7 @@ async def test_summarizer_node(fake_llm_with_tools): @pytest.mark.asyncio async def test_create_graph(fake_llm_with_tools): - mock_llm, _, _ = await anext(fake_llm_with_tools) + mock_llm, _, _ = fake_llm_with_tools agent = SupervisorMultiAgent(agents=[("agent1", [])], llm=mock_llm) result = agent.create_graph() nodes = result.nodes From 8f549507774649e610d9cc11259785d89c0c5dc0 Mon Sep 17 00:00:00 2001 From: cszsol Date: Tue, 17 Dec 2024 17:02:17 +0100 Subject: [PATCH 2/2] Added app_utils tests (#55) * Added app_utils tests * Added test_dependencies * Update test_dependencies.py * Conflict resolution * Update test_dependencies.py * run swarm copy tests * Added test_dependencies * Fixed breaking changes * Fixed settings test * lint * Remove unnecessary dependencies * Added get_vlab_and_project tests * Added test for get starting agent * Added test for get_kg_token * Added test lifespan * lint * unit tests * Revert conftest.py * Review comments * Fixed lifespan test * Fixed fixture * Fixed test --------- Co-authored-by: kanesoban --- CHANGELOG.md | 1 + swarm_copy/app/dependencies.py | 4 +- swarm_copy_tests/app/test_app_utils.py | 75 +++++ swarm_copy_tests/app/test_config.py | 71 ++++ swarm_copy_tests/app/test_dependencies.py | 387 ++++++++++++++++++++++ swarm_copy_tests/app/test_main.py | 75 +++++ swarm_copy_tests/conftest.py | 21 ++ 7 files changed, 631 insertions(+), 3 deletions(-) create mode 100644 swarm_copy_tests/app/test_app_utils.py create mode 100644 swarm_copy_tests/app/test_config.py create mode 100644 swarm_copy_tests/app/test_dependencies.py create mode 100644 swarm_copy_tests/app/test_main.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 34dc6cf..0e33b9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Tool implementations without langchain or langgraph dependencies - CRUDs. - BlueNaas CRUD tools +- app unit tests - Tests of AgentsRoutine. - Unit tests for database diff --git a/swarm_copy/app/dependencies.py b/swarm_copy/app/dependencies.py index c64223a..087f9a1 100644 --- a/swarm_copy/app/dependencies.py +++ b/swarm_copy/app/dependencies.py @@ -204,9 +204,7 @@ async def get_vlab_and_project( if not thread: raise HTTPException( status_code=404, - detail={ - "detail": "Thread not found.", - }, + detail="Thread not found.", ) if thread and thread.vlab_id and thread.project_id: vlab_and_project = { diff --git a/swarm_copy_tests/app/test_app_utils.py b/swarm_copy_tests/app/test_app_utils.py new file mode 100644 index 0000000..70018f2 --- /dev/null +++ b/swarm_copy_tests/app/test_app_utils.py @@ -0,0 +1,75 @@ +"""Test app utils.""" + +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi.exceptions import HTTPException +from httpx import AsyncClient + +from swarm_copy.app.app_utils import setup_engine, validate_project +from swarm_copy.app.config import Settings + + +@pytest.mark.asyncio +async def test_validate_project(patch_required_env, httpx_mock, monkeypatch): + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "true") + httpx_client = AsyncClient() + token = "fake_token" + test_vp = {"vlab_id": "test_vlab_DB", "project_id": "project_id_DB"} + vlab_url = "https://openbluebrain.com/api/virtual-lab-manager/virtual-labs" + + # test with bad config + httpx_mock.add_response( + url=f'{vlab_url}/{test_vp["vlab_id"]}/projects/{test_vp["project_id"]}', + status_code=404, + ) + with pytest.raises(HTTPException) as error: + await validate_project( + httpx_client=httpx_client, + vlab_id=test_vp["vlab_id"], + project_id=test_vp["project_id"], + token=token, + vlab_project_url=vlab_url, + ) + assert error.value.status_code == 401 + + # test with good config + httpx_mock.add_response( + url=f'{vlab_url}/{test_vp["vlab_id"]}/projects/{test_vp["project_id"]}', + json="test_project_ID", + ) + await validate_project( + httpx_client=httpx_client, + vlab_id=test_vp["vlab_id"], + project_id=test_vp["project_id"], + token=token, + vlab_project_url=vlab_url, + ) + # we jsut want to assert that the httpx_mock was called. + + +@patch("neuroagent.app.app_utils.create_async_engine") +def test_setup_engine(create_engine_mock, monkeypatch, patch_required_env): + create_engine_mock.return_value = AsyncMock() + + monkeypatch.setenv("NEUROAGENT_DB__PREFIX", "prefix") + + settings = Settings() + + connection_string = "postgresql+asyncpg://user:password@localhost/dbname" + retval = setup_engine(settings=settings, connection_string=connection_string) + assert retval is not None + + +@patch("neuroagent.app.app_utils.create_async_engine") +def test_setup_engine_no_connection_string( + create_engine_mock, monkeypatch, patch_required_env +): + create_engine_mock.return_value = AsyncMock() + + monkeypatch.setenv("NEUROAGENT_DB__PREFIX", "prefix") + + settings = Settings() + + retval = setup_engine(settings=settings, connection_string=None) + assert retval is None diff --git a/swarm_copy_tests/app/test_config.py b/swarm_copy_tests/app/test_config.py new file mode 100644 index 0000000..5274b9c --- /dev/null +++ b/swarm_copy_tests/app/test_config.py @@ -0,0 +1,71 @@ +"""Test config""" + +import pytest +from pydantic import ValidationError + +from swarm_copy.app.config import Settings + + +def test_required(monkeypatch, patch_required_env): + settings = Settings() + + assert settings.tools.literature.url == "https://fake_url" + assert settings.knowledge_graph.base_url == "https://fake_url/api/nexus/v1" + assert settings.openai.token.get_secret_value() == "dummy" + + # make sure not case sensitive + monkeypatch.delenv("NEUROAGENT_TOOLS__LITERATURE__URL") + monkeypatch.setenv("neuroagent_tools__literature__URL", "https://new_fake_url") + + settings = Settings() + assert settings.tools.literature.url == "https://new_fake_url" + + +def test_no_settings(): + # We get an error when no custom variables provided + with pytest.raises(ValidationError): + Settings() + + +def test_setup_tools(monkeypatch, patch_required_env): + monkeypatch.setenv("NEUROAGENT_TOOLS__TRACE__SEARCH_SIZE", "20") + monkeypatch.setenv("NEUROAGENT_TOOLS__MORPHO__SEARCH_SIZE", "20") + monkeypatch.setenv("NEUROAGENT_TOOLS__KG_MORPHO_FEATURES__SEARCH_SIZE", "20") + + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__USERNAME", "user") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "pass") + + settings = Settings() + + assert settings.tools.morpho.search_size == 20 + assert settings.tools.trace.search_size == 20 + assert settings.tools.kg_morpho_features.search_size == 20 + assert settings.keycloak.username == "user" + assert settings.keycloak.password.get_secret_value() == "pass" + + +def test_check_consistency(monkeypatch): + # We get an error when no custom variables provided + url = "https://fake_url" + monkeypatch.setenv("NEUROAGENT_TOOLS__LITERATURE__URL", url) + monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__URL", url) + + with pytest.raises(ValueError): + Settings() + + monkeypatch.setenv("NEUROAGENT_GENERATIVE__OPENAI__TOKEN", "dummy") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "true") + + with pytest.raises(ValueError): + Settings() + + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "false") + + with pytest.raises(ValueError): + Settings() + + monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__BASE_URL", "http://fake_nexus.com") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "true") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "Hello") + + Settings() diff --git a/swarm_copy_tests/app/test_dependencies.py b/swarm_copy_tests/app/test_dependencies.py new file mode 100644 index 0000000..ad8b8f2 --- /dev/null +++ b/swarm_copy_tests/app/test_dependencies.py @@ -0,0 +1,387 @@ +"""Test dependencies.""" + +import json +import os +from pathlib import Path +from typing import AsyncIterator +from unittest.mock import Mock, patch + +import pytest +from httpx import AsyncClient +from fastapi import Request, HTTPException + +from swarm_copy.app.app_utils import setup_engine +from swarm_copy.app.database.sql_schemas import Base, Threads +from swarm_copy.app.dependencies import ( + Settings, + get_cell_types_kg_hierarchy, + get_connection_string, + get_httpx_client, + get_settings, + get_update_kg_hierarchy, + get_user_id, get_session, get_vlab_and_project, get_starting_agent, get_kg_token, +) +from swarm_copy.new_types import Agent + + +def test_get_settings(patch_required_env): + settings = get_settings() + assert settings.tools.literature.url == "https://fake_url" + assert settings.knowledge_graph.url == "https://fake_url/api/nexus/v1/search/query/" + + +@pytest.mark.asyncio +async def test_get_httpx_client(): + request = Mock() + request.headers = {"x-request-id": "greatid"} + httpx_client_iterator = get_httpx_client(request=request) + assert isinstance(httpx_client_iterator, AsyncIterator) + async for httpx_client in httpx_client_iterator: + assert isinstance(httpx_client, AsyncClient) + assert httpx_client.headers["x-request-id"] == "greatid" + + +@pytest.mark.asyncio +async def test_get_user(httpx_mock, monkeypatch, patch_required_env): + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__USERNAME", "fake_username") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "fake_password") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__ISSUER", "https://great_issuer.com") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "true") + + fake_response = { + "sub": "12345", + "email_verified": False, + "name": "Machine Learning Test User", + "groups": [], + "preferred_username": "sbo-ml", + "given_name": "Machine Learning", + "family_name": "Test User", + "email": "email@epfl.ch", + } + httpx_mock.add_response( + url="https://great_issuer.com/protocol/openid-connect/userinfo", + json=fake_response, + ) + + settings = Settings() + client = AsyncClient() + token = "eyJgreattoken" + user_id = await get_user_id(token=token, settings=settings, httpx_client=client) + + assert user_id == fake_response["sub"] + + +@pytest.mark.asyncio +async def test_get_update_kg_hierarchy( + tmp_path, httpx_mock, monkeypatch, patch_required_env +): + token = "fake_token" + file_name = "fake_file" + client = AsyncClient() + + file_url = "https://fake_file_url" + + monkeypatch.setenv( + "NEUROAGENT_KNOWLEDGE_GRAPH__HIERARCHY_URL", "http://fake_hierarchy_url.com" + ) + + settings = Settings( + knowledge_graph={"br_saving_path": tmp_path / "test_brain_region.json"} + ) + + json_response_url = { + "head": {"vars": ["file_url"]}, + "results": {"bindings": [{"file_url": {"type": "uri", "value": file_url}}]}, + } + with open( + Path(__file__).parent.parent.parent + / "tests" + / "data" + / "KG_brain_regions_hierarchy_test.json" + ) as fh: + json_response_file = json.load(fh) + + httpx_mock.add_response( + url=settings.knowledge_graph.sparql_url, json=json_response_url + ) + httpx_mock.add_response(url=file_url, json=json_response_file) + + await get_update_kg_hierarchy( + token, + client, + settings, + file_name, + ) + + assert os.path.exists(settings.knowledge_graph.br_saving_path) + + +@pytest.mark.asyncio +async def test_get_cell_types_kg_hierarchy( + tmp_path, httpx_mock, monkeypatch, patch_required_env +): + token = "fake_token" + file_name = "fake_file" + client = AsyncClient() + + file_url = "https://fake_file_url" + monkeypatch.setenv( + "NEUROAGENT_KNOWLEDGE_GRAPH__HIERARCHY_URL", "http://fake_hierarchy_url.com" + ) + + settings = Settings( + knowledge_graph={"ct_saving_path": tmp_path / "test_cell_types_region.json"} + ) + + json_response_url = { + "head": {"vars": ["file_url"]}, + "results": {"bindings": [{"file_url": {"type": "uri", "value": file_url}}]}, + } + with open( + Path(__file__).parent.parent.parent + / "tests" + / "data" + / "kg_cell_types_hierarchy_test.json" + ) as fh: + json_response_file = json.load(fh) + + httpx_mock.add_response( + url=settings.knowledge_graph.sparql_url, json=json_response_url + ) + httpx_mock.add_response(url=file_url, json=json_response_file) + + await get_cell_types_kg_hierarchy( + token, + client, + settings, + file_name, + ) + + assert os.path.exists(settings.knowledge_graph.ct_saving_path) + + +def test_get_connection_string_full(monkeypatch, patch_required_env): + monkeypatch.setenv("NEUROAGENT_DB__PREFIX", "http://") + monkeypatch.setenv("NEUROAGENT_DB__USER", "John") + monkeypatch.setenv("NEUROAGENT_DB__PASSWORD", "Doe") + monkeypatch.setenv("NEUROAGENT_DB__HOST", "localhost") + monkeypatch.setenv("NEUROAGENT_DB__PORT", "5000") + monkeypatch.setenv("NEUROAGENT_DB__NAME", "test") + + settings = Settings() + result = get_connection_string(settings) + assert ( + result == "http://John:Doe@localhost:5000/test" + ), "must return fully formed connection string" + + + +@pytest.mark.asyncio +@pytest.mark.httpx_mock(can_send_already_matched_responses=True) +async def test_get_vlab_and_project( + patch_required_env, httpx_mock, db_connection, monkeypatch +): + # Setup DB with one thread to do the tests + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "true") + test_settings = Settings( + db={"prefix": db_connection}, + ) + engine = setup_engine(test_settings, db_connection) + session = await anext(get_session(engine)) + user_id = "Super_user" + token = "fake_token" + httpx_client = AsyncClient() + httpx_mock.add_response( + url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project", + json="test_project_ID", + ) + + # create test thread table + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + new_thread = Threads( + user_id=user_id, + vlab_id="test_vlab_DB", + project_id="project_id_DB", + title="test_title", + ) + session.add(new_thread) + await session.commit() + + try: + # Test with info in headers. + good_request_headers = Request( + scope={ + "type": "http", + "method": "Get", + "url": "http://fake_url/thread_id", + "headers": [ + (b"x-virtual-lab-id", b"test_vlab"), + (b"x-project-id", b"test_project"), + ], + }, + ) + ids = await get_vlab_and_project( + user_id=user_id, + session=session, + request=good_request_headers, + settings=test_settings, + token=token, + httpx_client=httpx_client, + ) + assert ids == {"vlab_id": "test_vlab", "project_id": "test_project"} + finally: + # don't forget to close the session, otherwise the tests hangs. + await session.close() + await engine.dispose() + + +@pytest.mark.asyncio +async def test_get_vlab_and_project_no_info_in_headers( + patch_required_env, db_connection, monkeypatch +): + # Setup DB with one thread to do the tests + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "true") + test_settings = Settings( + db={"prefix": db_connection}, + ) + engine = setup_engine(test_settings, db_connection) + session = await anext(get_session(engine)) + user_id = "Super_user" + token = "fake_token" + httpx_client = AsyncClient() + + # create test thread table + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + new_thread = Threads( + user_id=user_id, + vlab_id="test_vlab_DB", + project_id="project_id_DB", + title="test_title", + ) + session.add(new_thread) + await session.commit() + + try: + # Test with no infos in headers. + bad_request = Request( + scope={ + "type": "http", + "method": "GET", + "scheme": "http", + "server": ("example.com", 80), + "path_params": {"dummy_patram": "fake_thread_id"}, + "headers": [ + (b"wong_header", b"wrong value"), + ], + } + ) + with pytest.raises(HTTPException) as error: + await get_vlab_and_project( + user_id=user_id, + session=session, + request=bad_request, + settings=test_settings, + token=token, + httpx_client=httpx_client, + ) + assert ( + error.value.detail == "Thread not found." + ) + finally: + # don't forget to close the session, otherwise the tests hangs. + await session.close() + await engine.dispose() + + +@pytest.mark.asyncio +@pytest.mark.httpx_mock(can_send_already_matched_responses=True) +async def test_get_vlab_and_project_valid_thread_id( + patch_required_env, httpx_mock, db_connection, monkeypatch +): + # Setup DB with one thread to do the tests + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "true") + test_settings = Settings( + db={"prefix": db_connection}, + ) + engine = setup_engine(test_settings, db_connection) + session = await anext(get_session(engine)) + user_id = "Super_user" + token = "fake_token" + httpx_client = AsyncClient() + httpx_mock.add_response( + url=f"{test_settings.virtual_lab.get_project_url}/test_vlab_DB/projects/project_id_DB", + json="test_project_ID", + ) + + + # create test thread table + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + new_thread = Threads( + user_id=user_id, + vlab_id="test_vlab_DB", + project_id="project_id_DB", + title="test_title", + ) + session.add(new_thread) + await session.commit() + await session.refresh(new_thread) + + try: + # Test with no infos in headers, but valid thread_ID. + good_request_DB = Request( + scope={ + "type": "http", + "method": "GET", + "scheme": "http", + "server": ("example.com", 80), + "path_params": {"thread_id": new_thread.thread_id}, + "headers": [ + (b"wong_header", b"wrong value"), + ], + } + ) + ids_from_DB = await get_vlab_and_project( + user_id=user_id, + session=session, + request=good_request_DB, + settings=test_settings, + token=token, + httpx_client=httpx_client, + ) + assert ids_from_DB == {"vlab_id": "test_vlab_DB", "project_id": "project_id_DB"} + + finally: + # don't forget to close the session, otherwise the tests hangs. + await session.close() + await engine.dispose() + + +def test_get_starting_agent(patch_required_env): + settings = Settings() + agent = get_starting_agent(None, settings) + + assert isinstance(agent, Agent) + + +@pytest.mark.parametrize( + "input_token, expected_token", + [ + ("existing_token", "existing_token"), + (None, "new_token"), + ], +) +def test_get_kg_token(patch_required_env, input_token, expected_token): + settings = Settings() + mock = Mock() + mock.token.return_value = {"access_token": expected_token} + with ( + patch("swarm_copy.app.dependencies.KeycloakOpenID", return_value=mock), + ): + result = get_kg_token(settings, input_token) + assert result == expected_token diff --git a/swarm_copy_tests/app/test_main.py b/swarm_copy_tests/app/test_main.py new file mode 100644 index 0000000..23f4299 --- /dev/null +++ b/swarm_copy_tests/app/test_main.py @@ -0,0 +1,75 @@ +import logging +from unittest.mock import patch + +from fastapi.testclient import TestClient + +from swarm_copy.app.dependencies import get_settings +from swarm_copy.app.main import app + + +def test_settings_endpoint(app_client, dont_look_at_env_file, settings): + response = app_client.get("/settings") + + replace_secretstr = settings.model_dump() + replace_secretstr["keycloak"]["password"] = "**********" + replace_secretstr["openai"]["token"] = "**********" + assert response.json() == replace_secretstr + + +def test_readyz(app_client): + response = app_client.get( + "/", + ) + + body = response.json() + assert isinstance(body, dict) + assert body["status"] == "ok" + + +def test_lifespan(caplog, monkeypatch, tmp_path, patch_required_env, db_connection): + get_settings.cache_clear() + caplog.set_level(logging.INFO) + + monkeypatch.setenv("NEUROAGENT_LOGGING__LEVEL", "info") + monkeypatch.setenv("NEUROAGENT_LOGGING__EXTERNAL_PACKAGES", "warning") + monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__DOWNLOAD_HIERARCHY", "true") + monkeypatch.setenv("NEUROAGENT_DB__PREFIX", db_connection) + + save_path_brainregion = tmp_path / "fake.json" + + async def save_dummy(*args, **kwargs): + with open(save_path_brainregion, "w") as f: + f.write("test_text") + + with ( + patch("swarm_copy.app.main.get_update_kg_hierarchy", new=save_dummy), + patch("swarm_copy.app.main.get_cell_types_kg_hierarchy", new=save_dummy), + patch("swarm_copy.app.main.get_kg_token", new=lambda *args, **kwargs: "dev"), + ): + # The with statement triggers the startup. + with TestClient(app) as test_client: + test_client.get("/healthz") + # check if the brain region dummy file was created. + assert save_path_brainregion.exists() + + assert caplog.record_tuples[0][::2] == ( + "swarm_copy.app.dependencies", + "Reading the environment and instantiating settings", + ) + + assert ( + logging.getLevelName(logging.getLogger("swarm_copy").getEffectiveLevel()) + == "INFO" + ) + assert ( + logging.getLevelName(logging.getLogger("httpx").getEffectiveLevel()) + == "WARNING" + ) + assert ( + logging.getLevelName(logging.getLogger("fastapi").getEffectiveLevel()) + == "WARNING" + ) + assert ( + logging.getLevelName(logging.getLogger("bluepyefe").getEffectiveLevel()) + == "CRITICAL" + ) diff --git a/swarm_copy_tests/conftest.py b/swarm_copy_tests/conftest.py index 31deefb..48c5a59 100644 --- a/swarm_copy_tests/conftest.py +++ b/swarm_copy_tests/conftest.py @@ -167,3 +167,24 @@ def get_resolve_query_output(): def brain_region_json_path(): br_path = Path(__file__).parent / "data" / "brainregion_hierarchy.json" return br_path + + +@pytest.fixture(name="settings") +def settings(): + return Settings( + tools={ + "literature": { + "url": "fake_literature_url", + }, + }, + knowledge_graph={ + "base_url": "https://fake_url/api/nexus/v1", + }, + openai={ + "token": "fake_token", + }, + keycloak={ + "username": "fake_username", + "password": "fake_password", + }, + )