Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flow the YRoom exception handler down from the websocket server #1

Open
wants to merge 2 commits into
base: yroom-exception
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pycrdt_websocket/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class Websocket(Protocol):
await websocket.send(message)
```
"""
@staticmethod
def exception_handler(exception: Exception, log: Logger) -> bool:
...
return False

@property
def path(self) -> str:
Expand Down
20 changes: 16 additions & 4 deletions pycrdt_websocket/websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,26 @@ def _start_lock(self) -> Lock:
self.__start_lock = Lock()
return self.__start_lock

async def get_room(self, name: str) -> YRoom:
async def get_room(
self, name: str,
exception_handler: Callable[[Exception, Logger], bool] | None
) -> YRoom:
"""Get or create a room with the given name, and start it.

Arguments:
name: The room name.
exception_handler: A callable for handling exceptions raised in the YRoom.
If the exception in handled, should return `True`; otherwise, returns `False`.

Returns:
The room with the given name, or a new one if no room with that name was found.
"""
if name not in self.rooms.keys():
self.rooms[name] = YRoom(ready=self.rooms_ready, log=self.log)
self.rooms[name] = YRoom(
ready=self.rooms_ready,
log=self.log,
exception_handler=exception_handler
)
room = self.rooms[name]
await self.start_room(room)
return room
Expand Down Expand Up @@ -158,7 +167,10 @@ async def serve(self, websocket: Websocket) -> None:

try:
async with create_task_group():
room = await self.get_room(websocket.path)
# If the websocket interface includes a customer exception handler
# pass it to the YRoom.
exception_handler = getattr(websocket, "exception_handler", None)
room = await self.get_room(websocket.path, exception_handler=exception_handler)
await self.start_room(room)
await room.serve(websocket)
if self.auto_clean_rooms and not room.clients:
Expand Down Expand Up @@ -236,5 +248,5 @@ async def stop(self) -> None:

def exception_logger(exception: Exception, log: Logger) -> bool:
"""An exception handler that logs the exception and discards it."""
log.error("WebsocketServer exception", exc_info=exception)
log.error("PyCRDT WebsocketServer exception", exc_info=exception)
return True # the exception was handled
79 changes: 52 additions & 27 deletions pycrdt_websocket/yroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,16 @@ class YRoom:
_update_receive_stream: MemoryObjectReceiveStream
_task_group: TaskGroup | None = None
_started: Event | None = None
_stopped: Event
__start_lock: Lock | None = None
_subscription: Subscription | None = None

def __init__(
self, ready: bool = True, ystore: BaseYStore | None = None, log: Logger | None = None
self,
ready: bool = True,
ystore: BaseYStore | None = None,
exception_handler: Callable[[Exception, Logger], bool] | None = None,
log: Logger | None = None,
):
"""Initialize the object.

Expand All @@ -63,6 +68,8 @@ def __init__(
Arguments:
ready: Whether the internal YDoc is ready to be synchronized right away.
ystore: An optional store in which to persist document updates.
exception_handler: An optional callback to call when an exception is raised, that
returns True if the exception was handled.
log: An optional logger.
"""
self.ydoc = Doc()
Expand All @@ -76,6 +83,8 @@ def __init__(
self.log = log or getLogger(__name__)
self.clients = []
self._on_message = None
self.exception_handler = exception_handler
self._stopped = Event()

@property
def _start_lock(self) -> Lock:
Expand Down Expand Up @@ -138,30 +147,42 @@ async def _broadcast_updates(self):
# broadcast internal ydoc's update to all clients, that includes changes from the
# clients and changes from the backend (out-of-band changes)
for client in self.clients:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
message = create_update_message(update)
self._task_group.start_soon(client.send, message)
try:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
message = create_update_message(update)
self._task_group.start_soon(client.send, message)
except Exception as exception:
self._handle_exception(exception)
if self.ystore:
self.log.debug("Writing Y update to YStore")
self._task_group.start_soon(self.ystore.write, update)
try:
self._task_group.start_soon(self.ystore.write, update)
self.log.debug("Writing Y update to YStore")
except Exception as exception:
self._handle_exception(exception)

async def __aenter__(self) -> YRoom:
async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("YRoom already running")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._task_group = await exit_stack.enter_async_context(create_task_group())
self._exit_stack = exit_stack.pop_all()
await tg.start(partial(self.start, from_context_manager=True))
await self._task_group.start(partial(self.start, from_context_manager=True))

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
await self.stop()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

def _handle_exception(self, exception: Exception) -> None:
exception_handled = False
if self.exception_handler is not None:
exception_handled = self.exception_handler(exception, self.log)
if not exception_handled:
raise exception

async def start(
self,
*,
Expand All @@ -177,27 +198,32 @@ async def start(
task_status.started()
self.started.set()
assert self._task_group is not None
self._task_group.start_soon(self._stopped.wait)
self._task_group.start_soon(self._broadcast_updates)
return

async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("YRoom already running")

async with create_task_group() as self._task_group:
task_status.started()
self.started.set()
self._task_group.start_soon(self._broadcast_updates)
self._task_group.start_soon(self._watch_ready)
while True:
try:
async with create_task_group() as self._task_group:
if not self.started.is_set():
task_status.started()
self.started.set()
self._task_group.start_soon(self._stopped.wait)
self._task_group.start_soon(self._broadcast_updates)
self._task_group.start_soon(self._watch_ready)
return
except Exception as exception:
self._handle_exception(exception)

async def stop(self) -> None:
"""Stop the room."""
if self._task_group is None:
raise RuntimeError("YRoom not running")

if self._task_group is None:
return

self._stopped.set()
self._task_group.cancel_scope.cancel()
self._task_group = None
if self._subscription is not None:
Expand All @@ -209,10 +235,10 @@ async def serve(self, websocket: Websocket):
Arguments:
websocket: The WebSocket through which to serve the client.
"""
async with create_task_group() as tg:
self.clients.append(websocket)
await sync(self.ydoc, websocket, self.log)
try:
try:
async with create_task_group() as tg:
self.clients.append(websocket)
await sync(self.ydoc, websocket, self.log)
async for message in websocket:
# filter messages (e.g. awareness)
skip = False
Expand Down Expand Up @@ -245,8 +271,7 @@ async def serve(self, websocket: Websocket):
client.path,
)
tg.start_soon(client.send, message)
except Exception as e:
self.log.debug("Error serving endpoint: %s", websocket.path, exc_info=e)

# remove this client
self.clients = [c for c in self.clients if c != websocket]
# remove this client
self.clients = [c for c in self.clients if c != websocket]
except Exception as exception:
self._handle_exception(exception)