From 44b8f55e1aba8b0bd5c2fbabfb3e2482f0465fc6 Mon Sep 17 00:00:00 2001 From: Jialin Zhang Date: Mon, 29 Apr 2024 17:47:37 -0700 Subject: [PATCH] change yroom class attribute to instance attribute and stop ystore in stop method --- pycrdt_websocket/yroom.py | 21 +++++++++++++++------ pycrdt_websocket/ystore.py | 26 +++++++++++++++++++------- tests/test_ystore.py | 28 +++++++++++++++++++++++++++- 3 files changed, 61 insertions(+), 14 deletions(-) diff --git a/pycrdt_websocket/yroom.py b/pycrdt_websocket/yroom.py index a2a8b3d..540df12 100644 --- a/pycrdt_websocket/yroom.py +++ b/pycrdt_websocket/yroom.py @@ -37,12 +37,12 @@ class YRoom: _on_message: Callable[[bytes], Awaitable[bool] | bool] | None _update_send_stream: MemoryObjectSendStream _update_receive_stream: MemoryObjectReceiveStream - _task_group: TaskGroup | None = None - _started: Event | None = None + _task_group: TaskGroup | None + _started: Event | None _stopped: Event - __start_lock: Lock | None = None - _subscription: Subscription | None = None - + __start_lock: Lock | None + _subscription: Subscription | None + def __init__( self, ready: bool = True, @@ -82,7 +82,11 @@ def __init__( self._on_message = None self.exception_handler = exception_handler self._stopped = Event() - + self._task_group = None + self._started = None + self.__start_lock = None + self._subscription = None + @property def _start_lock(self) -> Lock: if self.__start_lock is None: @@ -230,6 +234,11 @@ async def stop(self) -> None: self._stopped.set() self._task_group.cancel_scope.cancel() self._task_group = None + if self.ystore is not None: + try: + await self.ystore.stop() + except RuntimeError: + pass if self._subscription is not None: self.ydoc.unobserve(self._subscription) diff --git a/pycrdt_websocket/ystore.py b/pycrdt_websocket/ystore.py index 5772189..a6a7058 100644 --- a/pycrdt_websocket/ystore.py +++ b/pycrdt_websocket/ystore.py @@ -28,6 +28,7 @@ class BaseYStore(ABC): metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None version = 2 _started: Event | None = None + _stopped: Event | None = None _task_group: TaskGroup | None = None __start_lock: Lock | None = None @@ -50,6 +51,12 @@ def started(self) -> Event: self._started = Event() return self._started + @property + def stopped(self) -> Event: + if self._stopped is None: + self._stopped = Event() + return self._stopped + @property def _start_lock(self) -> Lock: if self.__start_lock is None: @@ -96,12 +103,14 @@ async def start( async with create_task_group() as self._task_group: task_status.started() self.started.set() + await self.stopped.wait() async def stop(self) -> None: """Stop the store.""" if self._task_group is None: raise RuntimeError("YStore not running") + self.stopped.set() self._task_group.cancel_scope.cancel() self._task_group = None @@ -309,7 +318,7 @@ class MySQLiteYStore(SQLiteYStore): document_ttl: int | None = None path: str lock: Lock - db_initialized: Event + db_initialized: Event | None _db: Connection def __init__( @@ -329,6 +338,7 @@ def __init__( self.metadata_callback = metadata_callback self.log = log or getLogger(__name__) self.lock = Lock() + self.db_initialized = None async def start( self, @@ -356,10 +366,11 @@ async def start( self._task_group.start_soon(self._init_db) task_status.started() self.started.set() + await self.stopped.wait() async def stop(self) -> None: """Stop the store.""" - if hasattr(self, "db_initialized") and self.db_initialized.is_set(): + if self.db_initialized is not None and self.db_initialized.is_set(): await self._db.close() await super().stop() @@ -405,6 +416,7 @@ async def _init_db(self): await db.commit() await db.close() self._db = await connect(self.db_path) + assert self.db_initialized is not None self.db_initialized.set() async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: @@ -413,8 +425,8 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: Returns: A tuple of (update, metadata, timestamp) for each update. """ - if not hasattr(self, "db_initialized"): - raise RuntimeError("ystore is not started") + if self.db_initialized is None: + raise RuntimeError("YStore not started") await self.db_initialized.wait() try: async with self.lock: @@ -438,8 +450,8 @@ async def write(self, data: bytes) -> None: Arguments: data: The update to store. """ - if not hasattr(self, "db_initialized"): - raise RuntimeError("ystore is not started") + if self.db_initialized is None: + raise RuntimeError("YStore not started") await self.db_initialized.wait() async with self.lock: # first, determine time elapsed since last update @@ -473,4 +485,4 @@ async def write(self, data: bytes) -> None: "INSERT INTO yupdates VALUES (?, ?, ?, ?)", (self.path, data, metadata, time.time()), ) - await self._db.commit() + await self._db.commit() \ No newline at end of file diff --git a/tests/test_ystore.py b/tests/test_ystore.py index c7bf729..cbe5de5 100644 --- a/tests/test_ystore.py +++ b/tests/test_ystore.py @@ -3,8 +3,11 @@ from pathlib import Path from unittest.mock import patch +from pycrdt_websocket.websocket_server import exception_logger +from pycrdt_websocket.yroom import YRoom import pytest -from anyio import create_task_group +from pycrdt import Map +from anyio import create_task_group, sleep from sqlite_anyio import connect from utils import StartStopContextManager, YDocTest @@ -124,3 +127,26 @@ async def test_version(YStore, ystore_api, caplog): YStore.version = prev_version async with ystore as ystore: await ystore.write(b"bar") + + +@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True) +@pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True) +@pytest.mark.parametrize("YStore", (MyTempFileYStore, MySQLiteYStore)) +async def test_yroom_stop(yws_server, yws_provider, YStore): + port, server = yws_server + ystore = YStore("ystore", metadata_callback=MetadataCallback()) + yroom = YRoom(ystore=ystore, exception_handler=exception_logger) + yroom.ydoc, _ = yws_provider + await server.start_room(yroom) + yroom.ydoc["map"] = ymap1 = Map() + ymap1["key"] = "value" + ymap1["key2"] = "value2" + await sleep(1) + assert yroom._task_group is not None + assert not yroom._task_group.cancel_scope.cancel_called + assert ystore._task_group is not None + assert not ystore._task_group.cancel_scope.cancel_called + await yroom.stop() + assert yroom._task_group is None + assert ystore._task_group is None +