From bbc7f2b801db446609b38ce7a350b27595ec4691 Mon Sep 17 00:00:00 2001 From: cszsolnai Date: Wed, 20 Nov 2024 08:48:04 +0100 Subject: [PATCH 1/9] Added tests --- CHANGELOG.md | 1 + swarm_copy_tests/app/routers/__init__.py | 1 + swarm_copy_tests/app/routers/test_threads.py | 122 +++++++++++++++++++ 3 files changed, 124 insertions(+) create mode 100644 swarm_copy_tests/app/routers/__init__.py create mode 100644 swarm_copy_tests/app/routers/test_threads.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 52955bc..534d166 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Tool implementations without langchain or langgraph dependencies - CRUDs. - BlueNaas CRUD tools +- Tests for threads module ### Fixed - Migrate LLM Evaluation logic to scripts and add tests diff --git a/swarm_copy_tests/app/routers/__init__.py b/swarm_copy_tests/app/routers/__init__.py new file mode 100644 index 0000000..ecc485b --- /dev/null +++ b/swarm_copy_tests/app/routers/__init__.py @@ -0,0 +1 @@ +"""Tests for threads module.""" diff --git a/swarm_copy_tests/app/routers/test_threads.py b/swarm_copy_tests/app/routers/test_threads.py new file mode 100644 index 0000000..0926bfc --- /dev/null +++ b/swarm_copy_tests/app/routers/test_threads.py @@ -0,0 +1,122 @@ +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from swarm_copy.app.database.sql_schemas import utc_now +from swarm_copy.app.routers.threads import get_threads, update_thread_title, delete_thread + + +@pytest.mark.asyncio +async def test_create_thread(app_client, settings): + mock_validate_project = AsyncMock() + mock_session = AsyncMock() + user_id = "user_id" + token = "token" + title = "title" + thread_id = "uuid" + project_id = "project_id" + virtual_lab_id = "virtual_lab_id" + creation_date = utc_now() + update_date = utc_now() + with patch("swarm_copy.app.app_utils.validate_project", mock_validate_project): + with patch('swarm_copy.app.database.sql_schemas.Threads', autospec=True) as mock_threads: + from swarm_copy.app.routers.threads import create_thread + mock_thread_instance = Mock(user_id=user_id, + title=title, + vlab_id=virtual_lab_id, + project_id=project_id, + thread_id=thread_id, + creation_date=creation_date, + update_date=update_date) + mock_threads.return_value = mock_thread_instance + await create_thread(app_client, settings, + token, + virtual_lab_id, + project_id, + mock_session, + user_id, + title) + assert mock_session.add.called + assert mock_session.commit.called + assert mock_session.refresh.called + + +@pytest.mark.asyncio +async def test_get_threads(): + user_id = "user_id" + title = "title" + thread_id = "uuid" + project_id = "project_id" + virtual_lab_id = "virtual_lab_id" + creation_date = utc_now() + update_date = utc_now() + mock_threads = [ + Mock(user_id=user_id, + title=title, + vlab_id=virtual_lab_id, + project_id=project_id, + thread_id=thread_id, + creation_date=creation_date, + update_date=update_date) + ] + mock_session = AsyncMock() + scalars_mock = Mock() + scalars_mock.all.return_value = mock_threads + mock_thread_result = Mock() + mock_thread_result.scalars.return_value = scalars_mock + mock_session.execute.return_value = mock_thread_result + thread_reads = await get_threads(mock_session, user_id) + thread_read = thread_reads[0] + assert thread_read.thread_id == thread_id + assert thread_read.user_id == user_id + assert thread_read.vlab_id == virtual_lab_id + assert thread_read.project_id == project_id + assert thread_read.title == title + assert thread_read.creation_date == creation_date + assert thread_read.update_date == update_date + + +@pytest.mark.asyncio +async def test_update_thread_title(): + user_id = "user_id" + title = "title" + thread_id = "uuid" + project_id = "project_id" + virtual_lab_id = "virtual_lab_id" + creation_date = utc_now() + update_date = utc_now() + mock_session = AsyncMock() + mock_thread_result = Mock() + mock_session.execute.return_value = mock_thread_result + mock_update_thread = Mock() + mock_update_thread.model_dump.return_value = { + "user_id": user_id, + "title": title, + "vlab_id": virtual_lab_id, + "project_id": project_id, + "thread_id": thread_id, + "creation_date": creation_date, + "update_date": update_date + } + mock_thread = Mock() + thread_read = await update_thread_title(mock_session, mock_update_thread, mock_thread) + assert mock_session.commit.called + assert mock_session.refresh.called + assert thread_read.thread_id == thread_id + assert thread_read.user_id == user_id + assert thread_read.vlab_id == virtual_lab_id + assert thread_read.project_id == project_id + assert thread_read.title == title + assert thread_read.creation_date == creation_date + assert thread_read.update_date == update_date + + +@pytest.mark.asyncio +async def test_delete_thread(): + mock_session = AsyncMock() + mock_thread_result = Mock() + mock_session.execute.return_value = mock_thread_result + mock_thread = Mock() + await delete_thread(mock_session, mock_thread) + assert mock_session.delete.called + assert mock_session.commit.called From da32d22456ca4d7ed869425f026f411d89d8f0e2 Mon Sep 17 00:00:00 2001 From: cszsolnai Date: Wed, 20 Nov 2024 10:40:17 +0100 Subject: [PATCH 2/9] Run new tests --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d8becc5..0347c9b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -90,4 +90,4 @@ jobs: mypy src/ swarm_copy/ # Include src/ directory in Python path to prioritize local files in pytest export PYTHONPATH=$(pwd)/src:$PYTHONPATH - pytest --color=yes + pytest --color=yes tests/ swarm_copy_tests/ From 28335cb7b71b79d1123caf93293ad001e222f8ba Mon Sep 17 00:00:00 2001 From: cszsolnai Date: Wed, 20 Nov 2024 14:55:17 +0100 Subject: [PATCH 3/9] Fixed thread unittest issues --- swarm_copy/tools/bluenaas_memodel_getall.py | 4 +- swarm_copy/tools/bluenaas_memodel_getone.py | 4 +- swarm_copy_tests/app/routers/test_threads.py | 47 +++++++++++--------- 3 files changed, 29 insertions(+), 26 deletions(-) diff --git a/swarm_copy/tools/bluenaas_memodel_getall.py b/swarm_copy/tools/bluenaas_memodel_getall.py index 8bda00e..cdf55ff 100644 --- a/swarm_copy/tools/bluenaas_memodel_getall.py +++ b/swarm_copy/tools/bluenaas_memodel_getall.py @@ -29,7 +29,7 @@ class InputMEModelGetAll(BaseModel): page_size: int = Field( default=20, description="Number of results returned by the API." ) - model_type: Literal["single-neuron-simulation", "synaptome-simulation"] = Field( + simulation_type: Literal["single-neuron-simulation", "synaptome-simulation"] = Field( default="single-neuron-simulation", description="Type of simulation to retrieve.", ) @@ -55,7 +55,7 @@ async def arun(self) -> PaginatedResponseUnionMEModelResponseSynaptomeModelRespo response = await self.metadata.httpx_client.get( url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/me-models", params={ - "simulation_type": self.input_schema.model_type, + "simulation_type": self.input_schema.simulation_type, "offset": self.input_schema.offset, "page_size": self.input_schema.page_size, }, diff --git a/swarm_copy/tools/bluenaas_memodel_getone.py b/swarm_copy/tools/bluenaas_memodel_getone.py index 4f4a3b3..70774b0 100644 --- a/swarm_copy/tools/bluenaas_memodel_getone.py +++ b/swarm_copy/tools/bluenaas_memodel_getone.py @@ -24,7 +24,7 @@ class MEModelGetOneMetadata(BaseMetadata): class InputMEModelGetOne(BaseModel): """Inputs for the BlueNaaS single-neuron simulation.""" - model_id: str = Field( + simulation_id: str = Field( description="ID of the model to retrieve. Should be an https link." ) @@ -45,7 +45,7 @@ async def arun(self) -> MEModelResponse: ) response = await self.metadata.httpx_client.get( - url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/{quote_plus(self.input_schema.model_id)}", + url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/{quote_plus(self.input_schema.simulation_id)}", headers={"Authorization": f"Bearer {self.metadata.token}"}, ) diff --git a/swarm_copy_tests/app/routers/test_threads.py b/swarm_copy_tests/app/routers/test_threads.py index 0926bfc..a2b3d67 100644 --- a/swarm_copy_tests/app/routers/test_threads.py +++ b/swarm_copy_tests/app/routers/test_threads.py @@ -3,40 +3,34 @@ import pytest from swarm_copy.app.database.sql_schemas import utc_now -from swarm_copy.app.routers.threads import get_threads, update_thread_title, delete_thread @pytest.mark.asyncio async def test_create_thread(app_client, settings): mock_validate_project = AsyncMock() mock_session = AsyncMock() + + def add_thread(thread): + thread.creation_date = utc_now() + thread.update_date = utc_now() + thread.thread_id = "thread_id" + + mock_session.add = add_thread user_id = "user_id" token = "token" title = "title" - thread_id = "uuid" project_id = "project_id" virtual_lab_id = "virtual_lab_id" - creation_date = utc_now() - update_date = utc_now() with patch("swarm_copy.app.app_utils.validate_project", mock_validate_project): - with patch('swarm_copy.app.database.sql_schemas.Threads', autospec=True) as mock_threads: - from swarm_copy.app.routers.threads import create_thread - mock_thread_instance = Mock(user_id=user_id, - title=title, - vlab_id=virtual_lab_id, - project_id=project_id, - thread_id=thread_id, - creation_date=creation_date, - update_date=update_date) - mock_threads.return_value = mock_thread_instance - await create_thread(app_client, settings, - token, - virtual_lab_id, - project_id, - mock_session, - user_id, - title) - assert mock_session.add.called + from swarm_copy.app.routers.threads import create_thread + await create_thread(app_client, settings, + token, + virtual_lab_id, + project_id, + mock_session, + user_id, + title) + assert mock_session.commit.called assert mock_session.refresh.called @@ -65,6 +59,7 @@ async def test_get_threads(): mock_thread_result = Mock() mock_thread_result.scalars.return_value = scalars_mock mock_session.execute.return_value = mock_thread_result + from swarm_copy.app.routers.threads import get_threads thread_reads = await get_threads(mock_session, user_id) thread_read = thread_reads[0] assert thread_read.thread_id == thread_id @@ -99,6 +94,7 @@ async def test_update_thread_title(): "update_date": update_date } mock_thread = Mock() + from swarm_copy.app.routers.threads import update_thread_title thread_read = await update_thread_title(mock_session, mock_update_thread, mock_thread) assert mock_session.commit.called assert mock_session.refresh.called @@ -117,6 +113,13 @@ async def test_delete_thread(): mock_thread_result = Mock() mock_session.execute.return_value = mock_thread_result mock_thread = Mock() + from swarm_copy.app.routers.threads import delete_thread await delete_thread(mock_session, mock_thread) assert mock_session.delete.called assert mock_session.commit.called + + +@pytest.fixture(autouse=True) +def stop_patches(): + yield + patch.stopall() From d362c2e1c7f8480263492ed5f7a57481a66348be Mon Sep 17 00:00:00 2001 From: cszsolnai Date: Wed, 20 Nov 2024 14:56:26 +0100 Subject: [PATCH 4/9] lint --- swarm_copy/tools/bluenaas_memodel_getall.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/swarm_copy/tools/bluenaas_memodel_getall.py b/swarm_copy/tools/bluenaas_memodel_getall.py index cdf55ff..db2501e 100644 --- a/swarm_copy/tools/bluenaas_memodel_getall.py +++ b/swarm_copy/tools/bluenaas_memodel_getall.py @@ -29,9 +29,11 @@ class InputMEModelGetAll(BaseModel): page_size: int = Field( default=20, description="Number of results returned by the API." ) - simulation_type: Literal["single-neuron-simulation", "synaptome-simulation"] = Field( - default="single-neuron-simulation", - description="Type of simulation to retrieve.", + simulation_type: Literal["single-neuron-simulation", "synaptome-simulation"] = ( + Field( + default="single-neuron-simulation", + description="Type of simulation to retrieve.", + ) ) From 24e305f008c0482b7521ac5fe0307699ade8cf5a Mon Sep 17 00:00:00 2001 From: cszsolnai Date: Wed, 20 Nov 2024 14:59:44 +0100 Subject: [PATCH 5/9] Added conftest --- swarm_copy_tests/conftest.py | 210 +++++++++++++++++++++++++++++++++++ 1 file changed, 210 insertions(+) create mode 100644 swarm_copy_tests/conftest.py diff --git a/swarm_copy_tests/conftest.py b/swarm_copy_tests/conftest.py new file mode 100644 index 0000000..8acd252 --- /dev/null +++ b/swarm_copy_tests/conftest.py @@ -0,0 +1,210 @@ +"""Test configuration.""" + +import json +from pathlib import Path + +import pytest +import pytest_asyncio +from fastapi.testclient import TestClient +from httpx import AsyncClient +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage +from sqlalchemy import MetaData +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine + +from neuroagent.app.config import Settings +from neuroagent.app.dependencies import get_kg_token, get_settings +from neuroagent.app.main import app +from neuroagent.tools import GetMorphoTool + + +@pytest.fixture(name="settings") +def settings(): + return Settings( + tools={ + "literature": { + "url": "fake_literature_url", + }, + }, + knowledge_graph={ + "base_url": "https://fake_url/api/nexus/v1", + }, + openai={ + "token": "fake_token", + }, + keycloak={ + "username": "fake_username", + "password": "fake_password", + }, + ) + + +@pytest.fixture(name="app_client") +def client_fixture(): + """Get client and clear app dependency_overrides.""" + app_client = TestClient(app) + test_settings = Settings( + tools={ + "literature": { + "url": "fake_literature_url", + }, + }, + knowledge_graph={ + "base_url": "https://fake_url/api/nexus/v1", + }, + openai={ + "token": "fake_token", + }, + keycloak={ + "username": "fake_username", + "password": "fake_password", + }, + ) + app.dependency_overrides[get_settings] = lambda: test_settings + # mock keycloak authentication + app.dependency_overrides[get_kg_token] = lambda: "fake_token" + yield app_client + app.dependency_overrides.clear() + + +@pytest.fixture(autouse=True, scope="session") +def dont_look_at_env_file(): + """Never look inside of the .env when running unit tests.""" + Settings.model_config["env_file"] = None + + +@pytest.fixture() +def patch_required_env(monkeypatch): + monkeypatch.setenv("NEUROAGENT_TOOLS__LITERATURE__URL", "https://fake_url") + monkeypatch.setenv( + "NEUROAGENT_KNOWLEDGE_GRAPH__BASE_URL", "https://fake_url/api/nexus/v1" + ) + monkeypatch.setenv("NEUROAGENT_OPENAI__TOKEN", "dummy") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "False") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "password") + + +@pytest_asyncio.fixture(params=["sqlite", "postgresql"], name="db_connection") +async def setup_sql_db(request, tmp_path): + db_type = request.param + + # To start the postgresql database: + # docker run -it --rm -p 5432:5432 -e POSTGRES_USER=test -e POSTGRES_PASSWORD=password postgres:latest + path = ( + f"sqlite+aiosqlite:///{tmp_path / 'test_db.db'}" + if db_type == "sqlite" + else "postgresql+asyncpg://test:password@localhost:5432" + ) + if db_type == "postgresql": + try: + async with create_async_engine(path).connect() as conn: + pass + except Exception: + pytest.skip("Postgres database not connected") + yield path + if db_type == "postgresql": + metadata = MetaData() + engine = create_async_engine(path) + session = AsyncSession(bind=engine) + async with engine.begin() as conn: + await conn.run_sync(metadata.reflect) + await conn.run_sync(metadata.drop_all) + + await session.commit() + await engine.dispose() + await session.aclose() + + +@pytest.fixture +def get_resolve_query_output(): + with open("tests/data/resolve_query.json") as f: + outputs = json.loads(f.read()) + return outputs + + +@pytest.fixture +def brain_region_json_path(): + br_path = Path(__file__).parent / "data" / "brainregion_hierarchy.json" + return br_path + + +@pytest.fixture +async def fake_llm_with_tools(brain_region_json_path): + class FakeFuntionChatModel(GenericFakeChatModel): + def bind_tools(self, functions: list): + return self + + def bind_functions(self, **kwargs): + return self + + # If you need another fake response to use different tools, + # you can do in your test + # ```python + # llm, _ = await anext(fake_llm_with_tools) + # llm.responses = my_fake_responses + # ``` + # and simply bind the corresponding tools + fake_responses = [ + AIMessage( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": "call_zHhwfNLSvGGHXMoILdIYtDVI", + "function": { + "arguments": '{"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"}', + "name": "get-morpho-tool", + }, + "type": "function", + } + ] + }, + response_metadata={"finish_reason": "tool_calls"}, + id="run-3828644d-197b-401b-8634-e6ecf01c2e7c-0", + tool_calls=[ + { + "name": "get-morpho-tool", + "args": { + "brain_region_id": ( + "http://api.brain-map.org/api/v2/data/Structure/549" + ) + }, + "id": "call_zHhwfNLSvGGHXMoILdIYtDVI", + } + ], + ), + AIMessage( + content="Great answer", + response_metadata={"finish_reason": "stop"}, + id="run-42768b30-044a-4263-8c5c-da61429aa9da-0", + ), + ] + + # If you use this tool in your test, DO NOT FORGET to mock the url response with the following snippet: + # + # ```python + # json_path = Path(__file__).resolve().parent.parent / "data" / "knowledge_graph.json" + # with open(json_path) as f: + # knowledge_graph_response = json.load(f) + + # httpx_mock.add_response( + # url="http://fake_url", + # json=knowledge_graph_response, + # ) + # ``` + # The http call is not mocked here because one might want to change the responses + # and the tools used. + async_client = AsyncClient() + tool = GetMorphoTool( + metadata={ + "url": "http://fake_url", + "search_size": 2, + "httpx_client": async_client, + "token": "fake_token", + "brainregion_path": brain_region_json_path, + } + ) + + yield FakeFuntionChatModel(messages=iter(fake_responses)), [tool], fake_responses + await async_client.aclose() From 72a90a022cf04fcd26522e5bc17aa67ef8b51378 Mon Sep 17 00:00:00 2001 From: kanesoban Date: Wed, 11 Dec 2024 10:19:30 +0100 Subject: [PATCH 6/9] Review fixes --- swarm_copy/tools/bluenaas_memodel_getall.py | 10 +- swarm_copy/tools/bluenaas_memodel_getone.py | 4 +- swarm_copy_tests/app/routers/test_threads.py | 220 +++++++++---------- swarm_copy_tests/conftest.py | 92 +------- 4 files changed, 114 insertions(+), 212 deletions(-) diff --git a/swarm_copy/tools/bluenaas_memodel_getall.py b/swarm_copy/tools/bluenaas_memodel_getall.py index db2501e..a7cee19 100644 --- a/swarm_copy/tools/bluenaas_memodel_getall.py +++ b/swarm_copy/tools/bluenaas_memodel_getall.py @@ -29,11 +29,9 @@ class InputMEModelGetAll(BaseModel): page_size: int = Field( default=20, description="Number of results returned by the API." ) - simulation_type: Literal["single-neuron-simulation", "synaptome-simulation"] = ( - Field( - default="single-neuron-simulation", - description="Type of simulation to retrieve.", - ) + memodel_type: Literal["single-neuron-simulation", "synaptome-simulation"] = Field( + default="single-neuron-simulation", + description="Type of simulation to retrieve.", ) @@ -57,7 +55,7 @@ async def arun(self) -> PaginatedResponseUnionMEModelResponseSynaptomeModelRespo response = await self.metadata.httpx_client.get( url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/me-models", params={ - "simulation_type": self.input_schema.simulation_type, + "simulation_type": self.input_schema.memodel_type, "offset": self.input_schema.offset, "page_size": self.input_schema.page_size, }, diff --git a/swarm_copy/tools/bluenaas_memodel_getone.py b/swarm_copy/tools/bluenaas_memodel_getone.py index 70774b0..f84acfa 100644 --- a/swarm_copy/tools/bluenaas_memodel_getone.py +++ b/swarm_copy/tools/bluenaas_memodel_getone.py @@ -24,7 +24,7 @@ class MEModelGetOneMetadata(BaseMetadata): class InputMEModelGetOne(BaseModel): """Inputs for the BlueNaaS single-neuron simulation.""" - simulation_id: str = Field( + memodel_id: str = Field( description="ID of the model to retrieve. Should be an https link." ) @@ -45,7 +45,7 @@ async def arun(self) -> MEModelResponse: ) response = await self.metadata.httpx_client.get( - url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/{quote_plus(self.input_schema.simulation_id)}", + url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/{quote_plus(self.input_schema.memodel_id)}", headers={"Authorization": f"Bearer {self.metadata.token}"}, ) diff --git a/swarm_copy_tests/app/routers/test_threads.py b/swarm_copy_tests/app/routers/test_threads.py index a2b3d67..3d54f02 100644 --- a/swarm_copy_tests/app/routers/test_threads.py +++ b/swarm_copy_tests/app/routers/test_threads.py @@ -1,122 +1,112 @@ +import logging from unittest.mock import AsyncMock, Mock, patch import pytest -from swarm_copy.app.database.sql_schemas import utc_now - - -@pytest.mark.asyncio -async def test_create_thread(app_client, settings): - mock_validate_project = AsyncMock() - mock_session = AsyncMock() - - def add_thread(thread): - thread.creation_date = utc_now() - thread.update_date = utc_now() - thread.thread_id = "thread_id" - - mock_session.add = add_thread - user_id = "user_id" - token = "token" - title = "title" - project_id = "project_id" - virtual_lab_id = "virtual_lab_id" - with patch("swarm_copy.app.app_utils.validate_project", mock_validate_project): - from swarm_copy.app.routers.threads import create_thread - await create_thread(app_client, settings, - token, - virtual_lab_id, - project_id, - mock_session, - user_id, - title) - - assert mock_session.commit.called - assert mock_session.refresh.called - - -@pytest.mark.asyncio -async def test_get_threads(): - user_id = "user_id" - title = "title" - thread_id = "uuid" - project_id = "project_id" - virtual_lab_id = "virtual_lab_id" - creation_date = utc_now() - update_date = utc_now() - mock_threads = [ - Mock(user_id=user_id, - title=title, - vlab_id=virtual_lab_id, - project_id=project_id, - thread_id=thread_id, - creation_date=creation_date, - update_date=update_date) - ] - mock_session = AsyncMock() - scalars_mock = Mock() - scalars_mock.all.return_value = mock_threads - mock_thread_result = Mock() - mock_thread_result.scalars.return_value = scalars_mock - mock_session.execute.return_value = mock_thread_result - from swarm_copy.app.routers.threads import get_threads - thread_reads = await get_threads(mock_session, user_id) - thread_read = thread_reads[0] - assert thread_read.thread_id == thread_id - assert thread_read.user_id == user_id - assert thread_read.vlab_id == virtual_lab_id - assert thread_read.project_id == project_id - assert thread_read.title == title - assert thread_read.creation_date == creation_date - assert thread_read.update_date == update_date - - -@pytest.mark.asyncio -async def test_update_thread_title(): - user_id = "user_id" - title = "title" - thread_id = "uuid" - project_id = "project_id" - virtual_lab_id = "virtual_lab_id" - creation_date = utc_now() - update_date = utc_now() - mock_session = AsyncMock() - mock_thread_result = Mock() - mock_session.execute.return_value = mock_thread_result - mock_update_thread = Mock() - mock_update_thread.model_dump.return_value = { - "user_id": user_id, - "title": title, - "vlab_id": virtual_lab_id, - "project_id": project_id, - "thread_id": thread_id, - "creation_date": creation_date, - "update_date": update_date - } - mock_thread = Mock() - from swarm_copy.app.routers.threads import update_thread_title - thread_read = await update_thread_title(mock_session, mock_update_thread, mock_thread) - assert mock_session.commit.called - assert mock_session.refresh.called - assert thread_read.thread_id == thread_id - assert thread_read.user_id == user_id - assert thread_read.vlab_id == virtual_lab_id - assert thread_read.project_id == project_id - assert thread_read.title == title - assert thread_read.creation_date == creation_date - assert thread_read.update_date == update_date - - -@pytest.mark.asyncio -async def test_delete_thread(): - mock_session = AsyncMock() - mock_thread_result = Mock() - mock_session.execute.return_value = mock_thread_result - mock_thread = Mock() - from swarm_copy.app.routers.threads import delete_thread - await delete_thread(mock_session, mock_thread) - assert mock_session.delete.called - assert mock_session.commit.called +from swarm_copy.app.config import Settings +from swarm_copy.app.dependencies import get_settings +from swarm_copy.app.main import app + + +def test_create_thread(patch_required_env, httpx_mock, app_client, db_connection): + test_settings = Settings( + db={"prefix": db_connection}, + ) + app.dependency_overrides[get_settings] = lambda: test_settings + httpx_mock.add_response( + url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project" + ) + with app_client as app_client: + # Create a thread + create_output = app_client.post( + "/threads/?virtual_lab_id=test_vlab&project_id=test_project" + ).json() + assert create_output["thread_id"] + assert create_output["title"] == "New chat" + assert create_output["vlab_id"] == "test_vlab" + assert create_output["project_id"] == "test_project" + + +@pytest.mark.httpx_mock(can_send_already_matched_responses=True) +def test_get_threads(patch_required_env, httpx_mock, app_client, db_connection): + test_settings = Settings( + db={"prefix": db_connection}, + ) + app.dependency_overrides[get_settings] = lambda: test_settings + httpx_mock.add_response( + url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project" + ) + with app_client as app_client: + threads = app_client.get("/threads/").json() + assert not threads + create_output_1 = app_client.post( + "/threads/?virtual_lab_id=test_vlab&project_id=test_project" + ).json() + create_output_2 = app_client.post( + "/threads/?virtual_lab_id=test_vlab&project_id=test_project" + ).json() + threads = app_client.get("/threads/").json() + + assert len(threads) == 2 + assert threads[0] == create_output_1 + assert threads[1] == create_output_2 + + +@pytest.mark.httpx_mock(can_send_already_matched_responses=True) +def test_update_thread_title(patch_required_env, httpx_mock, app_client, db_connection): + test_settings = Settings( + db={"prefix": db_connection}, + ) + app.dependency_overrides[get_settings] = lambda: test_settings + + httpx_mock.add_response( + url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project" + ) + with app_client as app_client: + threads = app_client.get("/threads/").json() + assert not threads + + create_thread_response = app_client.post( + "/threads/?virtual_lab_id=test_vlab&project_id=test_project" + ).json() + thread_id = create_thread_response["thread_id"] + + updated_title = "Updated Thread Title" + update_response = app_client.patch( + f"/threads/{thread_id}", json={"title": updated_title} + ).json() + + assert update_response["title"] == updated_title + + +@pytest.mark.httpx_mock(can_send_already_matched_responses=True) +def test_delete_thread(patch_required_env, httpx_mock, app_client, db_connection): + test_settings = Settings( + db={"prefix": db_connection}, + ) + app.dependency_overrides[get_settings] = lambda: test_settings + + httpx_mock.add_response( + url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project" + ) + with app_client as app_client: + threads = app_client.get("/threads/").json() + assert not threads + + create_thread_response = app_client.post( + "/threads/?virtual_lab_id=test_vlab&project_id=test_project" + ).json() + thread_id = create_thread_response["thread_id"] + + threads = app_client.get("/threads/").json() + assert len(threads) == 1 + assert threads[0]["thread_id"] == thread_id + + delete_response = app_client.delete(f"/threads/{thread_id}").json() + assert delete_response["Acknowledged"] == "true" + + threads = app_client.get("/threads/").json() + assert not threads @pytest.fixture(autouse=True) diff --git a/swarm_copy_tests/conftest.py b/swarm_copy_tests/conftest.py index 8acd252..54ee423 100644 --- a/swarm_copy_tests/conftest.py +++ b/swarm_copy_tests/conftest.py @@ -6,16 +6,12 @@ import pytest import pytest_asyncio from fastapi.testclient import TestClient -from httpx import AsyncClient -from langchain_core.language_models.fake_chat_models import GenericFakeChatModel -from langchain_core.messages import AIMessage from sqlalchemy import MetaData from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from neuroagent.app.config import Settings -from neuroagent.app.dependencies import get_kg_token, get_settings -from neuroagent.app.main import app -from neuroagent.tools import GetMorphoTool +from swarm_copy.app.config import Settings +from swarm_copy.app.dependencies import get_kg_token, get_settings +from swarm_copy.app.main import app @pytest.fixture(name="settings") @@ -126,85 +122,3 @@ def get_resolve_query_output(): def brain_region_json_path(): br_path = Path(__file__).parent / "data" / "brainregion_hierarchy.json" return br_path - - -@pytest.fixture -async def fake_llm_with_tools(brain_region_json_path): - class FakeFuntionChatModel(GenericFakeChatModel): - def bind_tools(self, functions: list): - return self - - def bind_functions(self, **kwargs): - return self - - # If you need another fake response to use different tools, - # you can do in your test - # ```python - # llm, _ = await anext(fake_llm_with_tools) - # llm.responses = my_fake_responses - # ``` - # and simply bind the corresponding tools - fake_responses = [ - AIMessage( - content="", - additional_kwargs={ - "tool_calls": [ - { - "index": 0, - "id": "call_zHhwfNLSvGGHXMoILdIYtDVI", - "function": { - "arguments": '{"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"}', - "name": "get-morpho-tool", - }, - "type": "function", - } - ] - }, - response_metadata={"finish_reason": "tool_calls"}, - id="run-3828644d-197b-401b-8634-e6ecf01c2e7c-0", - tool_calls=[ - { - "name": "get-morpho-tool", - "args": { - "brain_region_id": ( - "http://api.brain-map.org/api/v2/data/Structure/549" - ) - }, - "id": "call_zHhwfNLSvGGHXMoILdIYtDVI", - } - ], - ), - AIMessage( - content="Great answer", - response_metadata={"finish_reason": "stop"}, - id="run-42768b30-044a-4263-8c5c-da61429aa9da-0", - ), - ] - - # If you use this tool in your test, DO NOT FORGET to mock the url response with the following snippet: - # - # ```python - # json_path = Path(__file__).resolve().parent.parent / "data" / "knowledge_graph.json" - # with open(json_path) as f: - # knowledge_graph_response = json.load(f) - - # httpx_mock.add_response( - # url="http://fake_url", - # json=knowledge_graph_response, - # ) - # ``` - # The http call is not mocked here because one might want to change the responses - # and the tools used. - async_client = AsyncClient() - tool = GetMorphoTool( - metadata={ - "url": "http://fake_url", - "search_size": 2, - "httpx_client": async_client, - "token": "fake_token", - "brainregion_path": brain_region_json_path, - } - ) - - yield FakeFuntionChatModel(messages=iter(fake_responses)), [tool], fake_responses - await async_client.aclose() From 2901f01239dc03456a8b0bab4bf9b645412bbc27 Mon Sep 17 00:00:00 2001 From: kanesoban Date: Tue, 17 Dec 2024 14:04:08 +0100 Subject: [PATCH 7/9] Fixed fixture --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 88ee7ea..65c4ac7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -107,7 +107,7 @@ def brain_region_json_path(): return br_path -@pytest.fixture +@pytest_asyncio.fixture async def fake_llm_with_tools(brain_region_json_path): class FakeFuntionChatModel(GenericFakeChatModel): def bind_tools(self, functions: list): From 9edab098b6888cddfddd8826d5602533913105a4 Mon Sep 17 00:00:00 2001 From: Boris-Bergsma Date: Thu, 19 Dec 2024 16:32:09 +0100 Subject: [PATCH 8/9] Add test for get_messages --- swarm_copy_tests/app/routers/test_threads.py | 111 +++++++++++++++++-- 1 file changed, 101 insertions(+), 10 deletions(-) diff --git a/swarm_copy_tests/app/routers/test_threads.py b/swarm_copy_tests/app/routers/test_threads.py index 3d54f02..cfbec4f 100644 --- a/swarm_copy_tests/app/routers/test_threads.py +++ b/swarm_copy_tests/app/routers/test_threads.py @@ -1,13 +1,17 @@ -import logging -from unittest.mock import AsyncMock, Mock, patch - import pytest +from swarm_copy.agent_routine import Agent, AgentsRoutine from swarm_copy.app.config import Settings -from swarm_copy.app.dependencies import get_settings +from swarm_copy.app.dependencies import ( + get_agents_routine, + get_settings, + get_starting_agent, +) from swarm_copy.app.main import app +from swarm_copy_tests.mock_client import create_mock_response +@pytest.mark.httpx_mock(can_send_already_matched_responses=True) def test_create_thread(patch_required_env, httpx_mock, app_client, db_connection): test_settings = Settings( db={"prefix": db_connection}, @@ -52,6 +56,87 @@ def test_get_threads(patch_required_env, httpx_mock, app_client, db_connection): assert threads[1] == create_output_2 +@pytest.mark.httpx_mock(can_send_already_matched_responses=True) +@pytest.mark.asyncio +async def test_get_messages( + patch_required_env, + httpx_mock, + app_client, + db_connection, + mock_openai_client, + get_weather_tool, +): + # Put data in the db + routine = AgentsRoutine(client=mock_openai_client) + + mock_openai_client.set_sequential_responses( + [ + create_mock_response( + message={"role": "assistant", "content": ""}, + function_calls=[ + {"name": "get_weather", "args": {"location": "Geneva"}} + ], + ), + create_mock_response( + {"role": "assistant", "content": "sample response content"} + ), + ] + ) + agent = Agent(tools=[get_weather_tool]) + + app.dependency_overrides[get_agents_routine] = lambda: routine + app.dependency_overrides[get_starting_agent] = lambda: agent + + test_settings = Settings( + db={"prefix": db_connection}, + ) + app.dependency_overrides[get_settings] = lambda: test_settings + httpx_mock.add_response( + url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project" + ) + + with app_client as app_client: + # wrong thread ID + wrong_response = app_client.get("/threads/test") + assert wrong_response.status_code == 404 + assert wrong_response.json() == {"detail": {"detail": "Thread not found."}} + + # Create a thread + create_output = app_client.post( + "/threads/?virtual_lab_id=test_vlab&project_id=test_project" + ).json() + thread_id = create_output["thread_id"] + + # Fill the thread + app_client.post( + f"/qa/chat/{thread_id}", + json={"query": "This is my query"}, + headers={"x-virtual-lab-id": "test_vlab", "x-project-id": "test_project"}, + ) + + create_output = app_client.post( + "/threads/?virtual_lab_id=test_vlab&project_id=test_project" + ).json() + empty_thread_id = create_output["thread_id"] + empty_messages = app_client.get(f"/threads/{empty_thread_id}").json() + assert empty_messages == [] + + # Get the messages of the thread + messages = app_client.get(f"/threads/{thread_id}").json() + + assert messages[0]["order"] == 0 + assert messages[0]["entity"] == "user" + assert messages[0]["msg_content"] == "This is my query" + assert messages[0]["message_id"] + assert messages[0]["creation_date"] + + assert messages[1]["order"] == 3 + assert messages[1]["entity"] == "ai_message" + assert messages[1]["msg_content"] == "sample response content" + assert messages[1]["message_id"] + assert messages[1]["creation_date"] + + @pytest.mark.httpx_mock(can_send_already_matched_responses=True) def test_update_thread_title(patch_required_env, httpx_mock, app_client, db_connection): test_settings = Settings( @@ -66,6 +151,13 @@ def test_update_thread_title(patch_required_env, httpx_mock, app_client, db_conn threads = app_client.get("/threads/").json() assert not threads + # Check when wrong thread id + wrong_response = app_client.patch( + "/threads/wrong_id", json={"title": "great_title"} + ) + assert wrong_response.status_code == 404 + assert wrong_response.json() == {"detail": {"detail": "Thread not found."}} + create_thread_response = app_client.post( "/threads/?virtual_lab_id=test_vlab&project_id=test_project" ).json() @@ -93,6 +185,11 @@ def test_delete_thread(patch_required_env, httpx_mock, app_client, db_connection threads = app_client.get("/threads/").json() assert not threads + # Check when wrong thread id + wrong_response = app_client.delete("/threads/wrong_id") + assert wrong_response.status_code == 404 + assert wrong_response.json() == {"detail": {"detail": "Thread not found."}} + create_thread_response = app_client.post( "/threads/?virtual_lab_id=test_vlab&project_id=test_project" ).json() @@ -107,9 +204,3 @@ def test_delete_thread(patch_required_env, httpx_mock, app_client, db_connection threads = app_client.get("/threads/").json() assert not threads - - -@pytest.fixture(autouse=True) -def stop_patches(): - yield - patch.stopall() From 5e948ded850c8b4487238abc701f9939fd709a85 Mon Sep 17 00:00:00 2001 From: Boris-Bergsma Date: Thu, 19 Dec 2024 17:19:05 +0100 Subject: [PATCH 9/9] small change in messages test --- swarm_copy_tests/app/routers/test_threads.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/swarm_copy_tests/app/routers/test_threads.py b/swarm_copy_tests/app/routers/test_threads.py index cfbec4f..5fd2221 100644 --- a/swarm_copy_tests/app/routers/test_threads.py +++ b/swarm_copy_tests/app/routers/test_threads.py @@ -106,6 +106,8 @@ async def test_get_messages( "/threads/?virtual_lab_id=test_vlab&project_id=test_project" ).json() thread_id = create_output["thread_id"] + empty_messages = app_client.get(f"/threads/{thread_id}").json() + assert empty_messages == [] # Fill the thread app_client.post( @@ -114,13 +116,6 @@ async def test_get_messages( headers={"x-virtual-lab-id": "test_vlab", "x-project-id": "test_project"}, ) - create_output = app_client.post( - "/threads/?virtual_lab_id=test_vlab&project_id=test_project" - ).json() - empty_thread_id = create_output["thread_id"] - empty_messages = app_client.get(f"/threads/{empty_thread_id}").json() - assert empty_messages == [] - # Get the messages of the thread messages = app_client.get(f"/threads/{thread_id}").json()