Skip to content

Commit

Permalink
change yroom class attribute to instance attribute and stop ystore in…
Browse files Browse the repository at this point in the history
… stop method
  • Loading branch information
Jialin Zhang committed May 3, 2024
1 parent 1cd727f commit 44b8f55
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 14 deletions.
21 changes: 15 additions & 6 deletions pycrdt_websocket/yroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ class YRoom:
_on_message: Callable[[bytes], Awaitable[bool] | bool] | None
_update_send_stream: MemoryObjectSendStream
_update_receive_stream: MemoryObjectReceiveStream
_task_group: TaskGroup | None = None
_started: Event | None = None
_task_group: TaskGroup | None
_started: Event | None
_stopped: Event
__start_lock: Lock | None = None
_subscription: Subscription | None = None

__start_lock: Lock | None
_subscription: Subscription | None
def __init__(
self,
ready: bool = True,
Expand Down Expand Up @@ -82,7 +82,11 @@ def __init__(
self._on_message = None
self.exception_handler = exception_handler
self._stopped = Event()

self._task_group = None
self._started = None
self.__start_lock = None
self._subscription = None

@property
def _start_lock(self) -> Lock:
if self.__start_lock is None:
Expand Down Expand Up @@ -230,6 +234,11 @@ async def stop(self) -> None:
self._stopped.set()
self._task_group.cancel_scope.cancel()
self._task_group = None
if self.ystore is not None:
try:
await self.ystore.stop()
except RuntimeError:
pass
if self._subscription is not None:
self.ydoc.unobserve(self._subscription)

Expand Down
26 changes: 19 additions & 7 deletions pycrdt_websocket/ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class BaseYStore(ABC):
metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None
version = 2
_started: Event | None = None
_stopped: Event | None = None
_task_group: TaskGroup | None = None
__start_lock: Lock | None = None

Expand All @@ -50,6 +51,12 @@ def started(self) -> Event:
self._started = Event()
return self._started

@property
def stopped(self) -> Event:
if self._stopped is None:
self._stopped = Event()
return self._stopped

@property
def _start_lock(self) -> Lock:
if self.__start_lock is None:
Expand Down Expand Up @@ -96,12 +103,14 @@ async def start(
async with create_task_group() as self._task_group:
task_status.started()
self.started.set()
await self.stopped.wait()

async def stop(self) -> None:
"""Stop the store."""
if self._task_group is None:
raise RuntimeError("YStore not running")

self.stopped.set()
self._task_group.cancel_scope.cancel()
self._task_group = None

Expand Down Expand Up @@ -309,7 +318,7 @@ class MySQLiteYStore(SQLiteYStore):
document_ttl: int | None = None
path: str
lock: Lock
db_initialized: Event
db_initialized: Event | None
_db: Connection

def __init__(
Expand All @@ -329,6 +338,7 @@ def __init__(
self.metadata_callback = metadata_callback
self.log = log or getLogger(__name__)
self.lock = Lock()
self.db_initialized = None

async def start(
self,
Expand Down Expand Up @@ -356,10 +366,11 @@ async def start(
self._task_group.start_soon(self._init_db)
task_status.started()
self.started.set()
await self.stopped.wait()

async def stop(self) -> None:
"""Stop the store."""
if hasattr(self, "db_initialized") and self.db_initialized.is_set():
if self.db_initialized is not None and self.db_initialized.is_set():
await self._db.close()
await super().stop()

Expand Down Expand Up @@ -405,6 +416,7 @@ async def _init_db(self):
await db.commit()
await db.close()
self._db = await connect(self.db_path)
assert self.db_initialized is not None
self.db_initialized.set()

async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
Expand All @@ -413,8 +425,8 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
Returns:
A tuple of (update, metadata, timestamp) for each update.
"""
if not hasattr(self, "db_initialized"):
raise RuntimeError("ystore is not started")
if self.db_initialized is None:
raise RuntimeError("YStore not started")
await self.db_initialized.wait()
try:
async with self.lock:
Expand All @@ -438,8 +450,8 @@ async def write(self, data: bytes) -> None:
Arguments:
data: The update to store.
"""
if not hasattr(self, "db_initialized"):
raise RuntimeError("ystore is not started")
if self.db_initialized is None:
raise RuntimeError("YStore not started")
await self.db_initialized.wait()
async with self.lock:
# first, determine time elapsed since last update
Expand Down Expand Up @@ -473,4 +485,4 @@ async def write(self, data: bytes) -> None:
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, data, metadata, time.time()),
)
await self._db.commit()
await self._db.commit()
28 changes: 27 additions & 1 deletion tests/test_ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from pathlib import Path
from unittest.mock import patch

from pycrdt_websocket.websocket_server import exception_logger
from pycrdt_websocket.yroom import YRoom
import pytest
from anyio import create_task_group
from pycrdt import Map
from anyio import create_task_group, sleep
from sqlite_anyio import connect
from utils import StartStopContextManager, YDocTest

Expand Down Expand Up @@ -124,3 +127,26 @@ async def test_version(YStore, ystore_api, caplog):
YStore.version = prev_version
async with ystore as ystore:
await ystore.write(b"bar")


@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True)
@pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True)
@pytest.mark.parametrize("YStore", (MyTempFileYStore, MySQLiteYStore))
async def test_yroom_stop(yws_server, yws_provider, YStore):
port, server = yws_server
ystore = YStore("ystore", metadata_callback=MetadataCallback())
yroom = YRoom(ystore=ystore, exception_handler=exception_logger)
yroom.ydoc, _ = yws_provider
await server.start_room(yroom)
yroom.ydoc["map"] = ymap1 = Map()
ymap1["key"] = "value"
ymap1["key2"] = "value2"
await sleep(1)
assert yroom._task_group is not None
assert not yroom._task_group.cancel_scope.cancel_called
assert ystore._task_group is not None
assert not ystore._task_group.cancel_scope.cancel_called
await yroom.stop()
assert yroom._task_group is None
assert ystore._task_group is None

0 comments on commit 44b8f55

Please sign in to comment.