diff --git a/swarm_copy/tools/bluenaas_memodel_getall.py b/swarm_copy/tools/bluenaas_memodel_getall.py index db2501e..bc42146 100644 --- a/swarm_copy/tools/bluenaas_memodel_getall.py +++ b/swarm_copy/tools/bluenaas_memodel_getall.py @@ -29,7 +29,7 @@ class InputMEModelGetAll(BaseModel): page_size: int = Field( default=20, description="Number of results returned by the API." ) - simulation_type: Literal["single-neuron-simulation", "synaptome-simulation"] = ( + memodel_type: Literal["single-neuron-simulation", "synaptome-simulation"] = ( Field( default="single-neuron-simulation", description="Type of simulation to retrieve.", diff --git a/swarm_copy/tools/bluenaas_memodel_getone.py b/swarm_copy/tools/bluenaas_memodel_getone.py index 70774b0..f84acfa 100644 --- a/swarm_copy/tools/bluenaas_memodel_getone.py +++ b/swarm_copy/tools/bluenaas_memodel_getone.py @@ -24,7 +24,7 @@ class MEModelGetOneMetadata(BaseMetadata): class InputMEModelGetOne(BaseModel): """Inputs for the BlueNaaS single-neuron simulation.""" - simulation_id: str = Field( + memodel_id: str = Field( description="ID of the model to retrieve. Should be an https link." ) @@ -45,7 +45,7 @@ async def arun(self) -> MEModelResponse: ) response = await self.metadata.httpx_client.get( - url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/{quote_plus(self.input_schema.simulation_id)}", + url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/{quote_plus(self.input_schema.memodel_id)}", headers={"Authorization": f"Bearer {self.metadata.token}"}, ) diff --git a/swarm_copy_tests/conftest.py b/swarm_copy_tests/conftest.py index b11dfab..e7d384d 100644 --- a/swarm_copy_tests/conftest.py +++ b/swarm_copy_tests/conftest.py @@ -6,14 +6,12 @@ import pytest import pytest_asyncio from fastapi.testclient import TestClient -from httpx import AsyncClient from sqlalchemy import MetaData from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from swarm_copy.app.config import Settings from swarm_copy.app.dependencies import get_kg_token, get_settings from swarm_copy.app.main import app -from swarm_copy.tools import GetMorphoTool @pytest.fixture(name="app_client") diff --git a/swarm_copy_tests/tools/test_literature_search_tool.py b/swarm_copy_tests/tools/test_literature_search_tool.py index a4061d7..4b21481 100644 --- a/swarm_copy_tests/tools/test_literature_search_tool.py +++ b/swarm_copy_tests/tools/test_literature_search_tool.py @@ -11,16 +11,11 @@ class TestLiteratureSearchTool: @pytest.mark.asyncio - async def test_arun(self): - url = "http://fake_url" + async def test_arun(self, httpx_mock): + url = "http://fake_url?query=covid+19&retriever_k=100&use_reranker=true&reranker_k=5" reranker_k = 5 - client = httpx.AsyncClient() - client.get = AsyncMock() - response = Mock() - response.status_code = 200 - client.get.return_value = response - response.json.return_value = [ + fake_response = [ { "article_title": "Article title", "article_authors": ["Author1", "Author2"], @@ -32,11 +27,16 @@ async def test_arun(self): for _ in range(reranker_k) ] + httpx_mock.add_response( + url=url, + json=fake_response, + ) + tool = LiteratureSearchTool( input_schema=LiteratureSearchInput(query="covid 19"), metadata=LiteratureSearchMetadata( literature_search_url=url, - httpx_client=client, + httpx_client=httpx.AsyncClient(), token="fake_token", retriever_k=100, use_reranker=True,