diff --git a/src/dumdum/client/async_client.py b/src/dumdum/client/async_client.py index c440c38..4e09be1 100644 --- a/src/dumdum/client/async_client.py +++ b/src/dumdum/client/async_client.py @@ -37,9 +37,13 @@ def __init__( nick: str, *, event_callback: Callable[[ClientEvent], Any], + drain_timeout: float = 30, + close_timeout: float = 5, ) -> None: self.nick = nick self.event_callback = event_callback + self.drain_timeout = drain_timeout + self.close_timeout = close_timeout self._protocol = Client(nick) self._reader = None @@ -92,8 +96,7 @@ async def close(self) -> None: # Any exceptions here will be repeated in _read_loop() self._writer.close() - with contextlib.suppress(Exception): - await self._writer.wait_closed() + await self._wait_closed() async def send_message(self, channel_name: str, content: str) -> None: data = self._protocol.send_message(channel_name, content) @@ -134,7 +137,7 @@ async def _read_loop( events, outgoing = self._protocol.receive_bytes(data) writer.write(outgoing) await self._handle_events(events) - await writer.drain() # exert backpressure + await self._drain() # exert backpressure async def _handshake(self) -> bool | None: assert self._writer is not None @@ -185,4 +188,14 @@ def _dispatch_event(self, event: ClientEvent) -> None: async def _send_and_drain(self, data: bytes) -> None: assert self._writer is not None self._writer.write(data) - await self._writer.drain() + await self._drain() + + async def _drain(self) -> None: + assert self._writer is not None + await asyncio.wait_for(self._writer.drain(), timeout=self.drain_timeout) + + async def _wait_closed(self) -> None: + assert self._writer is not None + timeout = self.close_timeout + with contextlib.suppress(Exception): + await asyncio.wait_for(self._writer.wait_closed(), timeout=timeout) diff --git a/src/dumdum/server/connection.py b/src/dumdum/server/connection.py index c4c5e28..189c275 100644 --- a/src/dumdum/server/connection.py +++ b/src/dumdum/server/connection.py @@ -35,7 +35,11 @@ async def communicate(self) -> None: events, outgoing = self.server.receive_bytes(data) self.writer.write(outgoing) await self._handle_events(events) - await self.writer.drain() # exert backpressure + await self._drain() # exert backpressure async def _handle_events(self, events: list[ServerEvent]) -> None: await self.manager._handle_events(self, events) + + async def _drain(self) -> None: + timeout = self.manager.drain_timeout + await asyncio.wait_for(self.writer.drain(), timeout=timeout) diff --git a/src/dumdum/server/manager.py b/src/dumdum/server/manager.py index 40a3672..0a5743c 100644 --- a/src/dumdum/server/manager.py +++ b/src/dumdum/server/manager.py @@ -23,10 +23,19 @@ class Manager: - def __init__(self, state: ServerState, ssl: ssl.SSLContext | None) -> None: + def __init__( + self, + state: ServerState, + ssl: ssl.SSLContext | None, + *, + drain_timeout: float = 30, + close_timeout: float = 5, + ) -> None: self.state = state self.connections: list[Connection] = [] self.ssl = ssl + self.drain_timeout = drain_timeout + self.close_timeout = close_timeout async def accept_connection( self, @@ -52,14 +61,18 @@ async def accept_connection( log.info("Connection %s has disconnected", addr) writer.close() - with contextlib.suppress(Exception): - await writer.wait_closed() + await self._wait_closed(writer) self._close_connection(connection) def _create_server(self) -> Server: return Server() + async def _wait_closed(self, writer: asyncio.StreamWriter) -> None: + timeout = self.close_timeout + with contextlib.suppress(Exception): + await asyncio.wait_for(writer.wait_closed(), timeout=timeout) + async def _handle_events(self, conn: Connection, events: list[ServerEvent]) -> None: for event in events: await self._handle_event(conn, event)