Skip to content

Commit

Permalink
me model tool addition (#9)
Browse files Browse the repository at this point in the history
* me model tool addition

* add changelog, fix mypy and lint errors

* rm breakpoint

* ruff

* ruff format

* ruff bump version

* rename protected namespace : model_ field

* rename protected namespace for others

* update changelog

* cell types hierarchy updated

* untrack app/data

---------

Co-authored-by: Mustafa Kerem Kurban <[email protected]>
  • Loading branch information
KeremKurban and Mustafa Kerem Kurban authored Oct 2, 2024
1 parent 198c76d commit 15854ed
Show file tree
Hide file tree
Showing 29 changed files with 654 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
- name: Set up environment
run: |
pip install --upgrade pip wheel setuptools
pip install bandit[toml]==1.7.4 ruff==0.5.5
pip install bandit[toml]==1.7.4 ruff==0.6.7
- name: Linting check
run: |
bandit -qr -c pyproject.toml src/
Expand Down
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,8 @@ cython_debug/
# static files generated from Django application using `collectstatic`
media
static

# database stuff
*db
*.db-shm
*.db-wal
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Add get morphoelectric (me) model tool

## [0.1.1] - 26.09.2024

### Fixed
Expand Down
9 changes: 9 additions & 0 deletions src/neuroagent/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ class SettingsGetMorpho(BaseModel):
model_config = ConfigDict(frozen=True)


class SettingsGetMEModel(BaseModel):
"""Get ME Model settings."""

search_size: int = 10

model_config = ConfigDict(frozen=True)


class SettingsKnowledgeGraph(BaseModel):
"""Knowledge graph API settings."""

Expand Down Expand Up @@ -157,6 +165,7 @@ class SettingsTools(BaseModel):
morpho: SettingsGetMorpho = SettingsGetMorpho()
trace: SettingsTrace = SettingsTrace()
kg_morpho_features: SettingsKGMorpho = SettingsKGMorpho()
me_model: SettingsGetMEModel = SettingsGetMEModel()

model_config = ConfigDict(frozen=True)

Expand Down
22 changes: 22 additions & 0 deletions src/neuroagent/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from neuroagent.multi_agents import BaseMultiAgent, SupervisorMultiAgent
from neuroagent.tools import (
ElectrophysFeatureTool,
GetMEModelTool,
GetMorphoTool,
GetTracesTool,
KGMorphoFeatureTool,
Expand Down Expand Up @@ -304,6 +305,25 @@ def get_morphology_feature_tool(
return tool


def get_me_model_tool(
settings: Annotated[Settings, Depends(get_settings)],
token: Annotated[str, Depends(get_kg_token)],
httpx_client: Annotated[AsyncClient, Depends(get_httpx_client)],
) -> GetMEModelTool:
"""Load get ME model tool."""
tool = GetMEModelTool(
metadata={
"url": settings.knowledge_graph.url,
"token": token,
"httpx_client": httpx_client,
"search_size": settings.tools.me_model.search_size,
"brainregion_path": settings.knowledge_graph.br_saving_path,
"celltypes_path": settings.knowledge_graph.ct_saving_path,
}
)
return tool


def get_language_model(
settings: Annotated[Settings, Depends(get_settings)],
) -> ChatOpenAI:
Expand Down Expand Up @@ -369,6 +389,7 @@ def get_agent(
ElectrophysFeatureTool, Depends(get_electrophys_feature_tool)
],
traces_tool: Annotated[GetTracesTool, Depends(get_traces_tool)],
me_model_tool: Annotated[GetMEModelTool, Depends(get_me_model_tool)],
settings: Annotated[Settings, Depends(get_settings)],
) -> BaseAgent | BaseMultiAgent:
"""Get the generative question answering service."""
Expand Down Expand Up @@ -397,6 +418,7 @@ def get_agent(
kg_morpho_feature_tool,
electrophys_feature_tool,
traces_tool,
me_model_tool,
]
logger.info("Load simple agent")
return SimpleAgent(llm=llm, tools=tools) # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions src/neuroagent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tools folder."""

from neuroagent.tools.electrophys_tool import ElectrophysFeatureTool, FeaturesOutput
from neuroagent.tools.get_me_model_tool import GetMEModelTool
from neuroagent.tools.get_morpho_tool import GetMorphoTool, KnowledgeGraphOutput
from neuroagent.tools.kg_morpho_features_tool import (
KGMorphoFeatureOutput,
Expand Down Expand Up @@ -35,4 +36,5 @@
"ParagraphMetadata",
"ResolveBrainRegionTool",
"TracesOutput",
"GetMEModelTool",
]
266 changes: 266 additions & 0 deletions src/neuroagent/tools/get_me_model_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
"""Module defining the Get ME Model tool."""

import logging
from typing import Any, Literal, Optional, Type

from langchain_core.tools import ToolException
from pydantic import BaseModel, Field

from neuroagent.cell_types import get_celltypes_descendants
from neuroagent.tools.base_tool import BaseToolOutput, BasicTool
from neuroagent.utils import get_descendants_id

logger = logging.getLogger(__name__)


class InputGetMEModel(BaseModel):
"""Inputs of the knowledge graph API."""

brain_region_id: str = Field(description="ID of the brain region of interest.")
mtype_id: Optional[str] = Field(
default=None, description="ID of the M-type of interest."
)
etype_id: Optional[
Literal[
"bAC",
"bIR",
"bNAC",
"bSTUT",
"cAC",
"cIR",
"cNAC",
"cSTUT",
"dNAC",
"dSTUT",
]
] = Field(default=None, description="ID of the E-type of interest.")


class MEModelOutput(BaseToolOutput):
"""Output schema for the knowledge graph API."""

me_model_id: str
me_model_name: str | None
me_model_description: str | None
mtype: str | None
etype: str | None

brain_region_id: str
brain_region_label: str | None

subject_species_label: str | None
subject_age: str | None


class GetMEModelTool(BasicTool):
"""Class defining the Get ME Model logic."""

name: str = "get-me-model-tool"
description: str = """Searches a neuroscience based knowledge graph to retrieve neuron morpho-electric model names, IDs and descriptions.
Requires a 'brain_region_id' which is the ID of the brain region of interest as registered in the knowledge graph. To get this ID, please use the `resolve-brain-region-tool` first.
Ideally, the user should also provide an 'mtype_id' and/or an 'etype_id' to filter the search results. But in case they are not provided, the search will return all models that match the brain region.
The output is a list of ME models, containing:
- The brain region ID.
- The brain region name.
- The subject species name.
- The subject age.
- The model ID.
- The model name.
- The model description.
The model ID is in the form of an HTTP(S) link such as 'https://bbp.epfl.ch/data/bbp/mmb-point-neuron-framework-model/...'."""
metadata: dict[str, Any]
args_schema: Type[BaseModel] = InputGetMEModel

def _run(self) -> None:
pass

async def _arun(
self,
brain_region_id: str,
mtype_id: str | None = None,
etype_id: str | None = None,
) -> list[MEModelOutput]:
"""From a brain region ID, extract ME models.
Parameters
----------
brain_region_id
ID of the brain region of interest (of the form http://api.brain-map.org/api/v2/data/Structure/...)
mtype_id
ID of the mtype of the model
etype_id
ID of the etype of the model
Returns
-------
list of MEModelOutput to describe the model and its metadata, or an error dict.
"""
logger.info(
f"Entering Get ME Model tool. Inputs: {brain_region_id=}, {mtype_id=}, {etype_id=}"
)
try:
# From the brain region ID, get the descendants.
hierarchy_ids = get_descendants_id(
brain_region_id, json_path=self.metadata["brainregion_path"]
)
logger.info(
f"Found {len(list(hierarchy_ids))} children of the brain ontology."
)

if mtype_id:
mtype_ids = set(
get_celltypes_descendants(mtype_id, self.metadata["celltypes_path"])
)
logger.info(
f"Found {len(list(mtype_ids))} children of the cell types ontology for mtype."
)
else:
mtype_ids = None

if etype_id:
etype_ids = set(
get_celltypes_descendants(etype_id, self.metadata["celltypes_path"])
)
logger.info(
f"Found {len(list(etype_ids))} children of the cell types ontology for etype."
)
else:
etype_ids = None

# Create the ES query to query the KG.
entire_query = self.create_query(
brain_regions_ids=hierarchy_ids,
mtype_ids=mtype_ids,
etype_ids=etype_ids,
)

# Send the query to get ME models.
response = await self.metadata["httpx_client"].post(
url=self.metadata["url"],
headers={"Authorization": f"Bearer {self.metadata['token']}"},
json=entire_query,
)

# Process the output and return.
return self._process_output(response.json())

except Exception as e:
raise ToolException(str(e), self.name)

def create_query(
self,
brain_regions_ids: set[str],
mtype_ids: set[str] | None = None,
etype_ids: set[str] | None = None,
) -> dict[str, Any]:
"""Create ES query out of the BR, mtype, and etype IDs.
Parameters
----------
brain_regions_ids
IDs of the brain region of interest (of the form http://api.brain-map.org/api/v2/data/Structure/...)
mtype_id
ID of the mtype of the model
etype_id
ID of the etype of the model
Returns
-------
dict containing the elasticsearch query to send to the KG.
"""
# At least one of the children brain region should match.
conditions = [
{
"bool": {
"should": [
{"term": {"[email protected]": hierarchy_id}}
for hierarchy_id in brain_regions_ids
]
}
},
{"term": {"@type.keyword": "https://neuroshapes.org/MEModel"}},
{"term": {"deprecated": False}},
]

if mtype_ids:
# The correct mtype should match. For now
# It is a one term should condition, but eventually
# we will resolve the subclasses of the mtypes.
# They will all be appended here.
conditions.append(
{
"bool": {
"should": [
{"match": {"mType.label": mtype_id}}
for mtype_id in mtype_ids
]
}
}
)

if etype_ids:
# The correct etype should match.
conditions.append(
{
"bool": {
"should": [
{"match": {"eType.label": etype_id}}
for etype_id in etype_ids
]
}
}
)

# Assemble the query to return ME models.
entire_query = {
"size": self.metadata["search_size"],
"track_total_hits": True,
"query": {"bool": {"must": conditions}},
"sort": {"createdAt": {"order": "desc", "unmapped_type": "keyword"}},
}
return entire_query

@staticmethod
def _process_output(output: Any) -> list[MEModelOutput]:
"""Process output to fit the MEModelOutput pydantic class defined above.
Parameters
----------
output
Raw output of the _arun method, which comes from the KG
Returns
-------
list of MEModelOutput to describe the model and its metadata.
"""
formatted_output = [
MEModelOutput(
me_model_id=res["_source"]["@id"],
me_model_name=res["_source"].get("name"),
me_model_description=res["_source"].get("description"),
mtype=(
res["_source"]["mType"].get("label")
if "mType" in res["_source"]
else None
),
etype=(
res["_source"]["eType"].get("label")
if "eType" in res["_source"]
else None
),
brain_region_id=res["_source"]["brainRegion"]["@id"],
brain_region_label=res["_source"]["brainRegion"].get("label"),
subject_species_label=(
res["_source"]["subjectSpecies"].get("label")
if "subjectSpecies" in res["_source"]
else None
),
subject_age=(
res["_source"]["subjectAge"].get("label")
if "subjectAge" in res["_source"]
else None
),
)
for res in output["hits"]["hits"]
]
return formatted_output
1 change: 1 addition & 0 deletions tests/agents/test_simple_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path

import pytest

from neuroagent.agents import AgentOutput, AgentStep, SimpleAgent


Expand Down
1 change: 1 addition & 0 deletions tests/agents/test_simple_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from langchain_core.messages import HumanMessage, ToolMessage
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver

from neuroagent.agents import AgentOutput, AgentStep, SimpleChatAgent


Expand Down
Loading

0 comments on commit 15854ed

Please sign in to comment.