Skip to content

Commit

Permalink
Fix unreachable sqlite db bug + add init in scripts (#12)
Browse files Browse the repository at this point in the history
* Fix unreachable sqlite db bug + add init in scripts

* Modify CHANGELOG

* Modify docstring

* Remove editable install in the CI

* Add separator when LLM starts streaming

* Prepare for release

---------

Co-authored-by: Nicolas Frank <[email protected]>
  • Loading branch information
WonderPG and Nicolas Frank authored Sep 26, 2024
1 parent dcfe50f commit 4f69726
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ jobs:
run: |
pip install --upgrade pip
pip install mypy==1.8.0
pip install -e ".[dev]"
pip install ".[dev]"
- name: Running mypy and tests
run: |
mypy src/
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.1.1] - 26.09.2024

### Fixed
- Fixed a bug that prevented AsyncSqlite checkpoint to access the DB in streamed endpoints.

## [0.1.0] - 19.09.2024

### Added
Expand Down
2 changes: 1 addition & 1 deletion src/neuroagent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Neuroagent package."""

__version__ = "0.1.0"
__version__ = "0.1.1"
24 changes: 24 additions & 0 deletions src/neuroagent/agents/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Base agent."""

from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator

from langchain.chat_models.base import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from pydantic import BaseModel, ConfigDict


Expand Down Expand Up @@ -47,3 +49,25 @@ def astream(self, *args: Any, **kwargs: Any) -> AsyncIterator[str]:
@abstractmethod
def _process_output(*args: Any, **kwargs: Any) -> AgentOutput:
"""Format the output."""


class AsyncSqliteSaverWithPrefix(AsyncSqliteSaver):
"""Wrapper around the AsyncSqliteSaver that accepts a connection string with prefix."""

@classmethod
@asynccontextmanager
async def from_conn_string(
cls, conn_string: str
) -> AsyncIterator["AsyncSqliteSaver"]:
"""Create a new AsyncSqliteSaver instance from a connection string.
Args:
conn_string (str): The SQLite connection string. It can have the 'sqlite:///' prefix.
Yields
------
AsyncSqliteSaverWithPrefix: A new AsyncSqliteSaverWithPrefix instance.
"""
conn_string = conn_string.split("///")[-1]
async with super().from_conn_string(conn_string) as memory:
yield AsyncSqliteSaverWithPrefix(memory.conn)
10 changes: 5 additions & 5 deletions src/neuroagent/agents/simple_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class SimpleChatAgent(BaseAgent):
"""Simple Agent class."""

memory: BaseCheckpointSaver
memory: BaseCheckpointSaver[Any]

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -73,12 +73,9 @@ async def astream(
streamed_response = self.agent.astream_events(
{"messages": query}, version="v2", config=config
)
is_streaming = False
async for event in streamed_response:
kind = event["event"]

# newline everytime model starts streaming.
if kind == "on_chat_model_start":
yield "\n\n"
# check for the model stream.
if kind == "on_chat_model_stream":
# check if we are calling the tools.
Expand All @@ -95,6 +92,9 @@ async def astream(

content = data_chunk.content
if content:
if not is_streaming:
yield "\n<begin_llm_response>\n"
is_streaming = True
yield content
yield "\n"

Expand Down
10 changes: 5 additions & 5 deletions src/neuroagent/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
Expand All @@ -23,6 +22,7 @@
SimpleAgent,
SimpleChatAgent,
)
from neuroagent.agents.base_agent import AsyncSqliteSaverWithPrefix
from neuroagent.app.config import Settings
from neuroagent.cell_types import CellTypesMeta
from neuroagent.multi_agents import BaseMultiAgent, SupervisorMultiAgent
Expand Down Expand Up @@ -321,12 +321,12 @@ def get_language_model(

async def get_agent_memory(
connection_string: Annotated[str | None, Depends(get_connection_string)],
) -> AsyncIterator[BaseCheckpointSaver | None]:
) -> AsyncIterator[BaseCheckpointSaver[Any] | None]:
"""Get the agent checkpointer."""
if connection_string:
if connection_string.startswith("sqlite"):
async with AsyncSqliteSaver.from_conn_string(
connection_string.split("///")[-1]
async with AsyncSqliteSaverWithPrefix.from_conn_string(
connection_string
) as memory:
await memory.setup()
yield memory
Expand Down Expand Up @@ -404,7 +404,7 @@ def get_agent(

def get_chat_agent(
llm: Annotated[ChatOpenAI, Depends(get_language_model)],
memory: Annotated[BaseCheckpointSaver, Depends(get_agent_memory)],
memory: Annotated[BaseCheckpointSaver[Any], Depends(get_agent_memory)],
literature_tool: Annotated[LiteratureSearchTool, Depends(get_literature_tool)],
br_resolver_tool: Annotated[
ResolveBrainRegionTool, Depends(get_brain_region_resolver_tool)
Expand Down
1 change: 1 addition & 0 deletions src/neuroagent/scripts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Neuroagent scripts."""
4 changes: 2 additions & 2 deletions tests/agents/test_simple_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ async def test_astream(fake_llm_with_tools, httpx_mock):

msg_list = "".join([el async for el in response])
assert (
msg_list == "\n\n\nCalling tool : get-morpho-tool with arguments :"
' {"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"}\n\nGreat'
msg_list == "\nCalling tool : get-morpho-tool with arguments :"
' {"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"}\n<begin_llm_response>\nGreat'
" answer\n"
)

Expand Down
10 changes: 5 additions & 5 deletions tests/test_resolving.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def test_sparql_exact_resolve(httpx_mock, get_resolve_query_output):
}
]

httpx_mock.reset(assert_all_responses_were_requested=False)
httpx_mock.reset()

mtype = "Interneuron"
mocked_response = get_resolve_query_output[1]
Expand Down Expand Up @@ -83,7 +83,7 @@ async def test_sparql_fuzzy_resolve(httpx_mock, get_resolve_query_output):
"id": "http://api.brain-map.org/api/v2/data/Structure/463",
},
]
httpx_mock.reset(assert_all_responses_were_requested=False)
httpx_mock.reset()

mtype = "Interneu"
mocked_response = get_resolve_query_output[3]
Expand Down Expand Up @@ -142,7 +142,7 @@ async def test_es_resolve(httpx_mock, get_resolve_query_output):
"id": "http://api.brain-map.org/api/v2/data/Structure/184",
},
]
httpx_mock.reset(assert_all_responses_were_requested=True)
httpx_mock.reset()

mtype = "Ventral neuron"
mocked_response = get_resolve_query_output[5]
Expand Down Expand Up @@ -221,7 +221,7 @@ async def test_resolve_query(httpx_mock, get_resolve_query_output):
"id": "http://api.brain-map.org/api/v2/data/Structure/463",
},
]
httpx_mock.reset(assert_all_responses_were_requested=True)
httpx_mock.reset()

httpx_mock.add_response(url=url, json=get_resolve_query_output[0])

Expand Down Expand Up @@ -252,7 +252,7 @@ async def test_resolve_query(httpx_mock, get_resolve_query_output):
"id": "http://api.brain-map.org/api/v2/data/Structure/549",
}
]
httpx_mock.reset(assert_all_responses_were_requested=True)
httpx_mock.reset()
httpx_mock.add_response(
url=url,
json={
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ async def test_get_file_from_KG_errors(httpx_mock):
)
assert not_found.value.args[0] == "No file url was found."

httpx_mock.reset(assert_all_responses_were_requested=True)
httpx_mock.reset()
# no file found corresponding to file_url
test_file_url = "http://test_url.com"
json_response = {
Expand Down

0 comments on commit 4f69726

Please sign in to comment.