diff --git a/pycrdt_websocket/websocket.py b/pycrdt_websocket/websocket.py index 1fecaba..34c4bef 100644 --- a/pycrdt_websocket/websocket.py +++ b/pycrdt_websocket/websocket.py @@ -19,6 +19,10 @@ class Websocket(Protocol): await websocket.send(message) ``` """ + @staticmethod + def exception_handler(exception: Exception, log: Logger) -> bool: + ... + return False @property def path(self) -> str: diff --git a/pycrdt_websocket/websocket_server.py b/pycrdt_websocket/websocket_server.py index 1846346..3a365d0 100644 --- a/pycrdt_websocket/websocket_server.py +++ b/pycrdt_websocket/websocket_server.py @@ -71,17 +71,26 @@ def _start_lock(self) -> Lock: self.__start_lock = Lock() return self.__start_lock - async def get_room(self, name: str) -> YRoom: + async def get_room( + self, name: str, + exception_handler: Callable[[Exception, Logger], bool] | None + ) -> YRoom: """Get or create a room with the given name, and start it. Arguments: name: The room name. + exception_handler: A callable for handling exceptions raised in the YRoom. + If the exception in handled, should return `True`; otherwise, returns `False`. Returns: The room with the given name, or a new one if no room with that name was found. """ if name not in self.rooms.keys(): - self.rooms[name] = YRoom(ready=self.rooms_ready, log=self.log) + self.rooms[name] = YRoom( + ready=self.rooms_ready, + log=self.log, + exception_handler=exception_handler + ) room = self.rooms[name] await self.start_room(room) return room @@ -158,7 +167,10 @@ async def serve(self, websocket: Websocket) -> None: try: async with create_task_group(): - room = await self.get_room(websocket.path) + # If the websocket interface includes a customer exception handler + # pass it to the YRoom. + exception_handler = getattr(websocket, "exception_handler", None) + room = await self.get_room(websocket.path, exception_handler=exception_handler) await self.start_room(room) await room.serve(websocket) if self.auto_clean_rooms and not room.clients: @@ -236,5 +248,5 @@ async def stop(self) -> None: def exception_logger(exception: Exception, log: Logger) -> bool: """An exception handler that logs the exception and discards it.""" - log.error("WebsocketServer exception", exc_info=exception) + log.error("PyCRDT WebsocketServer exception", exc_info=exception) return True # the exception was handled diff --git a/pycrdt_websocket/yroom.py b/pycrdt_websocket/yroom.py index f0c12c8..0b6ec01 100644 --- a/pycrdt_websocket/yroom.py +++ b/pycrdt_websocket/yroom.py @@ -39,11 +39,16 @@ class YRoom: _update_receive_stream: MemoryObjectReceiveStream _task_group: TaskGroup | None = None _started: Event | None = None + _stopped: Event __start_lock: Lock | None = None _subscription: Subscription | None = None def __init__( - self, ready: bool = True, ystore: BaseYStore | None = None, log: Logger | None = None + self, + ready: bool = True, + ystore: BaseYStore | None = None, + exception_handler: Callable[[Exception, Logger], bool] | None = None, + log: Logger | None = None, ): """Initialize the object. @@ -63,6 +68,8 @@ def __init__( Arguments: ready: Whether the internal YDoc is ready to be synchronized right away. ystore: An optional store in which to persist document updates. + exception_handler: An optional callback to call when an exception is raised, that + returns True if the exception was handled. log: An optional logger. """ self.ydoc = Doc() @@ -76,6 +83,8 @@ def __init__( self.log = log or getLogger(__name__) self.clients = [] self._on_message = None + self.exception_handler = exception_handler + self._stopped = Event() @property def _start_lock(self) -> Lock: @@ -138,12 +147,18 @@ async def _broadcast_updates(self): # broadcast internal ydoc's update to all clients, that includes changes from the # clients and changes from the backend (out-of-band changes) for client in self.clients: - self.log.debug("Sending Y update to client with endpoint: %s", client.path) - message = create_update_message(update) - self._task_group.start_soon(client.send, message) + try: + self.log.debug("Sending Y update to client with endpoint: %s", client.path) + message = create_update_message(update) + self._task_group.start_soon(client.send, message) + except Exception as exception: + self._handle_exception(exception) if self.ystore: - self.log.debug("Writing Y update to YStore") - self._task_group.start_soon(self.ystore.write, update) + try: + self._task_group.start_soon(self.ystore.write, update) + self.log.debug("Writing Y update to YStore") + except Exception as exception: + self._handle_exception(exception) async def __aenter__(self) -> YRoom: async with self._start_lock: @@ -151,10 +166,9 @@ async def __aenter__(self) -> YRoom: raise RuntimeError("YRoom already running") async with AsyncExitStack() as exit_stack: - tg = create_task_group() - self._task_group = await exit_stack.enter_async_context(tg) + self._task_group = await exit_stack.enter_async_context(create_task_group()) self._exit_stack = exit_stack.pop_all() - await tg.start(partial(self.start, from_context_manager=True)) + await self._task_group.start(partial(self.start, from_context_manager=True)) return self @@ -162,6 +176,13 @@ async def __aexit__(self, exc_type, exc_value, exc_tb): await self.stop() return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) + def _handle_exception(self, exception: Exception) -> None: + exception_handled = False + if self.exception_handler is not None: + exception_handled = self.exception_handler(exception, self.log) + if not exception_handled: + raise exception + async def start( self, *, @@ -177,6 +198,7 @@ async def start( task_status.started() self.started.set() assert self._task_group is not None + self._task_group.start_soon(self._stopped.wait) self._task_group.start_soon(self._broadcast_updates) return @@ -184,20 +206,24 @@ async def start( if self._task_group is not None: raise RuntimeError("YRoom already running") - async with create_task_group() as self._task_group: - task_status.started() - self.started.set() - self._task_group.start_soon(self._broadcast_updates) - self._task_group.start_soon(self._watch_ready) + while True: + try: + async with create_task_group() as self._task_group: + if not self.started.is_set(): + task_status.started() + self.started.set() + self._task_group.start_soon(self._stopped.wait) + self._task_group.start_soon(self._broadcast_updates) + self._task_group.start_soon(self._watch_ready) + return + except Exception as exception: + self._handle_exception(exception) async def stop(self) -> None: """Stop the room.""" if self._task_group is None: raise RuntimeError("YRoom not running") - - if self._task_group is None: - return - + self._stopped.set() self._task_group.cancel_scope.cancel() self._task_group = None if self._subscription is not None: @@ -209,10 +235,10 @@ async def serve(self, websocket: Websocket): Arguments: websocket: The WebSocket through which to serve the client. """ - async with create_task_group() as tg: - self.clients.append(websocket) - await sync(self.ydoc, websocket, self.log) - try: + try: + async with create_task_group() as tg: + self.clients.append(websocket) + await sync(self.ydoc, websocket, self.log) async for message in websocket: # filter messages (e.g. awareness) skip = False @@ -245,8 +271,7 @@ async def serve(self, websocket: Websocket): client.path, ) tg.start_soon(client.send, message) - except Exception as e: - self.log.debug("Error serving endpoint: %s", websocket.path, exc_info=e) - - # remove this client - self.clients = [c for c in self.clients if c != websocket] + # remove this client + self.clients = [c for c in self.clients if c != websocket] + except Exception as exception: + self._handle_exception(exception)