Skip to content

Commit

Permalink
Merge pull request #72 from nlioc4/master
Browse files Browse the repository at this point in the history
Add logic to resubscribe to triggers on websocket reconnect
  • Loading branch information
leonhard-s authored Mar 10, 2024
2 parents 23dcf92 + 8fbff23 commit 42a949c
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions auraxium/event/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
_EventT = TypeVar('_EventT', bound=Event)
_EventT2 = TypeVar('_EventT2', bound=Event)
_CallbackT = Union[Callable[[_EventT], None],
Callable[[_EventT], Coroutine[Any, Any, None]]]
Callable[[_EventT], Coroutine[Any, Any, None]]]

_log = logging.getLogger('auraxium.ess')

Expand Down Expand Up @@ -170,6 +170,14 @@ def remove_trigger(self, trigger: Union[Trigger, str], *,
_log.info('All triggers have been removed, closing websocket')
self.loop.create_task(self.close())

def _subscribe_all(self):
"""Add subscription messages for every registered trigger.
This will add a subscription message for every trigger currently registered with the client.
Useful for resubscribing to all events after a disconnect.
"""
self._send_queue.extend([trigger.generate_subscription() for trigger in self.triggers])

async def close(self) -> None:
"""Gracefully shut down the client.
Expand All @@ -186,7 +194,7 @@ async def connect(self) -> None:
This will continuously loop until :meth:`EventClient.close` is
called.
If the WebSocket connection encounters and error, it will be
If the WebSocket connection encounters an error, it will be
automatically restarted.
Any event payloads received will be passed to
Expand Down Expand Up @@ -262,9 +270,12 @@ async def _connection_handler(self) -> None:
# NOTE: The following "async for" loop will cleanly restart the
# connection should it go down. Invoking "continue" manually may be
# used to manually force a reconnect if needed.

connection_failed = False
async for websocket in websockets.client.connect(str(url)):
_log.info('Connected to %s', url)
if connection_failed:
self._subscribe_all()
connection_failed = False
self.websocket = websocket

try:
Expand All @@ -273,6 +284,7 @@ async def _connection_handler(self) -> None:

except websockets.exceptions.ConnectionClosed:
_log.info('Connection closed, restarting...')
connection_failed = True
continue

if not self._open:
Expand Down Expand Up @@ -312,22 +324,22 @@ async def _handle_websocket(self, timeout: float = 0.1) -> None:
def trigger(self, event: Type[_EventT], *, name: Optional[str] = None,
**kwargs: Any) -> Callable[[_CallbackT[_EventT]], None]:
# Single event variant (checks callback argument type)
... # pragma: no cover
... # pragma: no cover

@overload
def trigger(self, event: Type[_EventT],
arg1: Type[_EventT], *args: Type[_EventT2],
name: Optional[str] = None, **kwargs: Any) -> Callable[
[_CallbackT[Union[_EventT, _EventT2]]], None]:
[_CallbackT[Union[_EventT, _EventT2]]], None]:
# Two event variant (checks callback argument type)
... # pragma: no cover
... # pragma: no cover

@overload
def trigger(self, event: Union[str, Type[Event]],
*args: Union[str, Type[Event]], name: Optional[str] = None,
**kwargs: Any) -> Callable[[_CallbackT[Event]], None]:
# Generic fallback variant (callback argument type not checked)
... # pragma: no cover
... # pragma: no cover

def trigger(self, event: Union[str, Type[Event]],
*args: Union[str, Type[Event]], name: Optional[str] = None,
Expand Down

0 comments on commit 42a949c

Please sign in to comment.