Skip to content

Commit

Permalink
Added threads tests (#54)
Browse files Browse the repository at this point in the history
* Added tests

* Run new tests

* Fixed thread unittest issues

* lint

* Added conftest

* Review fixes

* Fixed fixture

* Add test for get_messages

* small change in messages test

---------

Co-authored-by: kanesoban <[email protected]>
Co-authored-by: Boris-Bergsma <[email protected]>
  • Loading branch information
3 people authored Dec 19, 2024
1 parent a2624e3 commit c6a7d0e
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Unit tests for the migrated tools
- CRUDs.
- BlueNaas CRUD tools
- Tests for threads module
- Cell types, resolving and utils tests
- app unit tests
- Tests of AgentsRoutine.
Expand Down
1 change: 1 addition & 0 deletions swarm_copy_tests/app/routers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for threads module."""
201 changes: 201 additions & 0 deletions swarm_copy_tests/app/routers/test_threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import pytest

from swarm_copy.agent_routine import Agent, AgentsRoutine
from swarm_copy.app.config import Settings
from swarm_copy.app.dependencies import (
get_agents_routine,
get_settings,
get_starting_agent,
)
from swarm_copy.app.main import app
from swarm_copy_tests.mock_client import create_mock_response


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
def test_create_thread(patch_required_env, httpx_mock, app_client, db_connection):
test_settings = Settings(
db={"prefix": db_connection},
)
app.dependency_overrides[get_settings] = lambda: test_settings
httpx_mock.add_response(
url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project"
)
with app_client as app_client:
# Create a thread
create_output = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
assert create_output["thread_id"]
assert create_output["title"] == "New chat"
assert create_output["vlab_id"] == "test_vlab"
assert create_output["project_id"] == "test_project"


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
def test_get_threads(patch_required_env, httpx_mock, app_client, db_connection):
test_settings = Settings(
db={"prefix": db_connection},
)
app.dependency_overrides[get_settings] = lambda: test_settings
httpx_mock.add_response(
url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project"
)
with app_client as app_client:
threads = app_client.get("/threads/").json()
assert not threads
create_output_1 = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
create_output_2 = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
threads = app_client.get("/threads/").json()

assert len(threads) == 2
assert threads[0] == create_output_1
assert threads[1] == create_output_2


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
@pytest.mark.asyncio
async def test_get_messages(
patch_required_env,
httpx_mock,
app_client,
db_connection,
mock_openai_client,
get_weather_tool,
):
# Put data in the db
routine = AgentsRoutine(client=mock_openai_client)

mock_openai_client.set_sequential_responses(
[
create_mock_response(
message={"role": "assistant", "content": ""},
function_calls=[
{"name": "get_weather", "args": {"location": "Geneva"}}
],
),
create_mock_response(
{"role": "assistant", "content": "sample response content"}
),
]
)
agent = Agent(tools=[get_weather_tool])

app.dependency_overrides[get_agents_routine] = lambda: routine
app.dependency_overrides[get_starting_agent] = lambda: agent

test_settings = Settings(
db={"prefix": db_connection},
)
app.dependency_overrides[get_settings] = lambda: test_settings
httpx_mock.add_response(
url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project"
)

with app_client as app_client:
# wrong thread ID
wrong_response = app_client.get("/threads/test")
assert wrong_response.status_code == 404
assert wrong_response.json() == {"detail": {"detail": "Thread not found."}}

# Create a thread
create_output = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
thread_id = create_output["thread_id"]
empty_messages = app_client.get(f"/threads/{thread_id}").json()
assert empty_messages == []

# Fill the thread
app_client.post(
f"/qa/chat/{thread_id}",
json={"query": "This is my query"},
headers={"x-virtual-lab-id": "test_vlab", "x-project-id": "test_project"},
)

# Get the messages of the thread
messages = app_client.get(f"/threads/{thread_id}").json()

assert messages[0]["order"] == 0
assert messages[0]["entity"] == "user"
assert messages[0]["msg_content"] == "This is my query"
assert messages[0]["message_id"]
assert messages[0]["creation_date"]

assert messages[1]["order"] == 3
assert messages[1]["entity"] == "ai_message"
assert messages[1]["msg_content"] == "sample response content"
assert messages[1]["message_id"]
assert messages[1]["creation_date"]


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
def test_update_thread_title(patch_required_env, httpx_mock, app_client, db_connection):
test_settings = Settings(
db={"prefix": db_connection},
)
app.dependency_overrides[get_settings] = lambda: test_settings

httpx_mock.add_response(
url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project"
)
with app_client as app_client:
threads = app_client.get("/threads/").json()
assert not threads

# Check when wrong thread id
wrong_response = app_client.patch(
"/threads/wrong_id", json={"title": "great_title"}
)
assert wrong_response.status_code == 404
assert wrong_response.json() == {"detail": {"detail": "Thread not found."}}

create_thread_response = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
thread_id = create_thread_response["thread_id"]

updated_title = "Updated Thread Title"
update_response = app_client.patch(
f"/threads/{thread_id}", json={"title": updated_title}
).json()

assert update_response["title"] == updated_title


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
def test_delete_thread(patch_required_env, httpx_mock, app_client, db_connection):
test_settings = Settings(
db={"prefix": db_connection},
)
app.dependency_overrides[get_settings] = lambda: test_settings

httpx_mock.add_response(
url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project"
)
with app_client as app_client:
threads = app_client.get("/threads/").json()
assert not threads

# Check when wrong thread id
wrong_response = app_client.delete("/threads/wrong_id")
assert wrong_response.status_code == 404
assert wrong_response.json() == {"detail": {"detail": "Thread not found."}}

create_thread_response = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
thread_id = create_thread_response["thread_id"]

threads = app_client.get("/threads/").json()
assert len(threads) == 1
assert threads[0]["thread_id"] == thread_id

delete_response = app_client.delete(f"/threads/{thread_id}").json()
assert delete_response["Acknowledged"] == "true"

threads = app_client.get("/threads/").json()
assert not threads

0 comments on commit c6a7d0e

Please sign in to comment.