Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change yroom class attribute to instance attribute and stop ystore in stop method #39

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions pycrdt_websocket/yroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ class YRoom:
_on_message: Callable[[bytes], Awaitable[bool] | bool] | None
_update_send_stream: MemoryObjectSendStream
_update_receive_stream: MemoryObjectReceiveStream
_task_group: TaskGroup | None = None
_started: Event | None = None
_task_group: TaskGroup | None
_started: Event | None
_stopped: Event
__start_lock: Lock | None = None
_subscription: Subscription | None = None
__start_lock: Lock | None
_subscription: Subscription | None

def __init__(
self,
Expand Down Expand Up @@ -82,6 +82,10 @@ def __init__(
self._on_message = None
self.exception_handler = exception_handler
self._stopped = Event()
self._task_group = None
self._started = None
self.__start_lock = None
self._subscription = None

@property
def _start_lock(self) -> Lock:
Expand Down Expand Up @@ -230,6 +234,11 @@ async def stop(self) -> None:
self._stopped.set()
self._task_group.cancel_scope.cancel()
self._task_group = None
if self.ystore is not None:
try:
await self.ystore.stop()
except RuntimeError:
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should automatically stop the YStore when the YRoom is stopped.
WebsocketServer has an auto_clean_room parameter, maybe there should be an auto_stop_store as well, that we would use to do so if set to True.
Also, we should have YStore use the same exception handler pattern that we used for YRoom and WebsocketServer, and not catch exceptions here.

Copy link
Collaborator Author

@jzhang20133 jzhang20133 May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a good idea to make it configurable. And also +1 for adding exception handler pattern in other logic to protect task group. I will address those.
In this stop method case, we are trying to except a RuntimeError that is thrown if "YStore not running" in cases like ystore is already stopped or ystore is not started yet but room crashed. I can create a specific exception type for this case.

if self._subscription is not None:
self.ydoc.unobserve(self._subscription)

Expand Down
24 changes: 18 additions & 6 deletions pycrdt_websocket/ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class BaseYStore(ABC):
metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None
version = 2
_started: Event | None = None
_stopped: Event | None = None
_task_group: TaskGroup | None = None
__start_lock: Lock | None = None

Expand All @@ -50,6 +51,12 @@ def started(self) -> Event:
self._started = Event()
return self._started

@property
def stopped(self) -> Event:
if self._stopped is None:
self._stopped = Event()
return self._stopped

@property
def _start_lock(self) -> Lock:
if self.__start_lock is None:
Expand Down Expand Up @@ -96,12 +103,14 @@ async def start(
async with create_task_group() as self._task_group:
task_status.started()
self.started.set()
await self.stopped.wait()

async def stop(self) -> None:
"""Stop the store."""
if self._task_group is None:
raise RuntimeError("YStore not running")

self.stopped.set()
self._task_group.cancel_scope.cancel()
self._task_group = None

Expand Down Expand Up @@ -309,7 +318,7 @@ class MySQLiteYStore(SQLiteYStore):
document_ttl: int | None = None
path: str
lock: Lock
db_initialized: Event
db_initialized: Event | None
_db: Connection

def __init__(
Expand All @@ -329,6 +338,7 @@ def __init__(
self.metadata_callback = metadata_callback
self.log = log or getLogger(__name__)
self.lock = Lock()
self.db_initialized = None

async def start(
self,
Expand Down Expand Up @@ -356,10 +366,11 @@ async def start(
self._task_group.start_soon(self._init_db)
task_status.started()
self.started.set()
await self.stopped.wait()

async def stop(self) -> None:
"""Stop the store."""
if hasattr(self, "db_initialized") and self.db_initialized.is_set():
if self.db_initialized is not None and self.db_initialized.is_set():
await self._db.close()
await super().stop()

Expand Down Expand Up @@ -405,6 +416,7 @@ async def _init_db(self):
await db.commit()
await db.close()
self._db = await connect(self.db_path)
assert self.db_initialized is not None
self.db_initialized.set()

async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
Expand All @@ -413,8 +425,8 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
Returns:
A tuple of (update, metadata, timestamp) for each update.
"""
if not hasattr(self, "db_initialized"):
raise RuntimeError("ystore is not started")
if self.db_initialized is None:
raise RuntimeError("YStore not started")
await self.db_initialized.wait()
try:
async with self.lock:
Expand All @@ -438,8 +450,8 @@ async def write(self, data: bytes) -> None:
Arguments:
data: The update to store.
"""
if not hasattr(self, "db_initialized"):
raise RuntimeError("ystore is not started")
if self.db_initialized is None:
raise RuntimeError("YStore not started")
await self.db_initialized.wait()
async with self.lock:
# first, determine time elapsed since last update
Expand Down
27 changes: 26 additions & 1 deletion tests/test_ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from unittest.mock import patch

import pytest
from anyio import create_task_group
from anyio import create_task_group, sleep
from pycrdt import Map
from sqlite_anyio import connect
from utils import StartStopContextManager, YDocTest

from pycrdt_websocket.websocket_server import exception_logger
from pycrdt_websocket.yroom import YRoom
from pycrdt_websocket.ystore import SQLiteYStore, TempFileYStore

pytestmark = pytest.mark.anyio
Expand Down Expand Up @@ -124,3 +127,25 @@ async def test_version(YStore, ystore_api, caplog):
YStore.version = prev_version
async with ystore as ystore:
await ystore.write(b"bar")


@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True)
@pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True)
@pytest.mark.parametrize("YStore", (MyTempFileYStore, MySQLiteYStore))
async def test_yroom_stop(yws_server, yws_provider, YStore):
port, server = yws_server
ystore = YStore("ystore", metadata_callback=MetadataCallback())
yroom = YRoom(ystore=ystore, exception_handler=exception_logger)
yroom.ydoc, _ = yws_provider
await server.start_room(yroom)
yroom.ydoc["map"] = ymap1 = Map()
ymap1["key"] = "value"
ymap1["key2"] = "value2"
await sleep(1)
assert yroom._task_group is not None
assert not yroom._task_group.cancel_scope.cancel_called
assert ystore._task_group is not None
assert not ystore._task_group.cancel_scope.cancel_called
await yroom.stop()
assert yroom._task_group is None
assert ystore._task_group is None
Loading