Skip to content

Commit

Permalink
Merged main
Browse files Browse the repository at this point in the history
  • Loading branch information
kanesoban committed Dec 18, 2024
2 parents cc82dee + 8f54950 commit 386c9f8
Show file tree
Hide file tree
Showing 15 changed files with 1,359 additions and 31 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Unit tests for the migrated tools
- CRUDs.
- BlueNaas CRUD tools
- app unit tests
- Tests of AgentsRoutine.
- Unit tests for database

### Fixed
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ convention = "numpy"

[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D"]
"swarm_copy_tests/*" = ["D"]

[tool.mypy]
mypy_path = "src"
Expand Down
4 changes: 2 additions & 2 deletions swarm_copy/run.py → swarm_copy/agent_routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ async def astream(
) -> AsyncIterator[str | Response]:
"""Stream the agent response."""
active_agent = agent
context_variables = copy.deepcopy(context_variables)

history = copy.deepcopy(messages)
init_len = len(messages)
is_streaming = False
Expand Down Expand Up @@ -251,7 +251,7 @@ async def astream(
stream=True,
)
async for chunk in completion: # type: ignore
delta = json.loads(chunk.choices[0].delta.json())
delta = json.loads(chunk.choices[0].delta.model_dump_json())

# Check for tool calls
if delta["tool_calls"]:
Expand Down
6 changes: 2 additions & 4 deletions swarm_copy/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from starlette.status import HTTP_401_UNAUTHORIZED

from swarm_copy.agent_routine import AgentsRoutine
from swarm_copy.app.app_utils import validate_project
from swarm_copy.app.config import Settings
from swarm_copy.app.database.sql_schemas import Threads
from swarm_copy.cell_types import CellTypesMeta
from swarm_copy.new_types import Agent
from swarm_copy.run import AgentsRoutine
from swarm_copy.tools import (
ElectrophysFeatureTool,
GetMorphoTool,
Expand Down Expand Up @@ -204,9 +204,7 @@ async def get_vlab_and_project(
if not thread:
raise HTTPException(
status_code=404,
detail={
"detail": "Thread not found.",
},
detail="Thread not found.",
)
if thread and thread.vlab_id and thread.project_id:
vlab_and_project = {
Expand Down
2 changes: 1 addition & 1 deletion swarm_copy/app/routers/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession

from swarm_copy.agent_routine import AgentsRoutine
from swarm_copy.app.database.db_utils import get_history, get_thread, save_history
from swarm_copy.app.database.sql_schemas import Threads
from swarm_copy.app.dependencies import (
Expand All @@ -17,7 +18,6 @@
get_user_id,
)
from swarm_copy.new_types import Agent, AgentRequest, AgentResponse
from swarm_copy.run import AgentsRoutine
from swarm_copy.stream import stream_agent_response

router = APIRouter(prefix="/qa", tags=["Run the agent"])
Expand Down
2 changes: 1 addition & 1 deletion swarm_copy/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from openai import AsyncOpenAI
from sqlalchemy.ext.asyncio import AsyncSession

from swarm_copy.agent_routine import AgentsRoutine
from swarm_copy.app.database.db_utils import save_history
from swarm_copy.new_types import Agent, Response
from swarm_copy.run import AgentsRoutine


async def stream_agent_response(
Expand Down
1 change: 1 addition & 0 deletions swarm_copy_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Sarm copy tests."""
75 changes: 75 additions & 0 deletions swarm_copy_tests/app/test_app_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Test app utils."""

from unittest.mock import AsyncMock, patch

import pytest
from fastapi.exceptions import HTTPException
from httpx import AsyncClient

from swarm_copy.app.app_utils import setup_engine, validate_project
from swarm_copy.app.config import Settings


@pytest.mark.asyncio
async def test_validate_project(patch_required_env, httpx_mock, monkeypatch):
monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "true")
httpx_client = AsyncClient()
token = "fake_token"
test_vp = {"vlab_id": "test_vlab_DB", "project_id": "project_id_DB"}
vlab_url = "https://openbluebrain.com/api/virtual-lab-manager/virtual-labs"

# test with bad config
httpx_mock.add_response(
url=f'{vlab_url}/{test_vp["vlab_id"]}/projects/{test_vp["project_id"]}',
status_code=404,
)
with pytest.raises(HTTPException) as error:
await validate_project(
httpx_client=httpx_client,
vlab_id=test_vp["vlab_id"],
project_id=test_vp["project_id"],
token=token,
vlab_project_url=vlab_url,
)
assert error.value.status_code == 401

# test with good config
httpx_mock.add_response(
url=f'{vlab_url}/{test_vp["vlab_id"]}/projects/{test_vp["project_id"]}',
json="test_project_ID",
)
await validate_project(
httpx_client=httpx_client,
vlab_id=test_vp["vlab_id"],
project_id=test_vp["project_id"],
token=token,
vlab_project_url=vlab_url,
)
# we jsut want to assert that the httpx_mock was called.


@patch("neuroagent.app.app_utils.create_async_engine")
def test_setup_engine(create_engine_mock, monkeypatch, patch_required_env):
create_engine_mock.return_value = AsyncMock()

monkeypatch.setenv("NEUROAGENT_DB__PREFIX", "prefix")

settings = Settings()

connection_string = "postgresql+asyncpg://user:password@localhost/dbname"
retval = setup_engine(settings=settings, connection_string=connection_string)
assert retval is not None


@patch("neuroagent.app.app_utils.create_async_engine")
def test_setup_engine_no_connection_string(
create_engine_mock, monkeypatch, patch_required_env
):
create_engine_mock.return_value = AsyncMock()

monkeypatch.setenv("NEUROAGENT_DB__PREFIX", "prefix")

settings = Settings()

retval = setup_engine(settings=settings, connection_string=None)
assert retval is None
71 changes: 71 additions & 0 deletions swarm_copy_tests/app/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Test config"""

import pytest
from pydantic import ValidationError

from swarm_copy.app.config import Settings


def test_required(monkeypatch, patch_required_env):
settings = Settings()

assert settings.tools.literature.url == "https://fake_url"
assert settings.knowledge_graph.base_url == "https://fake_url/api/nexus/v1"
assert settings.openai.token.get_secret_value() == "dummy"

# make sure not case sensitive
monkeypatch.delenv("NEUROAGENT_TOOLS__LITERATURE__URL")
monkeypatch.setenv("neuroagent_tools__literature__URL", "https://new_fake_url")

settings = Settings()
assert settings.tools.literature.url == "https://new_fake_url"


def test_no_settings():
# We get an error when no custom variables provided
with pytest.raises(ValidationError):
Settings()


def test_setup_tools(monkeypatch, patch_required_env):
monkeypatch.setenv("NEUROAGENT_TOOLS__TRACE__SEARCH_SIZE", "20")
monkeypatch.setenv("NEUROAGENT_TOOLS__MORPHO__SEARCH_SIZE", "20")
monkeypatch.setenv("NEUROAGENT_TOOLS__KG_MORPHO_FEATURES__SEARCH_SIZE", "20")

monkeypatch.setenv("NEUROAGENT_KEYCLOAK__USERNAME", "user")
monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "pass")

settings = Settings()

assert settings.tools.morpho.search_size == 20
assert settings.tools.trace.search_size == 20
assert settings.tools.kg_morpho_features.search_size == 20
assert settings.keycloak.username == "user"
assert settings.keycloak.password.get_secret_value() == "pass"


def test_check_consistency(monkeypatch):
# We get an error when no custom variables provided
url = "https://fake_url"
monkeypatch.setenv("NEUROAGENT_TOOLS__LITERATURE__URL", url)
monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__URL", url)

with pytest.raises(ValueError):
Settings()

monkeypatch.setenv("NEUROAGENT_GENERATIVE__OPENAI__TOKEN", "dummy")
monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "true")

with pytest.raises(ValueError):
Settings()

monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "false")

with pytest.raises(ValueError):
Settings()

monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__BASE_URL", "http://fake_nexus.com")
monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "true")
monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "Hello")

Settings()
Loading

0 comments on commit 386c9f8

Please sign in to comment.