From 4f69726698fb52c84e9059166884125492e260a6 Mon Sep 17 00:00:00 2001 From: Nicolas Frank <58003267+WonderPG@users.noreply.github.com> Date: Thu, 26 Sep 2024 16:47:08 +0200 Subject: [PATCH] Fix unreachable sqlite db bug + add init in scripts (#12) * Fix unreachable sqlite db bug + add init in scripts * Modify CHANGELOG * Modify docstring * Remove editable install in the CI * Add separator when LLM starts streaming * Prepare for release --------- Co-authored-by: Nicolas Frank --- .github/workflows/ci.yaml | 2 +- CHANGELOG.md | 5 +++++ src/neuroagent/__init__.py | 2 +- src/neuroagent/agents/base_agent.py | 24 ++++++++++++++++++++++ src/neuroagent/agents/simple_chat_agent.py | 10 ++++----- src/neuroagent/app/dependencies.py | 10 ++++----- src/neuroagent/scripts/__init__.py | 1 + tests/agents/test_simple_chat_agent.py | 4 ++-- tests/test_resolving.py | 10 ++++----- tests/test_utils.py | 2 +- 10 files changed, 50 insertions(+), 20 deletions(-) create mode 100644 src/neuroagent/scripts/__init__.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index dfbb0ab..369e6de 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -84,7 +84,7 @@ jobs: run: | pip install --upgrade pip pip install mypy==1.8.0 - pip install -e ".[dev]" + pip install ".[dev]" - name: Running mypy and tests run: | mypy src/ diff --git a/CHANGELOG.md b/CHANGELOG.md index eb751ad..d71d1fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.1] - 26.09.2024 + +### Fixed +- Fixed a bug that prevented AsyncSqlite checkpoint to access the DB in streamed endpoints. + ## [0.1.0] - 19.09.2024 ### Added diff --git a/src/neuroagent/__init__.py b/src/neuroagent/__init__.py index b4d4ea2..508d96b 100644 --- a/src/neuroagent/__init__.py +++ b/src/neuroagent/__init__.py @@ -1,3 +1,3 @@ """Neuroagent package.""" -__version__ = "0.1.0" +__version__ = "0.1.1" diff --git a/src/neuroagent/agents/base_agent.py b/src/neuroagent/agents/base_agent.py index 9ecf545..347e3bd 100644 --- a/src/neuroagent/agents/base_agent.py +++ b/src/neuroagent/agents/base_agent.py @@ -1,10 +1,12 @@ """Base agent.""" from abc import ABC, abstractmethod +from contextlib import asynccontextmanager from typing import Any, AsyncIterator from langchain.chat_models.base import BaseChatModel from langchain_core.tools import BaseTool +from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from pydantic import BaseModel, ConfigDict @@ -47,3 +49,25 @@ def astream(self, *args: Any, **kwargs: Any) -> AsyncIterator[str]: @abstractmethod def _process_output(*args: Any, **kwargs: Any) -> AgentOutput: """Format the output.""" + + +class AsyncSqliteSaverWithPrefix(AsyncSqliteSaver): + """Wrapper around the AsyncSqliteSaver that accepts a connection string with prefix.""" + + @classmethod + @asynccontextmanager + async def from_conn_string( + cls, conn_string: str + ) -> AsyncIterator["AsyncSqliteSaver"]: + """Create a new AsyncSqliteSaver instance from a connection string. + + Args: + conn_string (str): The SQLite connection string. It can have the 'sqlite:///' prefix. + + Yields + ------ + AsyncSqliteSaverWithPrefix: A new AsyncSqliteSaverWithPrefix instance. + """ + conn_string = conn_string.split("///")[-1] + async with super().from_conn_string(conn_string) as memory: + yield AsyncSqliteSaverWithPrefix(memory.conn) diff --git a/src/neuroagent/agents/simple_chat_agent.py b/src/neuroagent/agents/simple_chat_agent.py index b331d75..882b8d7 100644 --- a/src/neuroagent/agents/simple_chat_agent.py +++ b/src/neuroagent/agents/simple_chat_agent.py @@ -17,7 +17,7 @@ class SimpleChatAgent(BaseAgent): """Simple Agent class.""" - memory: BaseCheckpointSaver + memory: BaseCheckpointSaver[Any] @model_validator(mode="before") @classmethod @@ -73,12 +73,9 @@ async def astream( streamed_response = self.agent.astream_events( {"messages": query}, version="v2", config=config ) + is_streaming = False async for event in streamed_response: kind = event["event"] - - # newline everytime model starts streaming. - if kind == "on_chat_model_start": - yield "\n\n" # check for the model stream. if kind == "on_chat_model_stream": # check if we are calling the tools. @@ -95,6 +92,9 @@ async def astream( content = data_chunk.content if content: + if not is_streaming: + yield "\n\n" + is_streaming = True yield content yield "\n" diff --git a/src/neuroagent/app/dependencies.py b/src/neuroagent/app/dependencies.py index be00639..1294660 100644 --- a/src/neuroagent/app/dependencies.py +++ b/src/neuroagent/app/dependencies.py @@ -11,7 +11,6 @@ from langchain_openai import ChatOpenAI from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver -from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError @@ -23,6 +22,7 @@ SimpleAgent, SimpleChatAgent, ) +from neuroagent.agents.base_agent import AsyncSqliteSaverWithPrefix from neuroagent.app.config import Settings from neuroagent.cell_types import CellTypesMeta from neuroagent.multi_agents import BaseMultiAgent, SupervisorMultiAgent @@ -321,12 +321,12 @@ def get_language_model( async def get_agent_memory( connection_string: Annotated[str | None, Depends(get_connection_string)], -) -> AsyncIterator[BaseCheckpointSaver | None]: +) -> AsyncIterator[BaseCheckpointSaver[Any] | None]: """Get the agent checkpointer.""" if connection_string: if connection_string.startswith("sqlite"): - async with AsyncSqliteSaver.from_conn_string( - connection_string.split("///")[-1] + async with AsyncSqliteSaverWithPrefix.from_conn_string( + connection_string ) as memory: await memory.setup() yield memory @@ -404,7 +404,7 @@ def get_agent( def get_chat_agent( llm: Annotated[ChatOpenAI, Depends(get_language_model)], - memory: Annotated[BaseCheckpointSaver, Depends(get_agent_memory)], + memory: Annotated[BaseCheckpointSaver[Any], Depends(get_agent_memory)], literature_tool: Annotated[LiteratureSearchTool, Depends(get_literature_tool)], br_resolver_tool: Annotated[ ResolveBrainRegionTool, Depends(get_brain_region_resolver_tool) diff --git a/src/neuroagent/scripts/__init__.py b/src/neuroagent/scripts/__init__.py new file mode 100644 index 0000000..cc662d0 --- /dev/null +++ b/src/neuroagent/scripts/__init__.py @@ -0,0 +1 @@ +"""Neuroagent scripts.""" diff --git a/tests/agents/test_simple_chat_agent.py b/tests/agents/test_simple_chat_agent.py index 6ec5474..c8420a3 100644 --- a/tests/agents/test_simple_chat_agent.py +++ b/tests/agents/test_simple_chat_agent.py @@ -84,8 +84,8 @@ async def test_astream(fake_llm_with_tools, httpx_mock): msg_list = "".join([el async for el in response]) assert ( - msg_list == "\n\n\nCalling tool : get-morpho-tool with arguments :" - ' {"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"}\n\nGreat' + msg_list == "\nCalling tool : get-morpho-tool with arguments :" + ' {"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"}\n\nGreat' " answer\n" ) diff --git a/tests/test_resolving.py b/tests/test_resolving.py index eeafbad..4393b76 100644 --- a/tests/test_resolving.py +++ b/tests/test_resolving.py @@ -32,7 +32,7 @@ async def test_sparql_exact_resolve(httpx_mock, get_resolve_query_output): } ] - httpx_mock.reset(assert_all_responses_were_requested=False) + httpx_mock.reset() mtype = "Interneuron" mocked_response = get_resolve_query_output[1] @@ -83,7 +83,7 @@ async def test_sparql_fuzzy_resolve(httpx_mock, get_resolve_query_output): "id": "http://api.brain-map.org/api/v2/data/Structure/463", }, ] - httpx_mock.reset(assert_all_responses_were_requested=False) + httpx_mock.reset() mtype = "Interneu" mocked_response = get_resolve_query_output[3] @@ -142,7 +142,7 @@ async def test_es_resolve(httpx_mock, get_resolve_query_output): "id": "http://api.brain-map.org/api/v2/data/Structure/184", }, ] - httpx_mock.reset(assert_all_responses_were_requested=True) + httpx_mock.reset() mtype = "Ventral neuron" mocked_response = get_resolve_query_output[5] @@ -221,7 +221,7 @@ async def test_resolve_query(httpx_mock, get_resolve_query_output): "id": "http://api.brain-map.org/api/v2/data/Structure/463", }, ] - httpx_mock.reset(assert_all_responses_were_requested=True) + httpx_mock.reset() httpx_mock.add_response(url=url, json=get_resolve_query_output[0]) @@ -252,7 +252,7 @@ async def test_resolve_query(httpx_mock, get_resolve_query_output): "id": "http://api.brain-map.org/api/v2/data/Structure/549", } ] - httpx_mock.reset(assert_all_responses_were_requested=True) + httpx_mock.reset() httpx_mock.add_response( url=url, json={ diff --git a/tests/test_utils.py b/tests/test_utils.py index 742388b..152232c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -184,7 +184,7 @@ async def test_get_file_from_KG_errors(httpx_mock): ) assert not_found.value.args[0] == "No file url was found." - httpx_mock.reset(assert_all_responses_were_requested=True) + httpx_mock.reset() # no file found corresponding to file_url test_file_url = "http://test_url.com" json_response = {