Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added threads tests #54

Merged
merged 14 commits into from
Dec 19, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Unit tests for the migrated tools
- CRUDs.
- BlueNaas CRUD tools
- Tests for threads module
- Cell types, resolving and utils tests
- app unit tests
- Tests of AgentsRoutine.
Expand Down
1 change: 1 addition & 0 deletions swarm_copy_tests/app/routers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for threads module."""
201 changes: 201 additions & 0 deletions swarm_copy_tests/app/routers/test_threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import pytest

from swarm_copy.agent_routine import Agent, AgentsRoutine
from swarm_copy.app.config import Settings
from swarm_copy.app.dependencies import (
get_agents_routine,
get_settings,
get_starting_agent,
)
from swarm_copy.app.main import app
from swarm_copy_tests.mock_client import create_mock_response


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
def test_create_thread(patch_required_env, httpx_mock, app_client, db_connection):
test_settings = Settings(
db={"prefix": db_connection},
)
app.dependency_overrides[get_settings] = lambda: test_settings
httpx_mock.add_response(
url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project"
)
with app_client as app_client:
# Create a thread
create_output = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
assert create_output["thread_id"]
assert create_output["title"] == "New chat"
assert create_output["vlab_id"] == "test_vlab"
assert create_output["project_id"] == "test_project"


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
def test_get_threads(patch_required_env, httpx_mock, app_client, db_connection):
test_settings = Settings(
db={"prefix": db_connection},
)
app.dependency_overrides[get_settings] = lambda: test_settings
httpx_mock.add_response(
url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project"
)
with app_client as app_client:
threads = app_client.get("/threads/").json()
assert not threads
create_output_1 = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
create_output_2 = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
threads = app_client.get("/threads/").json()

assert len(threads) == 2
assert threads[0] == create_output_1
assert threads[1] == create_output_2


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
@pytest.mark.asyncio
async def test_get_messages(
patch_required_env,
httpx_mock,
app_client,
db_connection,
mock_openai_client,
get_weather_tool,
):
# Put data in the db
routine = AgentsRoutine(client=mock_openai_client)

mock_openai_client.set_sequential_responses(
[
create_mock_response(
message={"role": "assistant", "content": ""},
function_calls=[
{"name": "get_weather", "args": {"location": "Geneva"}}
],
),
create_mock_response(
{"role": "assistant", "content": "sample response content"}
),
]
)
agent = Agent(tools=[get_weather_tool])

app.dependency_overrides[get_agents_routine] = lambda: routine
app.dependency_overrides[get_starting_agent] = lambda: agent

test_settings = Settings(
db={"prefix": db_connection},
)
app.dependency_overrides[get_settings] = lambda: test_settings
httpx_mock.add_response(
url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project"
)

with app_client as app_client:
# wrong thread ID
wrong_response = app_client.get("/threads/test")
assert wrong_response.status_code == 404
assert wrong_response.json() == {"detail": {"detail": "Thread not found."}}

# Create a thread
create_output = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
thread_id = create_output["thread_id"]
empty_messages = app_client.get(f"/threads/{thread_id}").json()
assert empty_messages == []

# Fill the thread
app_client.post(
f"/qa/chat/{thread_id}",
json={"query": "This is my query"},
headers={"x-virtual-lab-id": "test_vlab", "x-project-id": "test_project"},
)

# Get the messages of the thread
messages = app_client.get(f"/threads/{thread_id}").json()

assert messages[0]["order"] == 0
assert messages[0]["entity"] == "user"
assert messages[0]["msg_content"] == "This is my query"
assert messages[0]["message_id"]
assert messages[0]["creation_date"]

assert messages[1]["order"] == 3
assert messages[1]["entity"] == "ai_message"
assert messages[1]["msg_content"] == "sample response content"
assert messages[1]["message_id"]
assert messages[1]["creation_date"]


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
def test_update_thread_title(patch_required_env, httpx_mock, app_client, db_connection):
test_settings = Settings(
db={"prefix": db_connection},
)
app.dependency_overrides[get_settings] = lambda: test_settings

httpx_mock.add_response(
url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project"
)
with app_client as app_client:
threads = app_client.get("/threads/").json()
assert not threads

# Check when wrong thread id
wrong_response = app_client.patch(
"/threads/wrong_id", json={"title": "great_title"}
)
assert wrong_response.status_code == 404
assert wrong_response.json() == {"detail": {"detail": "Thread not found."}}

create_thread_response = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
thread_id = create_thread_response["thread_id"]

updated_title = "Updated Thread Title"
update_response = app_client.patch(
f"/threads/{thread_id}", json={"title": updated_title}
).json()

assert update_response["title"] == updated_title


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
def test_delete_thread(patch_required_env, httpx_mock, app_client, db_connection):
test_settings = Settings(
db={"prefix": db_connection},
)
app.dependency_overrides[get_settings] = lambda: test_settings

httpx_mock.add_response(
url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project"
)
with app_client as app_client:
threads = app_client.get("/threads/").json()
assert not threads

# Check when wrong thread id
wrong_response = app_client.delete("/threads/wrong_id")
assert wrong_response.status_code == 404
assert wrong_response.json() == {"detail": {"detail": "Thread not found."}}

create_thread_response = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
thread_id = create_thread_response["thread_id"]

threads = app_client.get("/threads/").json()
assert len(threads) == 1
assert threads[0]["thread_id"] == thread_id

delete_response = app_client.delete(f"/threads/{thread_id}").json()
assert delete_response["Acknowledged"] == "true"

threads = app_client.get("/threads/").json()
assert not threads
Loading