From b5fbdd9e7508c4bcacc0481fb4fb602a5f345baa Mon Sep 17 00:00:00 2001 From: Nicolas Frank <58003267+WonderPG@users.noreply.github.com> Date: Tue, 19 Nov 2024 10:06:02 +0100 Subject: [PATCH 1/2] Bluenaas tools (#50) Co-authored-by: Nicolas Frank --- CHANGELOG.md | 1 + swarm_copy/README.md | 5 + swarm_copy/app/config.py | 2 +- swarm_copy/app/dependencies.py | 19 +- swarm_copy/app/routers/qa.py | 19 +- swarm_copy/bluenaas_models.py | 307 ++++++++++++++++++++ swarm_copy/tools/__init__.py | 12 +- swarm_copy/tools/base_tool.py | 21 +- swarm_copy/tools/bluenaas_memodel_getall.py | 67 +++++ swarm_copy/tools/bluenaas_memodel_getone.py | 52 ++++ swarm_copy/tools/bluenaas_scs_getall.py | 66 +++++ swarm_copy/tools/bluenaas_scs_getone.py | 53 ++++ swarm_copy/tools/bluenaas_scs_post.py | 175 +++++++++++ swarm_copy/tools/get_me_model_tool.py | 223 -------------- 14 files changed, 779 insertions(+), 243 deletions(-) create mode 100644 swarm_copy/README.md create mode 100644 swarm_copy/bluenaas_models.py create mode 100644 swarm_copy/tools/bluenaas_memodel_getall.py create mode 100644 swarm_copy/tools/bluenaas_memodel_getone.py create mode 100644 swarm_copy/tools/bluenaas_scs_getall.py create mode 100644 swarm_copy/tools/bluenaas_scs_getone.py create mode 100644 swarm_copy/tools/bluenaas_scs_post.py delete mode 100644 swarm_copy/tools/get_me_model_tool.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 687691c..f9fd72f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Integrated Alembic for managing chat history migrations - Tool implementations without langchain or langgraph dependencies - CRUDs. +- BlueNaas CRUD tools ## [0.3.3] - 30.10.2024 diff --git a/swarm_copy/README.md b/swarm_copy/README.md new file mode 100644 index 0000000..b7bdaa2 --- /dev/null +++ b/swarm_copy/README.md @@ -0,0 +1,5 @@ +To generate the pydantic models from an API openapi.json spec, run the following command: +```bash +pip install datamodel-code-generator +datamodel-codegen --enum-field-as-literal=all --target-python-version=3.10 --use-annotated --reuse-model --input-file-type=openapi --url=TARGET_URL/openapi.json --output=OUTPUT --output-model-type=pydantic_v2.BaseModel +``` diff --git a/swarm_copy/app/config.py b/swarm_copy/app/config.py index 0ea0450..f0334e9 100644 --- a/swarm_copy/app/config.py +++ b/swarm_copy/app/config.py @@ -115,7 +115,7 @@ class SettingsGetMEModel(BaseModel): class SettingsBlueNaaS(BaseModel): """BlueNaaS settings.""" - url: str = "https://openbluebrain.com/api/bluenaas/simulation/single-neuron/run" + url: str = "https://openbluebrain.com/api/bluenaas" model_config = ConfigDict(frozen=True) diff --git a/swarm_copy/app/dependencies.py b/swarm_copy/app/dependencies.py index ee27367..f021c3e 100644 --- a/swarm_copy/app/dependencies.py +++ b/swarm_copy/app/dependencies.py @@ -21,13 +21,17 @@ from swarm_copy.run import AgentsRoutine from swarm_copy.tools import ( ElectrophysFeatureTool, - GetMEModelTool, GetMorphoTool, GetTracesTool, KGMorphoFeatureTool, LiteratureSearchTool, + MEModelGetAllTool, + MEModelGetOneTool, MorphologyFeatureTool, ResolveEntitiesTool, + SCSGetAllTool, + SCSGetOneTool, + SCSPostTool, ) from swarm_copy.utils import RegionMeta, get_file_from_KG @@ -186,8 +190,8 @@ async def get_vlab_and_project( } elif not settings.keycloak.validate_token: vlab_and_project = { - "vlab_id": "430108e9-a81d-4b13-b7b6-afca00195908", - "project_id": "eff09ea1-be16-47f0-91b6-52a3ea3ee575", + "vlab_id": "32c83739-f39c-49d1-833f-58c981ebd2a2", + "project_id": "123251a1-be18-4146-87b5-5ca2f8bfaf48", } else: thread_id = request.path_params.get("thread_id") @@ -237,9 +241,13 @@ def get_starting_agent( You must always specify in your answers from which brain regions the information is extracted. Do no blindly repeat the brain region requested by the user, use the output of the tools instead.""", tools=[ + SCSGetAllTool, + SCSGetOneTool, + SCSPostTool, + MEModelGetAllTool, + MEModelGetOneTool, LiteratureSearchTool, ElectrophysFeatureTool, - GetMEModelTool, GetMorphoTool, KGMorphoFeatureTool, MorphologyFeatureTool, @@ -261,6 +269,8 @@ def get_context_variables( return { "starting_agent": starting_agent, "token": token, + "vlab_id": "32c83739-f39c-49d1-833f-58c981ebd2a2", # New god account vlab. Replaced by actual id in endpoint for now. Meant for usage without history + "project_id": "123251a1-be18-4146-87b5-5ca2f8bfaf48", # New god account proj. Replaced by actual id in endpoint for now. Meant for usage without history "retriever_k": settings.tools.literature.retriever_k, "reranker_k": settings.tools.literature.reranker_k, "use_reranker": settings.tools.literature.use_reranker, @@ -274,6 +284,7 @@ def get_context_variables( "trace_search_size": settings.tools.trace.search_size, "kg_sparql_url": settings.knowledge_graph.sparql_url, "kg_class_view_url": settings.knowledge_graph.class_view_url, + "bluenaas_url": settings.tools.bluenaas.url, "httpx_client": httpx_client, } diff --git a/swarm_copy/app/routers/qa.py b/swarm_copy/app/routers/qa.py index c2b1245..f7b2d3a 100644 --- a/swarm_copy/app/routers/qa.py +++ b/swarm_copy/app/routers/qa.py @@ -7,7 +7,8 @@ from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession -from swarm_copy.app.database.db_utils import get_history, save_history +from swarm_copy.app.database.db_utils import get_history, get_thread, save_history +from swarm_copy.app.database.sql_schemas import Threads from swarm_copy.app.dependencies import ( get_agents_routine, get_context_variables, @@ -46,17 +47,21 @@ async def run_chat_agent( context_variables: Annotated[dict[str, Any], Depends(get_context_variables)], session: Annotated[AsyncSession, Depends(get_session)], user_id: Annotated[str, Depends(get_user_id)], - thread_id: str, + thread: Annotated[Threads, Depends(get_thread)], messages: Annotated[list[dict[str, Any]], Depends(get_history)], ) -> AgentResponse: """Run a single agent query.""" + # Temporary solution + context_variables["vlab_id"] = thread.vlab_id + context_variables["project_id"] = thread.project_id + messages.append({"role": "user", "content": user_request.query}) response = await agent_routine.arun(agent, messages, context_variables) await save_history( user_id=user_id, history=response.messages, offset=len(messages) - 1, - thread_id=thread_id, + thread_id=thread.thread_id, session=session, ) return AgentResponse(message=response.messages[-1]["content"]) @@ -70,10 +75,14 @@ async def stream_chat_agent( context_variables: Annotated[dict[str, Any], Depends(get_context_variables)], session: Annotated[AsyncSession, Depends(get_session)], user_id: Annotated[str, Depends(get_user_id)], - thread_id: str, + thread: Annotated[Threads, Depends(get_thread)], messages: Annotated[list[dict[str, Any]], Depends(get_history)], ) -> StreamingResponse: """Run a single agent query in a streamed fashion.""" + # Temporary solution + context_variables["vlab_id"] = thread.vlab_id + context_variables["project_id"] = thread.project_id + messages.append({"role": "user", "content": user_request.query}) stream_generator = stream_agent_response( agents_routine, @@ -81,7 +90,7 @@ async def stream_chat_agent( messages, context_variables, user_id, - thread_id, + thread.thread_id, session, ) return StreamingResponse(stream_generator, media_type="text/event-stream") diff --git a/swarm_copy/bluenaas_models.py b/swarm_copy/bluenaas_models.py new file mode 100644 index 0000000..9b48b25 --- /dev/null +++ b/swarm_copy/bluenaas_models.py @@ -0,0 +1,307 @@ +# generated by datamodel-codegen: +# filename: https://openbluebrain.com/api/bluenaas/openapi.json +# timestamp: 2024-11-13T15:10:19+00:00 +"""Pydantic models of the BlueNaaS API.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Annotated, Any, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, Field, RootModel + + +class BodyPlaceSynapsesApiBluenaasValidationSynapseFormulaPost(BaseModel): + """Placeholder.""" + + formula: Annotated[str, Field(title="Formula")] + + +class BrainRegion(BaseModel): + """Placeholder.""" + + id: Annotated[str, Field(title="Id")] + label: Annotated[str, Field(title="Label")] + + +class DeprecateNexusResponse(BaseModel): + """Placeholder.""" + + id: Annotated[str, Field(title="Id")] + deprecated: Annotated[bool, Field(title="Deprecated")] + updated_at: Annotated[datetime, Field(title="Updated At")] + + +class ExclusionRule(BaseModel): + """Placeholder.""" + + distance_soma_gte: Annotated[Optional[float], Field(title="Distance Soma Gte")] = ( + None + ) + distance_soma_lte: Annotated[Optional[float], Field(title="Distance Soma Lte")] = ( + None + ) + + +class ExperimentSetupConfig(BaseModel): + """Placeholder.""" + + celsius: Annotated[float, Field(title="Celsius")] + vinit: Annotated[float, Field(title="Vinit")] + hypamp: Annotated[float, Field(title="Hypamp")] + max_time: Annotated[float, Field(le=3000.0, title="Max Time")] + time_step: Annotated[float, Field(title="Time Step")] + seed: Annotated[int, Field(title="Seed")] + + +class RecordingLocation(BaseModel): + """Placeholder.""" + + section: Annotated[str, Field(title="Section")] + offset: Annotated[float, Field(ge=0.0, le=1.0, title="Offset")] + + +class SectionTarget(RootModel[Literal["apic", "basal", "dend", "soma", "axon"]]): + """Placeholder.""" + + root: Annotated[ + Literal["apic", "basal", "dend", "soma", "axon"], Field(title="SectionTarget") + ] + + +class SimulationStimulusConfig(BaseModel): + """Placeholder.""" + + stimulus_type: Annotated[ + Literal["current_clamp", "voltage_clamp", "conductance"], + Field(title="Stimulus Type"), + ] + stimulus_protocol: Annotated[ + Optional[Literal["ap_waveform", "idrest", "iv", "fire_pattern"]], + Field(title="Stimulus Protocol"), + ] = None + amplitudes: Annotated[Union[List[float], float], Field(title="Amplitudes")] + + +class StimulationItemResponse(BaseModel): + """Placeholder.""" + + x: Annotated[List[float], Field(title="X")] + y: Annotated[List[float], Field(title="Y")] + name: Annotated[str, Field(title="Name")] + amplitude: Annotated[float, Field(title="Amplitude")] + + +class StimulationPlotConfig(BaseModel): + """Placeholder.""" + + stimulus_protocol: Annotated[ + Optional[Literal["ap_waveform", "idrest", "iv", "fire_pattern"]], + Field(title="Stimulus Protocol"), + ] = None + amplitudes: Annotated[List[float], Field(title="Amplitudes")] + + +class SynapseConfig(BaseModel): + """Placeholder.""" + + id: Annotated[str, Field(title="Id")] + name: Annotated[str, Field(title="Name")] + target: Optional[SectionTarget] = None + type: Annotated[int, Field(title="Type")] + distribution: Annotated[ + Literal["exponential", "linear", "formula"], Field(title="Distribution") + ] + formula: Annotated[Optional[str], Field(title="Formula")] = None + soma_synapse_count: Annotated[Optional[int], Field(title="Soma Synapse Count")] = ( + None + ) + seed: Annotated[int, Field(title="Seed")] + exclusion_rules: Annotated[ + Optional[List[ExclusionRule]], Field(title="Exclusion Rules") + ] = None + + +class SynapsePlacementBody(BaseModel): + """Placeholder.""" + + seed: Annotated[int, Field(title="Seed")] + config: SynapseConfig + + +class SynapsePosition(BaseModel): + """Placeholder.""" + + segment_id: Annotated[int, Field(title="Segment Id")] + coordinates: Annotated[List[float], Field(title="Coordinates")] + position: Annotated[float, Field(title="Position")] + + +class Frequency(RootModel[float]): + """Placeholder.""" + + root: Annotated[float, Field(gt=0.0, title="Frequency")] + + +class FrequencyItem(RootModel[float]): + """Placeholder.""" + + root: Annotated[float, Field(gt=0.0)] + + +class SynapseSimulationConfig(BaseModel): + """Placeholder.""" + + id: Annotated[str, Field(title="Id")] + delay: Annotated[int, Field(title="Delay")] + duration: Annotated[int, Field(le=3000, title="Duration")] + frequency: Annotated[ + Union[Frequency, List[FrequencyItem]], Field(title="Frequency") + ] + weight_scalar: Annotated[float, Field(gt=0.0, title="Weight Scalar")] + + +class UsedModel(BaseModel): + """Placeholder.""" + + id: Annotated[str, Field(title="Id")] + type: Annotated[ + Literal["me-model", "synaptome", "m-model", "e-model"], Field(title="Type") + ] + name: Annotated[str, Field(title="Name")] + + +class ValidationError(BaseModel): + """Placeholder.""" + + loc: Annotated[List[Union[str, int]], Field(title="Location")] + msg: Annotated[str, Field(title="Message")] + type: Annotated[str, Field(title="Error Type")] + + +class CurrentInjectionConfig(BaseModel): + """Placeholder.""" + + inject_to: Annotated[str, Field(title="Inject To")] + stimulus: SimulationStimulusConfig + + +class HTTPValidationError(BaseModel): + """Placeholder.""" + + detail: Annotated[Optional[List[ValidationError]], Field(title="Detail")] = None + + +class MEModelResponse(BaseModel): + """Placeholder.""" + + id: Annotated[str, Field(title="Id")] + name: Annotated[str, Field(title="Name")] + description: Annotated[Optional[str], Field(title="Description")] = None + type: Annotated[ + Literal["me-model", "synaptome", "m-model", "e-model"], Field(title="Type") + ] + created_by: Annotated[str, Field(title="Created By")] + created_at: Annotated[datetime, Field(title="Created At")] + brain_region: BrainRegion + m_model: UsedModel + e_model: UsedModel + + +class SectionSynapses(BaseModel): + """Placeholder.""" + + section_id: Annotated[str, Field(title="Section Id")] + synapses: Annotated[List[SynapsePosition], Field(title="Synapses")] + + +class SingleNeuronSimulationConfigInput(BaseModel): + """Placeholder.""" + + synaptome: Annotated[ + Optional[List[SynapseSimulationConfig]], Field(title="Synaptome") + ] = None + current_injection: CurrentInjectionConfig + record_from: Annotated[List[RecordingLocation], Field(title="Record From")] + conditions: ExperimentSetupConfig + type: Annotated[ + Literal["single-neuron-simulation", "synaptome-simulation"], Field(title="Type") + ] + duration: Annotated[int, Field(title="Duration")] + + +class SingleNeuronSimulationConfigOutput(SingleNeuronSimulationConfigInput): + """Placeholder.""" + + pass + + +class SynapsePlacementResponse(BaseModel): + """Placeholder.""" + + synapses: Annotated[List[SectionSynapses], Field(title="Synapses")] + + +class SynaptomeModelResponse(BaseModel): + """Placeholder.""" + + id: Annotated[str, Field(title="Id")] + name: Annotated[str, Field(title="Name")] + description: Annotated[Optional[str], Field(title="Description")] = None + type: Annotated[ + Literal["me-model", "synaptome", "m-model", "e-model"], Field(title="Type") + ] + created_by: Annotated[str, Field(title="Created By")] + created_at: Annotated[datetime, Field(title="Created At")] + brain_region: BrainRegion + me_model: UsedModel + synapses: Annotated[List[SynapseConfig], Field(title="Synapses")] + + +class PaginatedResponseUnionMEModelResponseSynaptomeModelResponse(BaseModel): + """Placeholder.""" + + offset: Annotated[int, Field(title="Offset")] + page_size: Annotated[int, Field(title="Page Size")] + total: Annotated[int, Field(title="Total")] + results: Annotated[ + List[Union[MEModelResponse, SynaptomeModelResponse]], Field(title="Results") + ] + + +class SimulationDetailsResponse(BaseModel): + """Placeholder.""" + + id: Annotated[str, Field(title="Id")] + status: Annotated[ + Optional[Literal["pending", "started", "success", "failure"]], + Field(title="Status"), + ] = None + results: Annotated[Optional[Dict[str, Any]], Field(title="Results")] = None + error: Annotated[Optional[str], Field(title="Error")] = None + type: Annotated[ + Literal["single-neuron-simulation", "synaptome-simulation"], Field(title="Type") + ] + name: Annotated[str, Field(title="Name")] + description: Annotated[str, Field(title="Description")] + created_by: Annotated[str, Field(title="Created By")] + created_at: Annotated[datetime, Field(title="Created At")] + injection_location: Annotated[str, Field(title="Injection Location")] + recording_location: Annotated[ + Union[List[str], str], Field(title="Recording Location") + ] + brain_region: BrainRegion + config: Optional[SingleNeuronSimulationConfigOutput] = None + me_model_id: Annotated[str, Field(title="Me Model Id")] + synaptome_model_id: Annotated[Optional[str], Field(title="Synaptome Model Id")] = ( + None + ) + + +class PaginatedResponseSimulationDetailsResponse(BaseModel): + """Placeholder.""" + + offset: Annotated[int, Field(title="Offset")] + page_size: Annotated[int, Field(title="Page Size")] + total: Annotated[int, Field(title="Total")] + results: Annotated[List[SimulationDetailsResponse], Field(title="Results")] diff --git a/swarm_copy/tools/__init__.py b/swarm_copy/tools/__init__.py index e51fbcb..8a3365d 100644 --- a/swarm_copy/tools/__init__.py +++ b/swarm_copy/tools/__init__.py @@ -1,7 +1,11 @@ """Tools package.""" +from swarm_copy.tools.bluenaas_memodel_getall import MEModelGetAllTool +from swarm_copy.tools.bluenaas_memodel_getone import MEModelGetOneTool +from swarm_copy.tools.bluenaas_scs_getall import SCSGetAllTool +from swarm_copy.tools.bluenaas_scs_getone import SCSGetOneTool +from swarm_copy.tools.bluenaas_scs_post import SCSPostTool from swarm_copy.tools.electrophys_tool import ElectrophysFeatureTool, FeatureOutput -from swarm_copy.tools.get_me_model_tool import GetMEModelTool from swarm_copy.tools.get_morpho_tool import GetMorphoTool, KnowledgeGraphOutput from swarm_copy.tools.kg_morpho_features_tool import ( KGMorphoFeatureOutput, @@ -22,6 +26,9 @@ from swarm_copy.tools.traces_tool import GetTracesTool, TracesOutput __all__ = [ + "SCSGetAllTool", + "SCSGetOneTool", + "SCSPostTool", "BRResolveOutput", "ElectrophysFeatureTool", "FeatureOutput", @@ -31,10 +38,11 @@ "KGMorphoFeatureTool", "KnowledgeGraphOutput", "LiteratureSearchTool", + "MEModelGetAllTool", + "MEModelGetOneTool", "MorphologyFeatureOutput", "MorphologyFeatureTool", "ParagraphMetadata", "ResolveEntitiesTool", "TracesOutput", - "GetMEModelTool", ] diff --git a/swarm_copy/tools/base_tool.py b/swarm_copy/tools/base_tool.py index 71af055..e366ab6 100644 --- a/swarm_copy/tools/base_tool.py +++ b/swarm_copy/tools/base_tool.py @@ -5,8 +5,6 @@ from typing import Any, ClassVar, Literal from httpx import AsyncClient -from openai.lib._tools import pydantic_function_tool -from openai.types.chat import ChatCompletionToolParam from pydantic import BaseModel, ConfigDict logger = logging.getLogger(__name__) @@ -72,13 +70,20 @@ class BaseTool(BaseModel, ABC): input_schema: BaseModel @classmethod - def pydantic_to_openai_schema(cls) -> ChatCompletionToolParam: + def pydantic_to_openai_schema(cls) -> dict[str, Any]: """Convert pydantic schema to OpenAI json.""" - return pydantic_function_tool( - model=cls.__annotations__["input_schema"], - name=cls.name, - description=cls.description, - ) + new_retval: dict[str, Any] = { + "type": "function", + "function": { + "name": cls.name, + "description": cls.description, + "strict": False, + "parameters": cls.__annotations__["input_schema"].model_json_schema(), + }, + } + new_retval["function"]["parameters"]["additionalProperties"] = False + + return new_retval @abstractmethod async def arun(self) -> Any: diff --git a/swarm_copy/tools/bluenaas_memodel_getall.py b/swarm_copy/tools/bluenaas_memodel_getall.py new file mode 100644 index 0000000..8bda00e --- /dev/null +++ b/swarm_copy/tools/bluenaas_memodel_getall.py @@ -0,0 +1,67 @@ +"""BlueNaaS single cell stimulation, simulation and synapse placement tool.""" + +import logging +from typing import ClassVar, Literal + +from pydantic import BaseModel, Field + +from swarm_copy.bluenaas_models import ( + PaginatedResponseUnionMEModelResponseSynaptomeModelResponse, +) +from swarm_copy.tools.base_tool import BaseMetadata, BaseTool + +logger = logging.getLogger(__name__) + + +class MEModelGetAllMetadata(BaseMetadata): + """Metadata class for the get all me models api.""" + + token: str + vlab_id: str + project_id: str + bluenaas_url: str + + +class InputMEModelGetAll(BaseModel): + """Inputs for the BlueNaaS single-neuron simulation.""" + + offset: int = Field(default=0, description="Pagination offset") + page_size: int = Field( + default=20, description="Number of results returned by the API." + ) + model_type: Literal["single-neuron-simulation", "synaptome-simulation"] = Field( + default="single-neuron-simulation", + description="Type of simulation to retrieve.", + ) + + +class MEModelGetAllTool(BaseTool): + """Class defining the MEModelGetAll tool.""" + + name: ClassVar[str] = "memodelgetall-tool" + description: ClassVar[str] = """Get multiple me models from the user. + Returns `page_size` ME-models that belong to the user's project. + If the user requests an ME-model with specific criteria, use this tool + to retrieve multiple of its ME-models and chose yourself the one(s) that fit the user's request.""" + metadata: MEModelGetAllMetadata + input_schema: InputMEModelGetAll + + async def arun(self) -> PaginatedResponseUnionMEModelResponseSynaptomeModelResponse: + """Run the MEModelGetAll tool.""" + logger.info( + f"Running MEModelGetAll tool with inputs {self.input_schema.model_dump()}" + ) + + response = await self.metadata.httpx_client.get( + url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/me-models", + params={ + "simulation_type": self.input_schema.model_type, + "offset": self.input_schema.offset, + "page_size": self.input_schema.page_size, + }, + headers={"Authorization": f"Bearer {self.metadata.token}"}, + ) + breakpoint() + return PaginatedResponseUnionMEModelResponseSynaptomeModelResponse( + **response.json() + ) diff --git a/swarm_copy/tools/bluenaas_memodel_getone.py b/swarm_copy/tools/bluenaas_memodel_getone.py new file mode 100644 index 0000000..4f4a3b3 --- /dev/null +++ b/swarm_copy/tools/bluenaas_memodel_getone.py @@ -0,0 +1,52 @@ +"""BlueNaaS single cell stimulation, simulation and synapse placement tool.""" + +import logging +from typing import ClassVar +from urllib.parse import quote_plus + +from pydantic import BaseModel, Field + +from swarm_copy.bluenaas_models import MEModelResponse +from swarm_copy.tools.base_tool import BaseMetadata, BaseTool + +logger = logging.getLogger(__name__) + + +class MEModelGetOneMetadata(BaseMetadata): + """Metadata class for the get one me models api.""" + + token: str + vlab_id: str + project_id: str + bluenaas_url: str + + +class InputMEModelGetOne(BaseModel): + """Inputs for the BlueNaaS single-neuron simulation.""" + + model_id: str = Field( + description="ID of the model to retrieve. Should be an https link." + ) + + +class MEModelGetOneTool(BaseTool): + """Class defining the MEModelGetOne tool.""" + + name: ClassVar[str] = "memodelgetone-tool" + description: ClassVar[str] = """Get one specific me model from a user. + The id can be retrieved using the 'memodelgetall-tool' or directly specified by the user.""" + metadata: MEModelGetOneMetadata + input_schema: InputMEModelGetOne + + async def arun(self) -> MEModelResponse: + """Run the MEModelGetOne tool.""" + logger.info( + f"Running MEModelGetOne tool with inputs {self.input_schema.model_dump()}" + ) + + response = await self.metadata.httpx_client.get( + url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/{quote_plus(self.input_schema.model_id)}", + headers={"Authorization": f"Bearer {self.metadata.token}"}, + ) + + return MEModelResponse(**response.json()) diff --git a/swarm_copy/tools/bluenaas_scs_getall.py b/swarm_copy/tools/bluenaas_scs_getall.py new file mode 100644 index 0000000..95897dc --- /dev/null +++ b/swarm_copy/tools/bluenaas_scs_getall.py @@ -0,0 +1,66 @@ +"""BlueNaaS single cell stimulation, simulation and synapse placement tool.""" + +import logging +from typing import ClassVar, Literal + +from pydantic import BaseModel, Field + +from swarm_copy.bluenaas_models import PaginatedResponseSimulationDetailsResponse +from swarm_copy.tools.base_tool import BaseMetadata, BaseTool + +logger = logging.getLogger(__name__) + + +class SCSGetAllMetadata(BaseMetadata): + """Metadata class for the get all simulations api.""" + + token: str + vlab_id: str + project_id: str + bluenaas_url: str + + +class InputSCSGetAll(BaseModel): + """Inputs for the BlueNaaS single-neuron simulation.""" + + offset: int = Field(default=0, description="Pagination offset") + page_size: int = Field( + default=20, description="Number of results returned by the API." + ) + simulation_type: Literal["single-neuron-simulation", "synaptome-simulation"] = ( + Field( + default="single-neuron-simulation", + description="Type of simulation to retrieve.", + ) + ) + + +class SCSGetAllTool(BaseTool): + """Class defining the SCSGetAll tool.""" + + name: ClassVar[str] = "scsgetall-tool" + description: ClassVar[ + str + ] = """Retrieve `page_size` simulations' metadata from a user's project. + If the user requests a simulation with specific criteria, use this tool + to retrieve multiple of its simulations and chose yourself the one(s) that fit the user's request.""" + metadata: SCSGetAllMetadata + input_schema: InputSCSGetAll + + async def arun(self) -> PaginatedResponseSimulationDetailsResponse: + """Run the SCSGetAll tool.""" + logger.info( + f"Running SCSGetAll tool with inputs {self.input_schema.model_dump()}" + ) + + response = await self.metadata.httpx_client.get( + url=f"{self.metadata.bluenaas_url}/simulation/single-neuron/{self.metadata.vlab_id}/{self.metadata.project_id}", + params={ + "simulation_type": self.input_schema.simulation_type, + "offset": self.input_schema.offset, + "page_size": self.input_schema.page_size, + }, + headers={"Authorization": f"Bearer {self.metadata.token}"}, + ) + + return PaginatedResponseSimulationDetailsResponse(**response.json()) diff --git a/swarm_copy/tools/bluenaas_scs_getone.py b/swarm_copy/tools/bluenaas_scs_getone.py new file mode 100644 index 0000000..4957be9 --- /dev/null +++ b/swarm_copy/tools/bluenaas_scs_getone.py @@ -0,0 +1,53 @@ +"""BlueNaaS single cell stimulation, simulation and synapse placement tool.""" + +import logging +from typing import ClassVar + +from pydantic import BaseModel, Field + +from swarm_copy.bluenaas_models import SimulationDetailsResponse +from swarm_copy.tools.base_tool import BaseMetadata, BaseTool + +logger = logging.getLogger(__name__) + + +class SCSGetOneMetadata(BaseMetadata): + """Metadata class for the get all simulations api.""" + + token: str + vlab_id: str + project_id: str + bluenaas_url: str + + +class InputSCSGetOne(BaseModel): + """Inputs for the BlueNaaS single-neuron simulation.""" + + simulation_id: str = Field( + description="ID of the simulation to retrieve. Should be an https link." + ) + + +class SCSGetOneTool(BaseTool): + """Class defining the SCSGetOne tool.""" + + name: ClassVar[str] = "scsgetone-tool" + description: ClassVar[ + str + ] = """Get one specific simulations from a user based on its id. + The id can be retrieved using the 'scsgetall-tool' or directly specified by the user.""" + metadata: SCSGetOneMetadata + input_schema: InputSCSGetOne + + async def arun(self) -> SimulationDetailsResponse: + """Run the SCSGetOne tool.""" + logger.info( + f"Running SCSGetOne tool with inputs {self.input_schema.model_dump()}" + ) + + response = await self.metadata.httpx_client.get( + url=f"{self.metadata.bluenaas_url}/simulation/single-neuron/{self.metadata.vlab_id}/{self.metadata.project_id}/{self.input_schema.simulation_id}", + headers={"Authorization": f"Bearer {self.metadata.token}"}, + ) + + return SimulationDetailsResponse(**response.json()) diff --git a/swarm_copy/tools/bluenaas_scs_post.py b/swarm_copy/tools/bluenaas_scs_post.py new file mode 100644 index 0000000..6c8e154 --- /dev/null +++ b/swarm_copy/tools/bluenaas_scs_post.py @@ -0,0 +1,175 @@ +"""BlueNaaS single cell stimulation, simulation and synapse placement tool.""" + +import logging +from typing import Any, ClassVar, Literal + +from pydantic import BaseModel, Field + +from swarm_copy.tools.base_tool import BaseMetadata, BaseTool + +logger = logging.getLogger(__name__) + + +class SCSPostMetadata(BaseMetadata): + """Metadata class for the get all simulations api.""" + + token: str + vlab_id: str + project_id: str + bluenaas_url: str + + +class RecordingLocation(BaseModel): + """Configuration for the recording location in the simulation.""" + + section: str = Field(default="soma[0]", description="Section to record from") + offset: float = Field( + default=0.5, ge=0, le=1, description="Offset in the section to record from" + ) + + +class InputSCSPost(BaseModel): + """Inputs for the BlueNaaS single-neuron simulation.""" + + me_model_id: str = Field( + description=( + "ID of the neuron model to be used in the simulation. The model ID can be" + " fetched using the 'memodelgetall-tool'." + ) + ) + current_injection__inject_to: str = Field( + default="soma[0]", description="Section to inject the current to." + ) + current_injection__stimulus__stimulus_type: Literal[ + "current_clamp", "voltage_clamp", "conductance" + ] = Field(default="current_clamp", description="Type of stimulus to be used.") + current_injection__stimulus__stimulus_protocol: Literal[ + "ap_waveform", "idrest", "iv", "fire_pattern" + ] = Field(default="ap_waveform", description="Stimulus protocol to be used.") + + current_injection__stimulus__amplitudes: list[float] = Field( + default=[0.1], + min_length=1, + description="List of amplitudes for the stimulus", + ) + record_from: list[RecordingLocation] = Field( + default=[RecordingLocation()], + description=( + "List of sections to record from during the simulation. Each record" + " configuration includes the section name and offset." + ), + ) + conditions__celsius: int = Field( + default=34, ge=0, le=50, description="Temperature in celsius" + ) + conditions__vinit: int = Field(default=-73, description="Initial voltage in mV") + conditions__hypamp: int = Field(default=0, description="Holding current in nA") + conditions__max_time: int = Field( + default=100, le=3000, description="Maximum simulation time in ms" + ) + conditions__time_step: float = Field( + default=0.05, ge=0.001, le=10, description="Time step in ms" + ) + conditions__seed: int = Field(default=100, description="Random seed") + + +class SCSPostOutput(BaseModel): + """Should return a successful POST request.""" + + id: str + name: str + status: Literal["success", "pending", "error"] + error: str | None + + +class SCSPostTool(BaseTool): + """Class defining the SCSPost tool.""" + + name: ClassVar[str] = "scspost-tool" + description: ClassVar[str] = """Runs a single-neuron simulation. + Requires a "me_model_id" which must be fetched through the 'memodelgetall-tool' or directly provided by the user. + Optionally, the user can specify simulation parameters. + Returns the id of the simulation along with metadatas to fetch the simulation result and analyse it at a later stage. + """ + metadata: SCSPostMetadata + input_schema: InputSCSPost + + async def arun(self) -> SCSPostOutput: + """Run the SCSPost tool.""" + logger.info( + f"Running SCSPost tool with inputs {self.input_schema.model_dump()}" + ) + + json_api = self.create_json_api( + current_injection__inject_to=self.input_schema.current_injection__inject_to, + current_injection__stimulus__stimulus_type=self.input_schema.current_injection__stimulus__stimulus_type, + current_injection__stimulus__stimulus_protocol=self.input_schema.current_injection__stimulus__stimulus_protocol, + current_injection__stimulus__amplitudes=self.input_schema.current_injection__stimulus__amplitudes, + record_from=self.input_schema.record_from, + conditions__celsius=self.input_schema.conditions__celsius, + conditions__vinit=self.input_schema.conditions__vinit, + conditions__hypamp=self.input_schema.conditions__hypamp, + conditions__max_time=self.input_schema.conditions__max_time, + conditions__time_step=self.input_schema.conditions__time_step, + conditions__seed=self.input_schema.conditions__seed, + ) + + response = await self.metadata.httpx_client.post( + url=f"{self.metadata.bluenaas_url}/simulation/single-neuron/{self.metadata.vlab_id}/{self.metadata.project_id}/run", + params={"model_id": self.input_schema.me_model_id, "realtime": "False"}, + headers={"Authorization": f"Bearer {self.metadata.token}"}, + json=json_api, + ) + json_response = response.json() + return SCSPostOutput( + id=json_response["id"], + status=json_response["status"], + name=json_response["name"], + error=json_response["error"], + ) + + @staticmethod + def create_json_api( + current_injection__inject_to: str = "soma[0]", + current_injection__stimulus__stimulus_type: Literal[ + "current_clamp", "voltage_clamp", "conductance" + ] = "current_clamp", + current_injection__stimulus__stimulus_protocol: Literal[ + "ap_waveform", "idrest", "iv", "fire_pattern" + ] = "ap_waveform", + current_injection__stimulus__amplitudes: list[float] | None = None, + record_from: list[RecordingLocation] | None = None, + conditions__celsius: int = 34, + conditions__vinit: int = -73, + conditions__hypamp: int = 0, + conditions__max_time: int = 100, + conditions__time_step: float = 0.05, + conditions__seed: int = 100, + ) -> dict[str, Any]: + """Based on the simulation config, create a valid JSON for the API.""" + if not current_injection__stimulus__amplitudes: + current_injection__stimulus__amplitudes = [0.1] + if not record_from: + record_from = [RecordingLocation()] + json_api = { + "current_injection": { + "inject_to": current_injection__inject_to, + "stimulus": { + "stimulus_type": current_injection__stimulus__stimulus_type, + "stimulus_protocol": current_injection__stimulus__stimulus_protocol, + "amplitudes": current_injection__stimulus__amplitudes, + }, + }, + "record_from": [recording.model_dump() for recording in record_from], + "conditions": { + "celsius": conditions__celsius, + "vinit": conditions__vinit, + "hypamp": conditions__hypamp, + "max_time": conditions__max_time, + "time_step": conditions__time_step, + "seed": conditions__seed, + }, + "type": "single-neuron-simulation", + "duration": conditions__max_time, + } + return json_api diff --git a/swarm_copy/tools/get_me_model_tool.py b/swarm_copy/tools/get_me_model_tool.py deleted file mode 100644 index 7c47b9a..0000000 --- a/swarm_copy/tools/get_me_model_tool.py +++ /dev/null @@ -1,223 +0,0 @@ -"""Module defining the Get ME Model tool.""" - -import logging -from typing import Any, ClassVar - -from pydantic import BaseModel, Field - -from swarm_copy.cell_types import get_celltypes_descendants -from swarm_copy.tools.base_tool import BaseMetadata, BaseTool -from swarm_copy.utils import get_descendants_id - -logger = logging.getLogger(__name__) - - -class GetMEModelInput(BaseModel): - """Inputs of the knowledge graph API.""" - - brain_region_id: str = Field( - description="ID of the brain region of interest. To get this ID, please use the `resolve-entities-tool` first." - ) - mtype_id: str | None = Field( - default=None, - description="ID of the M-type of interest. To get this ID, please use the `resolve-entities-tool` first.", - ) - etype_id: str | None = Field( - default=None, - description="ID of the electrical type of the cell. Can be obtained through the 'resolve-entities-tool'.", - ) - - -class GetMEModelMetadata(BaseMetadata): - """Metadata class for GetMEModelTool.""" - - knowledge_graph_url: str - token: str - me_model_search_size: int - brainregion_path: str - celltypes_path: str - - -class MEModelOutput(BaseModel): - """Output schema for the knowledge graph API.""" - - me_model_id: str - me_model_name: str | None - me_model_description: str | None - mtype: str | None - etype: str | None - - brain_region_id: str - brain_region_label: str | None - - subject_species_label: str | None - subject_age: str | None - - -class GetMEModelTool(BaseTool): - """Class defining the Get ME Model logic.""" - - name: ClassVar[str] = "get-me-model-tool" - description: ClassVar[ - str - ] = """Searches a neuroscience based knowledge graph to retrieve neuron morpho-electric model (ME models) names, IDs and descriptions. - Requires a 'brain_region_id' which is the ID of the brain region of interest as registered in the knowledge graph. - Optionally accepts an mtype_id and/or an etype_id. - The output is a list of ME models, containing: - - The brain region ID. - - The brain region name. - - The subject species name. - - The subject age. - - The model ID. - - The model name. - - The model description. - The model ID is in the form of an HTTP(S) link such as 'https://bbp.epfl.ch/data/bbp/mmb-point-neuron-framework-model/...'.""" - input_schema: GetMEModelInput - metadata: GetMEModelMetadata - - async def arun(self) -> list[MEModelOutput]: - """From a brain region ID, extract ME models.""" - logger.info( - f"Entering Get ME Model tool. Inputs: {self.input_schema.brain_region_id=}, {self.input_schema.mtype_id=}, {self.input_schema.etype_id=}" - ) - # From the brain region ID, get the descendants. - hierarchy_ids = get_descendants_id( - self.input_schema.brain_region_id, - json_path=self.metadata.brainregion_path, - ) - logger.info(f"Found {len(list(hierarchy_ids))} children of the brain ontology.") - - if self.input_schema.mtype_id: - mtype_ids = get_celltypes_descendants( - self.input_schema.mtype_id, self.metadata.celltypes_path - ) - logger.info( - f"Found {len(list(mtype_ids))} children of the cell types ontology for mtype." - ) - else: - mtype_ids = None - - # Create the ES query to query the KG. - entire_query = self.create_query( - brain_regions_ids=hierarchy_ids, - mtype_ids=mtype_ids, - etype_id=self.input_schema.etype_id, - ) - - # Send the query to get ME models. - response = await self.metadata.httpx_client.post( - url=self.metadata.knowledge_graph_url, - headers={"Authorization": f"Bearer {self.metadata.token}"}, - json=entire_query, - ) - - # Process the output and return. - return self._process_output(response.json()) - - def create_query( - self, - brain_regions_ids: set[str], - mtype_ids: set[str] | None = None, - etype_id: str | None = None, - ) -> dict[str, Any]: - """Create ES query out of the BR, mtype, and etype IDs. - - Parameters - ---------- - brain_regions_ids - IDs of the brain region of interest (of the form http://api.brain-map.org/api/v2/data/Structure/...) - mtype_id - ID of the mtype of the model - etype_id - ID of the etype of the model - - Returns - ------- - dict containing the elasticsearch query to send to the KG. - """ - # At least one of the children brain region should match. - conditions = [ - { - "bool": { - "should": [ - {"term": {"brainRegion.@id.keyword": hierarchy_id}} - for hierarchy_id in brain_regions_ids - ] - } - }, - {"term": {"@type.keyword": "https://neuroshapes.org/MEModel"}}, - {"term": {"deprecated": False}}, - ] - - if mtype_ids: - # The correct mtype should match. For now - # It is a one term should condition, but eventually - # we will resolve the subclasses of the mtypes. - # They will all be appended here. - conditions.append( - { - "bool": { - "should": [ - {"term": {"mType.@id.keyword": mtype_id}} - for mtype_id in mtype_ids - ] - } - } - ) - - if etype_id: - # The correct etype should match. - conditions.append({"term": {"eType.@id.keyword": etype_id}}) - - # Assemble the query to return ME models. - entire_query = { - "size": self.metadata.me_model_search_size, - "track_total_hits": True, - "query": {"bool": {"must": conditions}}, - } - return entire_query - - @staticmethod - def _process_output(output: Any) -> list[MEModelOutput]: - """Process output to fit the MEModelOutput pydantic class defined above. - - Parameters - ---------- - output - Raw output of the _arun method, which comes from the KG - - Returns - ------- - list of MEModelOutput to describe the model and its metadata. - """ - formatted_output = [ - MEModelOutput( - me_model_id=res["_source"]["_self"], - me_model_name=res["_source"].get("name"), - me_model_description=res["_source"].get("description"), - mtype=( - res["_source"]["mType"].get("label") - if "mType" in res["_source"] - else None - ), - etype=( - res["_source"]["eType"].get("label") - if "eType" in res["_source"] - else None - ), - brain_region_id=res["_source"]["brainRegion"]["@id"], - brain_region_label=res["_source"]["brainRegion"].get("label"), - subject_species_label=( - res["_source"]["subjectSpecies"].get("label") - if "subjectSpecies" in res["_source"] - else None - ), - subject_age=( - res["_source"]["subjectAge"].get("label") - if "subjectAge" in res["_source"] - else None - ), - ) - for res in output["hits"]["hits"] - ] - return formatted_output From d5a2ccf40bc3a1d2fbbde25ef8af47cdf9e515d8 Mon Sep 17 00:00:00 2001 From: Kerem Kurban Date: Tue, 19 Nov 2024 16:58:14 +0100 Subject: [PATCH 2/2] Llm eval script (#46) * move tool validation to script and add tests * add optional and forbidden tools * update validation logic * replace with async validation * add semaphore for concurrent async req limit * add error reasoning --------- Co-authored-by: Kerem Kurban --- .gitignore | 5 +- CHANGELOG.md | 3 + .../scripts/avalidate_tool_calls.py | 258 +++++++++++++++++ tests/agents/test_tool_calls.py | 112 -------- tests/data/tool_calls.json | 187 ++++++------ tests/tools/test_validate_tool_call.py | 268 ++++++++++++++++++ tool_call_evaluation.csv | 7 - 7 files changed, 640 insertions(+), 200 deletions(-) create mode 100644 src/neuroagent/scripts/avalidate_tool_calls.py delete mode 100644 tests/agents/test_tool_calls.py create mode 100644 tests/tools/test_validate_tool_call.py delete mode 100644 tool_call_evaluation.csv diff --git a/.gitignore b/.gitignore index cabbcce..15096e7 100644 --- a/.gitignore +++ b/.gitignore @@ -156,4 +156,7 @@ static # database stuff *db *.db-shm -*.db-wal \ No newline at end of file +*.db-wal + +# ignore csvs in root dir +*.csv \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index f9fd72f..52955bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - CRUDs. - BlueNaas CRUD tools +### Fixed +- Migrate LLM Evaluation logic to scripts and add tests + ## [0.3.3] - 30.10.2024 ### Changed diff --git a/src/neuroagent/scripts/avalidate_tool_calls.py b/src/neuroagent/scripts/avalidate_tool_calls.py new file mode 100644 index 0000000..44b8916 --- /dev/null +++ b/src/neuroagent/scripts/avalidate_tool_calls.py @@ -0,0 +1,258 @@ +"""Run validation on tool calls.""" + +import argparse +import asyncio +import json +import logging +from typing import Any + +import aiohttp +import pandas as pd + +logging.basicConfig( + level=logging.INFO, # Set the logging level + format="%(asctime)s - %(levelname)s - %(message)s", # Define the log message format +) + + +async def fetch_tool_call( + session: aiohttp.ClientSession, + query: dict[str, Any], + base_url: str, + semaphore: asyncio.Semaphore, +) -> dict[str, Any]: + """ + Fetch the tool call results for a given test case. + + This function sends an asynchronous POST request to the API with the provided + test case data and validates the tool calls against expected outcomes. + + Args: + ---- + session (aiohttp.ClientSession): The aiohttp session used to make the HTTP request. + query (dict): A dictionary containing the test case data, including the prompt, + expected tools, optional tools, and forbidden tools. + base_url (str): The base URL of the API. + + Returns + ------- + dict: A dictionary containing the prompt, actual tool calls, expected tool calls, + and whether the actual calls match the expected ones. + """ + async with semaphore: + prompt = query["prompt"] + expected_tool_calls = query["expected_tools"] + optional_tools = query["optional_tools"] + forbidden_tools = query["forbidden_tools"] + + logging.info(f"Testing prompt: {prompt}") + + # Send a request to the API + async with session.post( + f"{base_url}/qa/run", + headers={"Content-Type": "application/json"}, + json={ + "query": prompt, + }, + ) as response: + if response.status == 200: + steps = await response.json() + called_tool_names = [ + step.get("tool_name", None) for step in steps.get("steps", []) + ] + expected_tool_names = [ + tool_call.get("tool_name", None) + for tool_call in expected_tool_calls + ] + match, reason = validate_tool( + expected_tool_names, + called_tool_names, + optional_tools=optional_tools, + forbidden_tools=forbidden_tools, + ) + return { + "Prompt": prompt, + "Actual": called_tool_names, + "Expected": expected_tool_names, + "Optional": optional_tools, + "Forbidden": forbidden_tools, + "Match": "Yes" if match else "No", + "Reason": reason if not match else "N/A", + } + else: + # Attempt to parse the error message from the response content + try: + error_content = await response.json() + error_message = error_content.get("content", "Unknown error") + except Exception as e: + error_message = f"Failed to parse error message: {str(e)}" + + error_info = { + "status_code": response.status, + "response_content": error_message, + } + logging.error( + f"API call failed for prompt: {prompt} with error: {error_info}" + ) + return { + "Prompt": prompt, + "Actual": f"API call failed: {error_info}", + "Expected": expected_tool_calls, + "Optional": optional_tools, + "Forbidden": forbidden_tools, + "Match": "No", + "Reason": f"API call failed: {error_info}", + } + + +async def validate_tool_calls_async( + base_url: str, + data_file: str, + output_file: str = "tool_call_evaluation.csv", + max_concurrent_requests: int = 10, +) -> None: + """ + Run asynchronous tool call tests and save the results to a CSV file. + + Args: + ---- + base_url (str): The base URL of the API. + data_file (str): The path to the JSON file containing test case data. + output_file (str): The name of the output CSV file where the results will + be saved. Defaults to 'tool_call_evaluation.csv'. + max_concurrent_requests (int): Maximum number of concurrent API requests. + Defaults to 10. + + Returns + ------- + None: This function does not return any value. It writes the results to a + CSV file. + """ + with open(data_file) as f: + tool_calls_data = json.load(f) + + results_list = [] + semaphore = asyncio.Semaphore(max_concurrent_requests) + + async with aiohttp.ClientSession() as session: + tasks = [ + fetch_tool_call(session, query, base_url, semaphore) + for query in tool_calls_data + ] + results_list = await asyncio.gather(*tasks) + + results_df = pd.DataFrame(results_list) + results_df.to_csv(output_file, index=False) + + +def validate_tool( + required_tools: list[str], + actual_tool_calls: list[str], + optional_tools: list[str], + forbidden_tools: list[str], +) -> tuple[bool, str]: + """ + Validate the sequence of tool calls against required, optional, and forbidden tools. + + Args: + ---- + required_tools (List): A list of tools that must be called in the specified order. + actual_tool_calls (List): A list of tools that were actually called. + optional_tools (List): A list of tools that can be called but are not required. + forbidden_tools (List): A list of tools that must not be called. + + Returns + ------- + tuple: A tuple containing a boolean and a string message. The boolean is True if the + validation is successful, otherwise False. The string message provides details + about the validation result. + """ + # Check for forbidden tools + if inter := set(actual_tool_calls) & set(forbidden_tools): + return False, f"Forbidden tool(s) called: {inter}" + + # Validate required tools order + order = 0 + for tool in actual_tool_calls: + if order < len(required_tools) and tool == required_tools[order]: + order += 1 + elif tool in optional_tools or ( + order > 0 and tool == required_tools[order - 1] + ): + continue + elif tool not in required_tools[:order]: + return False, f"Unexpected tool called: {tool}" + + # Check if all required tools were called + if order != len(required_tools): + return False, "Not all required tools were called" + + return True, "All required tools called correctly" + + +def main() -> None: + """ + Execute the tool call validation process. + + This function sets up the argument parser to handle command-line arguments, + specifically for specifying the base URL, port, data file path, and output + CSV file name. It then calls the validate_tool_calls_async function with + the provided arguments to perform the validation of tool calls and save + the results. + + The function is designed to be the entry point when the script is run + directly from the command line. + """ + parser = argparse.ArgumentParser( + description="Run tool call tests and save results." + ) + parser.add_argument( + "--base_url", + type=str, + default="http://localhost", + help="Base URL for the API", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port number for the API", + ) + parser.add_argument( + "--data", + type=str, + default="tests/data/tool_calls.json", + help="Path to the JSON file containing test case data", + ) + parser.add_argument( + "--output", + type=str, + default="tool_call_evaluation.csv", + help="Output CSV file for results", + ) + args = parser.parse_args() + + # Construct the full base URL + full_url = f"{args.base_url}:{args.port}" + asyncio.run(validate_tool_calls_async(full_url, args.data, args.output)) + + +if __name__ == "__main__": + """ + Validate tool calls against expected outcomes and logs the results. + + The script reads a set of prompts and their expected tool calls, executes the tool calls, + and compares the actual tool calls made with the expected ones. It logs whether the actual + tool calls match the expected ones and saves the results to a CSV file. + + Usage: + python validate_tool_calls.py --output + + Arguments: + --output: The name of the output CSV file where the results will be saved. + Defaults to 'tool_call_evaluation.csv'. + + The script is intended to be run as a standalone module. + """ + + main() diff --git a/tests/agents/test_tool_calls.py b/tests/agents/test_tool_calls.py deleted file mode 100644 index dbe6b08..0000000 --- a/tests/agents/test_tool_calls.py +++ /dev/null @@ -1,112 +0,0 @@ -import argparse -import json - -import pandas as pd -import pytest -import requests -from tqdm import tqdm - -# Base URL for the local API -base_url = "http://localhost:8000" - - -def is_subsequence(expected, actual): - it = iter(actual) - # Check if all items in expected are in it and in the correct order - return all(item in it for item in expected) and all( - item in expected for item in actual - ) - - -@pytest.mark.skip(reason="Skipping this test by default unless provoked") -def test_tool_calls(output_file="tool_call_evaluation.csv"): - # Load the expected tool calls from the JSON file - with open("tests/data/tool_calls.json") as f: - tool_calls_data = json.load(f) - - # List to store results - results_list = [] - - # Iterate over each test case with a progress bar - for test_case in tqdm(tool_calls_data, desc="Processing test cases"): - prompt = test_case["prompt"] - expected_tool_calls = test_case["expected_tool_calls"] - - print(f"Testing prompt: {prompt}") # Verbose output - - # Send a request to the API - response = requests.post( - f"{base_url}/qa/run", # Replace with the actual endpoint - headers={ - "Content-Type": "application/json" - }, # Ensure the correct header is set - json={ - "query": prompt, # Add the 'query' field with the prompt as its value - "messages": [{"role": "user", "content": prompt}], - }, - ) - - # Check if the response is successful - if response.status_code == 200: - # Parse the response - steps = response.json().get("steps", []) - called_tool_names = [step.get("tool_name", None) for step in steps] - expected_tool_names = [ - tool_call.get("tool_name", None) for tool_call in expected_tool_calls - ] - match = is_subsequence(expected_tool_names, called_tool_names) - - # Append the result to the list - results_list.append( - { - "Prompt": prompt, - "Actual": called_tool_names, - "Expected": expected_tool_names, - "Match": "Yes" if match else "No", - } - ) - else: - # Log the response status code and content for debugging - error_info = { - "status_code": response.status_code, - "response_content": response.text, - } - - print( - f"API call failed for prompt: {prompt} with error: {error_info}" - ) # Verbose output - - # Handle the case where the API call fails - results_list.append( - { - "Prompt": prompt, - "Actual": f"API call failed: {error_info}", - "Expected": expected_tool_calls, - "Match": "No", - } - ) - - # Create a DataFrame from the results - results_df = pd.DataFrame(results_list) - - # Save the results to a CSV file - results_df.to_csv(output_file) - - -def main(): - parser = argparse.ArgumentParser( - description="Run tool call tests and save results." - ) - parser.add_argument( - "--output", - type=str, - default="tool_call_evaluation.csv", - help="Output CSV file for results", - ) - args = parser.parse_args() - - test_tool_calls(args.output) - - -if __name__ == "__main__": - main() diff --git a/tests/data/tool_calls.json b/tests/data/tool_calls.json index c8d9066..98b6aa3 100644 --- a/tests/data/tool_calls.json +++ b/tests/data/tool_calls.json @@ -1,102 +1,129 @@ [ - { - "prompt": "What are the morphological features of neurons in the thalamus?", - "expected_tool_calls": [ - { - "tool_name": "resolve-entities-tool", - "arguments": {"brain_region": "thalamus"} - }, - { - "tool_name": "get-morpho-tool", - "arguments": {"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"} + { + "prompt": "What are the morphological features of neurons in the thalamus?", + "expected_tools": [ + { + "tool_name": "resolve-entities-tool", + "arguments": {"brain_region": "thalamus"} + }, + { + "tool_name": "get-morpho-tool", + "arguments": {"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"} + } + ], + "optional_tools": ["literature-search-tool"], + "forbidden_tools": ["get-traces-tool", "electrophys-features-tool", "get-me-model-tool", "bluenaas-tool"] + }, + { + "prompt": "Find me articles about the role of the hippocampus in memory formation.", + "expected_tools": [ + { + "tool_name": "literature-search-tool", + "arguments": { + "query": "hippocampus memory formation" } - ] - }, - { - "prompt": "Find me articles about the role of the hippocampus in memory formation.", - "expected_tool_calls": [ - { - "tool_name": "literature-search-tool", + } + ], + "optional_tools": ["resolve-entities-tool"], + "forbidden_tools": ["get-morpho-tool", "get-traces-tool", "electrophys-features-tool", "get-me-model-tool", "bluenaas-tool"] + }, + { + "prompt": "Retrieve electrophysiological features of cortical neurons.", + "expected_tools": [ + { + "tool_name": "resolve-entities-tool", + "arguments": { + "brain_region": "cortex" + } + }, + { + "tool_name": "get-traces-tool", + "arguments": { + "brain_region_id": "http://api.brain-map.org/api/v2/data/Structure/134" + } + }, + { + "tool_name": "electrophys-features-tool", "arguments": { - "query": "hippocampus memory formation" + "brain_region": "cortex" } - } - ] - }, + } + ], + "optional_tools": ["literature-search-tool"], + "forbidden_tools": ["get-morpho-tool", "get-me-model-tool", "bluenaas-tool"] + }, { - "prompt": "Retrieve electrophysiological features of cortical neurons.", - "expected_tool_calls": [ - { + "prompt": "Get traces for neurons in the hippocampus.", + "expected_tools": [ + { "tool_name": "resolve-entities-tool", "arguments": { - "brain_region": "cortex" - } - }, - { + "brain_region": "hippocampus"} + }, + { "tool_name": "get-traces-tool", "arguments": { "brain_region_id": "http://api.brain-map.org/api/v2/data/Structure/134" } - }, - { - "tool_name": "electrophys-feature-tool", - "arguments": { - "brain_region": "cortex" - } - } - ] + }, + { + "tool_name": "electrophys-features-tool" + } + ], + "optional_tools": ["literature-search-tool"], + "forbidden_tools": ["get-morpho-tool", "get-me-model-tool", "bluenaas-tool"] }, { - "prompt": "Get traces for neurons in the hippocampus.", - "expected_tool_calls": [ - { - "tool_name": "resolve-entities-tool", - "arguments": { - "brain_region": "hippocampus"} - }, - { - "tool_name": "get-traces-tool", - "arguments": { - "brain_region_id": "http://api.brain-map.org/api/v2/data/Structure/134" - } - }, - { - "tool_name": "electrophys-features-tool" - } - ] + "prompt": "Get traces for neurons in the primary somatosensory area.", + "expected_tools": [ + { + "tool_name": "resolve-entities-tool", + "arguments": { + "brain_region": "primary somatosensory area"} + }, + { + "tool_name": "get-traces-tool" + } + ], + "optional_tools": ["literature-search-tool","electrophys-features-tool"], + "forbidden_tools": ["get-morpho-tool", "get-me-model-tool", "bluenaas-tool"] }, { - "prompt": "Search for literature on synaptic plasticity.", - "expected_tool_calls": [ - { - "tool_name": "literature-search-tool", - "arguments": { - "query": "synaptic plasticity" - } - } - ] + "prompt": "Search for literature on synaptic plasticity.", + "expected_tools": [ + { + "tool_name": "literature-search-tool", + "arguments": { + "query": "synaptic plasticity" + } + } + ], + "optional_tools": ["resolve-entities-tool"], + "forbidden_tools": ["get-morpho-tool", "get-traces-tool", "electrophys-features-tool", "get-me-model-tool", "bluenaas-tool"] }, { "prompt": "Run 1000 ms of simulation of a me model from somatosensory cortex with 34 degree temperature, current clamp stimulation mode with step current for fire pattern detection. use 1 number of step and 0.05 nA current stimulation. Record from soma.", - "expected_tool_calls": [ - { - "tool_name": "resolve-entities-tool", - "arguments": { - "brain_region": "somatosensory area" - } - }, - { - "tool_name": "literature-search-tool" - }, - { - "tool_name": "get-me-model-tool", + "expected_tools": [ + { + "tool_name": "resolve-entities-tool", "arguments": { - "brain_region_id" : "http://api.brain-map.org/api/v2/data/Structure/322" + "brain_region": "somatosensory area" } - }, - { - "tool_name": "bluenaas-tool" + }, + { + "tool_name": "literature-search-tool" + }, + { + "tool_name": "get-me-model-tool", + "arguments": { + "brain_region_id" : "http://api.brain-map.org/api/v2/data/Structure/322" } - ] -} + }, + { + "tool_name": "bluenaas-tool" + } + ], + "optional_tools": [], + "forbidden_tools": ["get-morpho-tool", "get-traces-tool", "electrophys-features-tool"] + } ] \ No newline at end of file diff --git a/tests/tools/test_validate_tool_call.py b/tests/tools/test_validate_tool_call.py new file mode 100644 index 0000000..4e34b95 --- /dev/null +++ b/tests/tools/test_validate_tool_call.py @@ -0,0 +1,268 @@ +import asyncio +import json +import unittest +from unittest.mock import AsyncMock, mock_open, patch + +import aiohttp +import pytest + +from src.neuroagent.scripts.avalidate_tool_calls import ( + fetch_tool_call, + validate_tool, + validate_tool_calls_async, +) + + +class TestValidateTool(unittest.TestCase): + def test_no_tools_called(self): + result, message = validate_tool( + required_tools=["tool1", "tool2"], + actual_tool_calls=[], + optional_tools=[], + forbidden_tools=[], + ) + self.assertFalse(result) + self.assertEqual(message, "Not all required tools were called") + + def test_all_required_tools_called_in_order(self): + result, message = validate_tool( + required_tools=["tool1", "tool2"], + actual_tool_calls=["tool1", "tool2"], + optional_tools=[], + forbidden_tools=[], + ) + self.assertTrue(result) + self.assertEqual(message, "All required tools called correctly") + + def test_required_tools_called_out_of_order(self): + result, message = validate_tool( + required_tools=["tool1", "tool2"], + actual_tool_calls=["tool2", "tool1"], + optional_tools=[], + forbidden_tools=[], + ) + self.assertFalse(result) + + def test_forbidden_tool_called(self): + result, message = validate_tool( + required_tools=["tool1"], + actual_tool_calls=["tool1", "tool3"], + optional_tools=[], + forbidden_tools=["tool3"], + ) + self.assertFalse(result) + self.assertEqual(message, "Forbidden tool(s) called: {'tool3'}") + + def test_optional_tools_called(self): + result, message = validate_tool( + required_tools=["tool1"], + actual_tool_calls=["tool1", "tool2"], + optional_tools=["tool2"], + forbidden_tools=[], + ) + self.assertTrue(result) + self.assertEqual(message, "All required tools called correctly") + + def test_unexpected_tool_called(self): + result, message = validate_tool( + required_tools=["tool1"], + actual_tool_calls=["tool1", "tool3"], + optional_tools=[], + forbidden_tools=[], + ) + self.assertFalse(result) + self.assertEqual(message, "Unexpected tool called: tool3") + + def test_all_required_tools_called_with_optional_and_forbidden(self): + result, message = validate_tool( + required_tools=["tool1", "tool2"], + actual_tool_calls=["tool1", "tool2", "tool3"], + optional_tools=["tool3"], + forbidden_tools=["tool4"], + ) + self.assertTrue(result) + self.assertEqual(message, "All required tools called correctly") + + def test_only_optional_tools_called(self): + result, message = validate_tool( + required_tools=[], + actual_tool_calls=["tool2"], + optional_tools=["tool2"], + forbidden_tools=[], + ) + self.assertTrue(result) + self.assertEqual(message, "All required tools called correctly") + + def test_only_forbidden_tools_called(self): + result, message = validate_tool( + required_tools=[], + actual_tool_calls=["tool3"], + optional_tools=[], + forbidden_tools=["tool3"], + ) + self.assertFalse(result) + self.assertEqual(message, "Forbidden tool(s) called: {'tool3'}") + + def test_mixed_tools_called(self): + result, message = validate_tool( + required_tools=["tool1"], + actual_tool_calls=["tool1", "tool2", "tool3"], + optional_tools=["tool2"], + forbidden_tools=["tool3"], + ) + self.assertFalse(result) + self.assertEqual(message, "Forbidden tool(s) called: {'tool3'}") + + def test_repeated_required_tools(self): + result, message = validate_tool( + required_tools=["tool1", "tool1"], + actual_tool_calls=["tool1", "tool1"], + optional_tools=[], + forbidden_tools=[], + ) + self.assertTrue(result) + self.assertEqual(message, "All required tools called correctly") + + def test_repeated_forbidden_tools(self): + result, message = validate_tool( + required_tools=["tool1"], + actual_tool_calls=["tool1", "tool3", "tool3"], + optional_tools=[], + forbidden_tools=["tool3"], + ) + self.assertFalse(result) + self.assertEqual(message, "Forbidden tool(s) called: {'tool3'}") + + def test_overrepeated_tools(self): + result, message = validate_tool( + required_tools=["tool1", "tool2", "tool3"], + actual_tool_calls=["tool1", "tool2", "tool2", "tool2", "tool3"], + optional_tools=[], + forbidden_tools=[], + ) + self.assertTrue(result) + self.assertEqual(message, "All required tools called correctly") + + def test_overrepeated_tools2(self): + result, message = validate_tool( + required_tools=["tool1", "tool2", "tool3"], + actual_tool_calls=["tool1", "tool2", "tool3", "tool3", "tool3"], + optional_tools=[], + forbidden_tools=[], + ) + self.assertTrue(result) + self.assertEqual(message, "All required tools called correctly") + + def test_overrepeated_tools3(self): + result, message = validate_tool( + required_tools=["tool1", "tool2", "tool3"], + actual_tool_calls=[ + "tool1", + "tool1", + "tool1", + "tool2", + "tool2", + "tool3", + "tool3", + ], + optional_tools=[], + forbidden_tools=[], + ) + self.assertTrue(result) + self.assertEqual(message, "All required tools called correctly") + + +@pytest.mark.asyncio +async def test_fetch_tool_call_success(): + test_case = { + "prompt": "Test prompt", + "expected_tools": [{"tool_name": "tool1"}, {"tool_name": "tool2"}], + "optional_tools": ["tool3"], + "forbidden_tools": ["tool4"], + } + + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json.return_value = { + "steps": [{"tool_name": "tool1"}, {"tool_name": "tool2"}] + } + + # Mock the context manager behavior + with patch("aiohttp.ClientSession.post") as mock_post: + mock_post.return_value.__aenter__.return_value = mock_response + + async with aiohttp.ClientSession() as session: + base_url = "http://localhost:8000" # Define the base URL for testing + semaphore = asyncio.Semaphore(1) # Create a semaphore for testing + result = await fetch_tool_call( + session, test_case, base_url, semaphore + ) # Pass semaphore + + assert result["Prompt"] == "Test prompt" + assert result["Actual"] == ["tool1", "tool2"] + assert result["Expected"] == ["tool1", "tool2"] + assert result["Optional"] == ["tool3"] + assert result["Forbidden"] == ["tool4"] + assert result["Match"] == "Yes" + + +@pytest.mark.asyncio +async def test_fetch_tool_call_failure(): + test_case = { + "prompt": "Test prompt", + "expected_tools": [{"tool_name": "tool1"}, {"tool_name": "tool2"}], + "optional_tools": ["tool3"], + "forbidden_tools": ["tool4"], + } + + mock_response = AsyncMock() + mock_response.status = 500 + mock_response.text.return_value = "Internal Server Error" + + with patch("aiohttp.ClientSession.post") as mock_post: + mock_post.return_value.__aenter__.return_value = mock_response + + async with aiohttp.ClientSession() as session: + base_url = "http://localhost:8000" # Define the base URL for testing + semaphore = asyncio.Semaphore(1) # Create a semaphore for testing + result = await fetch_tool_call( + session, test_case, base_url, semaphore + ) # Pass semaphore + + assert result["Prompt"] == "Test prompt" + assert "API call failed" in result["Actual"] + assert result["Expected"] == [ + {"tool_name": "tool1"}, + {"tool_name": "tool2"}, + ] + assert result["Optional"] == ["tool3"] + assert result["Forbidden"] == ["tool4"] + assert result["Match"] == "No" + + +@pytest.mark.asyncio +async def test_validate_tool_calls_async(): + mock_data = json.dumps( + [ + { + "prompt": "Test prompt", + "expected_tools": [{"tool_name": "tool1"}], + "optional_tools": [], + "forbidden_tools": [], + } + ] + ) + + with patch("builtins.open", mock_open(read_data=mock_data)): + with patch( + "src.neuroagent.scripts.avalidate_tool_calls.fetch_tool_call", + new_callable=AsyncMock, + ) as mock_fetch: + mock_fetch.return_value = {"Match": "Yes"} + base_url = "http://localhost:8000" # Define the base URL for testing + data_file = "mock_data.json" # Mock data file path + await validate_tool_calls_async(base_url, data_file, "test_output.csv") + + +if __name__ == "__main__": + unittest.main() diff --git a/tool_call_evaluation.csv b/tool_call_evaluation.csv deleted file mode 100644 index 0a8e00c..0000000 --- a/tool_call_evaluation.csv +++ /dev/null @@ -1,7 +0,0 @@ -,Prompt,Actual,Expected,Match -0,What are the morphological features of neurons in the thalamus?,"['resolve-entities-tool', 'get-morpho-tool']","['resolve-entities-tool', 'get-morpho-tool']",Yes -1,Find me articles about the role of the hippocampus in memory formation.,['literature-search-tool'],['literature-search-tool'],Yes -2,Retrieve electrophysiological features of cortical neurons.,"['resolve-entities-tool', 'get-traces-tool', 'electrophys-features-tool', 'electrophys-features-tool', 'electrophys-features-tool', 'electrophys-features-tool', 'electrophys-features-tool', 'electrophys-features-tool', 'electrophys-features-tool', 'electrophys-features-tool', 'electrophys-features-tool', 'electrophys-features-tool']","['resolve-entities-tool', 'get-traces-tool', 'electrophys-feature-tool']",No -3,Get traces for neurons in the hippocampus.,"['resolve-entities-tool', 'literature-search-tool', 'resolve-entities-tool', 'literature-search-tool', 'resolve-entities-tool', 'resolve-entities-tool', 'resolve-entities-tool', 'literature-search-tool', 'resolve-entities-tool', 'resolve-entities-tool']","['resolve-entities-tool', 'get-traces-tool', 'electrophys-features-tool']",No -4,Search for literature on synaptic plasticity.,['literature-search-tool'],['literature-search-tool'],Yes -5,"Run 1000 ms of simulation of a me model from somatosensory cortex with 34 degree temperature, current clamp stimulation mode with step current for fire pattern detection. use 1 number of step and 0.05 nA current stimulation. Record from soma.","['resolve-entities-tool', 'literature-search-tool', 'get-me-model-tool', 'bluenaas-tool']","['resolve-entities-tool', 'literature-search-tool', 'get-me-model-tool', 'bluenaas-tool']",Yes