Skip to content

Commit

Permalink
Fix streaming and tool outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
BoBer78 committed Dec 3, 2024
1 parent d5a2ccf commit 341c361
Show file tree
Hide file tree
Showing 15 changed files with 36 additions and 32 deletions.
2 changes: 1 addition & 1 deletion swarm_copy/cell_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class CellTypesMeta:
"""

def __init__(self) -> None:
self.name_: dict[str, str] = {}
self.name_: dict[str, str | None] = {}
self.descendants_ids: dict[str, set[str]] = {}

def descendants(self, ids: str | set[str]) -> set[str]:
Expand Down
2 changes: 1 addition & 1 deletion swarm_copy/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ async def astream(
) -> AsyncIterator[str | Response]:
"""Stream the agent response."""
active_agent = agent
context_variables = copy.deepcopy(context_variables)

history = copy.deepcopy(messages)
init_len = len(messages)
is_streaming = False
Expand Down
2 changes: 2 additions & 0 deletions swarm_copy/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Any, AsyncIterator

from httpx import AsyncClient
from openai import AsyncOpenAI
from sqlalchemy.ext.asyncio import AsyncSession

Expand All @@ -26,6 +27,7 @@ async def stream_agent_response(
)
else:
connected_agents_routine = AgentsRoutine(client=None)
context_variables["httpx_client"] = AsyncClient(timeout=None, verify=False)

iterator = connected_agents_routine.astream(agent, messages, context_variables)
async for chunk in iterator:
Expand Down
8 changes: 4 additions & 4 deletions swarm_copy/tools/bluenaas_memodel_getall.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""BlueNaaS single cell stimulation, simulation and synapse placement tool."""

import logging
from typing import ClassVar, Literal
from typing import Any, ClassVar, Literal

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -46,7 +46,7 @@ class MEModelGetAllTool(BaseTool):
metadata: MEModelGetAllMetadata
input_schema: InputMEModelGetAll

async def arun(self) -> PaginatedResponseUnionMEModelResponseSynaptomeModelResponse:
async def arun(self) -> dict[str, Any]:
"""Run the MEModelGetAll tool."""
logger.info(
f"Running MEModelGetAll tool with inputs {self.input_schema.model_dump()}"
Expand All @@ -61,7 +61,7 @@ async def arun(self) -> PaginatedResponseUnionMEModelResponseSynaptomeModelRespo
},
headers={"Authorization": f"Bearer {self.metadata.token}"},
)
breakpoint()

return PaginatedResponseUnionMEModelResponseSynaptomeModelResponse(
**response.json()
)
).model_dump()
6 changes: 3 additions & 3 deletions swarm_copy/tools/bluenaas_memodel_getone.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""BlueNaaS single cell stimulation, simulation and synapse placement tool."""

import logging
from typing import ClassVar
from typing import Any, ClassVar
from urllib.parse import quote_plus

from pydantic import BaseModel, Field
Expand Down Expand Up @@ -38,7 +38,7 @@ class MEModelGetOneTool(BaseTool):
metadata: MEModelGetOneMetadata
input_schema: InputMEModelGetOne

async def arun(self) -> MEModelResponse:
async def arun(self) -> dict[str, Any]:
"""Run the MEModelGetOne tool."""
logger.info(
f"Running MEModelGetOne tool with inputs {self.input_schema.model_dump()}"
Expand All @@ -49,4 +49,4 @@ async def arun(self) -> MEModelResponse:
headers={"Authorization": f"Bearer {self.metadata.token}"},
)

return MEModelResponse(**response.json())
return MEModelResponse(**response.json()).model_dump()
8 changes: 5 additions & 3 deletions swarm_copy/tools/bluenaas_scs_getall.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""BlueNaaS single cell stimulation, simulation and synapse placement tool."""

import logging
from typing import ClassVar, Literal
from typing import Any, ClassVar, Literal

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -47,7 +47,7 @@ class SCSGetAllTool(BaseTool):
metadata: SCSGetAllMetadata
input_schema: InputSCSGetAll

async def arun(self) -> PaginatedResponseSimulationDetailsResponse:
async def arun(self) -> dict[str, Any]:
"""Run the SCSGetAll tool."""
logger.info(
f"Running SCSGetAll tool with inputs {self.input_schema.model_dump()}"
Expand All @@ -63,4 +63,6 @@ async def arun(self) -> PaginatedResponseSimulationDetailsResponse:
headers={"Authorization": f"Bearer {self.metadata.token}"},
)

return PaginatedResponseSimulationDetailsResponse(**response.json())
return PaginatedResponseSimulationDetailsResponse(
**response.json()
).model_dump()
6 changes: 3 additions & 3 deletions swarm_copy/tools/bluenaas_scs_getone.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""BlueNaaS single cell stimulation, simulation and synapse placement tool."""

import logging
from typing import ClassVar
from typing import Any, ClassVar

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -39,7 +39,7 @@ class SCSGetOneTool(BaseTool):
metadata: SCSGetOneMetadata
input_schema: InputSCSGetOne

async def arun(self) -> SimulationDetailsResponse:
async def arun(self) -> dict[str, Any]:
"""Run the SCSGetOne tool."""
logger.info(
f"Running SCSGetOne tool with inputs {self.input_schema.model_dump()}"
Expand All @@ -50,4 +50,4 @@ async def arun(self) -> SimulationDetailsResponse:
headers={"Authorization": f"Bearer {self.metadata.token}"},
)

return SimulationDetailsResponse(**response.json())
return SimulationDetailsResponse(**response.json()).model_dump()
4 changes: 2 additions & 2 deletions swarm_copy/tools/bluenaas_scs_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class SCSPostTool(BaseTool):
metadata: SCSPostMetadata
input_schema: InputSCSPost

async def arun(self) -> SCSPostOutput:
async def arun(self) -> dict[str, Any]:
"""Run the SCSPost tool."""
logger.info(
f"Running SCSPost tool with inputs {self.input_schema.model_dump()}"
Expand Down Expand Up @@ -126,7 +126,7 @@ async def arun(self) -> SCSPostOutput:
status=json_response["status"],
name=json_response["name"],
error=json_response["error"],
)
).model_dump()

@staticmethod
def create_json_api(
Expand Down
4 changes: 2 additions & 2 deletions swarm_copy/tools/electrophys_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class ElectrophysFeatureTool(BaseTool):
input_schema: ElectrophysInput
metadata: ElectrophysMetadata

async def arun(self) -> FeatureOutput:
async def arun(self) -> dict[str, Any]:
"""Give features about trace."""
logger.info(
f"Entering electrophys tool. Inputs: {self.input_schema.trace_id=}, {self.input_schema.calculated_feature=},"
Expand Down Expand Up @@ -329,4 +329,4 @@ async def arun(self) -> FeatureOutput:
)
return FeatureOutput(
brain_region=metadata.brain_region, feature_dict=output_features
)
).model_dump()
4 changes: 2 additions & 2 deletions swarm_copy/tools/get_morpho_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class GetMorphoTool(BaseTool):
input_schema: GetMorphoInput
metadata: GetMorphoMetadata

async def arun(self) -> list[KnowledgeGraphOutput]:
async def arun(self) -> list[dict[str, Any]]:
"""From a brain region ID, extract morphologies.
Returns
Expand Down Expand Up @@ -107,7 +107,7 @@ async def arun(self) -> list[KnowledgeGraphOutput]:
)

# Process the output and return.
return self._process_output(response.json())
return [output.model_dump() for output in self._process_output(response.json())]

def create_query(
self, brain_regions_ids: set[str], mtype_ids: set[str] | None = None
Expand Down
4 changes: 2 additions & 2 deletions swarm_copy/tools/kg_morpho_features_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class KGMorphoFeatureTool(BaseTool):
input_schema: KGMorphoFeatureInput
metadata: KGMorphoFeatureMetadata

async def arun(self) -> list[KGMorphoFeatureOutput]:
async def arun(self) -> list[dict[str, Any]]:
"""Run the tool async.
Returns
Expand Down Expand Up @@ -216,7 +216,7 @@ async def arun(self) -> list[KGMorphoFeatureOutput]:
json=entire_query,
)

return self._process_output(response.json())
return [output.model_dump() for output in self._process_output(response.json())]

def create_query(
self, brain_regions_ids: set[str], features: KGFeatureInput
Expand Down
4 changes: 2 additions & 2 deletions swarm_copy/tools/literature_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class LiteratureSearchTool(BaseTool):
input_schema: LiteratureSearchInput
metadata: LiteratureSearchMetadata

async def arun(self) -> list[ParagraphMetadata]:
async def arun(self) -> list[dict[str, Any]]:
"""Async search the scientific literature and returns citations.
Returns
Expand All @@ -88,7 +88,7 @@ async def arun(self) -> list[ParagraphMetadata]:
timeout=None,
)

return self._process_output(response.json())
return [output.model_dump() for output in self._process_output(response.json())]

@staticmethod
def _process_output(output: list[dict[str, Any]]) -> list[ParagraphMetadata]:
Expand Down
4 changes: 2 additions & 2 deletions swarm_copy/tools/morphology_features_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class MorphologyFeatureTool(BaseTool):
input_schema: MorphologyFeatureInput
metadata: MorphologyFeatureMetadata

async def arun(self) -> list[MorphologyFeatureOutput]:
async def arun(self) -> list[dict[str, Any]]:
"""Give features about morphology."""
logger.info(
f"Entering morphology feature tool. Inputs: {self.input_schema.morphology_id=}"
Expand All @@ -71,7 +71,7 @@ async def arun(self) -> list[MorphologyFeatureOutput]:
return [
MorphologyFeatureOutput(
brain_region=metadata.brain_region, feature_dict=features
)
).model_dump()
]

def get_features(self, morphology_content: bytes, reader: str) -> dict[str, Any]:
Expand Down
6 changes: 3 additions & 3 deletions swarm_copy/tools/resolve_entities_tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tool to resolve the brain region from natural english to a KG ID."""

import logging
from typing import ClassVar
from typing import Any, ClassVar

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -86,7 +86,7 @@ class ResolveEntitiesTool(BaseTool):

async def arun(
self,
) -> list[BRResolveOutput | MTypeResolveOutput | EtypeResolveOutput]:
) -> list[dict[str, Any]]:
"""Given a brain region in natural language, resolve its ID."""
logger.info(
f"Entering Brain Region resolver tool. Inputs: {self.input_schema.brain_region=}, "
Expand Down Expand Up @@ -141,4 +141,4 @@ async def arun(
)
)

return output
return [out.model_dump() for out in output]
4 changes: 2 additions & 2 deletions swarm_copy/tools/traces_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class GetTracesTool(BaseTool):
input_schema: GetTracesInput
metadata: GetTracesMetadata

async def arun(self) -> list[TracesOutput]:
async def arun(self) -> list[dict[str, Any]]:
"""From a brain region ID, extract traces."""
logger.info(
f"Entering get trace tool. Inputs: {self.input_schema.brain_region_id=}, {self.input_schema.etype_id=}"
Expand All @@ -93,7 +93,7 @@ async def arun(self) -> list[TracesOutput]:
headers={"Authorization": f"Bearer {self.metadata.token}"},
json=entire_query,
)
return self._process_output(response.json())
return [output.model_dump() for output in self._process_output(response.json())]

def create_query(
self,
Expand Down

0 comments on commit 341c361

Please sign in to comment.