diff --git a/src/dumdum/client/app.py b/src/dumdum/client/app.py index 55fc9f4..6f10d70 100644 --- a/src/dumdum/client/app.py +++ b/src/dumdum/client/app.py @@ -67,18 +67,13 @@ def submit(self, coro: Awaitable[Any]) -> concurrent.futures.Future: fut.add_done_callback(log_fut_exception) return fut - async def attempt_connection(self, host: str, port: int, nick: str) -> bool: + async def attempt_connection(self, host: str, port: int, nick: str) -> None: self.client = AsyncClient(nick, event_callback=self._handle_event_threadsafe) coro = self._run_connection(host, port) self._connection_task = asyncio.create_task(coro) - async with asyncio.TaskGroup() as tg: - auth_task = tg.create_task(self.client._wait_for_authentication()) - tasks = [self._connection_task, auth_task] - await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - - return await auth_task + await self.client._wait_for_authentication() def _connect_lifetime_with_event_thread(self, event_thread: EventThread) -> None: # In our application we'll be running an asyncio event loop in diff --git a/src/dumdum/client/async_client.py b/src/dumdum/client/async_client.py index 50c5bfa..8ce135a 100644 --- a/src/dumdum/client/async_client.py +++ b/src/dumdum/client/async_client.py @@ -12,12 +12,18 @@ from .errors import AuthenticationFailedError +def maybe_create_fut(last: asyncio.Future | None) -> asyncio.Future: + if last is None: + return asyncio.get_running_loop().create_future() + return last + + class AsyncClient: _reader: asyncio.StreamReader | None _writer: asyncio.StreamWriter | None _read_task: asyncio.Task | None _addr: str | None - _auth_fut: asyncio.Future[bool] | None + _auth_fut: asyncio.Future[bool | None] | None def __init__( self, @@ -45,10 +51,18 @@ def addr(self) -> str: @contextlib.asynccontextmanager async def connect(self, host: str, port: int) -> AsyncIterator[Self]: self._addr = f"{host}:{port}" # FIXME: must be canonicalized - self._reader, self._writer = await asyncio.open_connection(host, port) + self._auth_fut = maybe_create_fut(self._auth_fut) + try: + connector = asyncio.open_connection(host, port) + self._reader, self._writer = await connector + except BaseException: + self._set_authentication(None) + raise + async with asyncio.TaskGroup() as tg: _read_coro = self._read_loop(self._reader, self._writer) self._read_task = tg.create_task(_read_coro) + self._read_task.add_done_callback(self._on_read_task_done) try: success = await self._handshake() @@ -101,16 +115,18 @@ async def _read_loop( self._handle_events(events) await writer.drain() # exert backpressure - async def _handshake(self) -> bool: + def _on_read_task_done(self, task: asyncio.Task) -> None: + self._set_authentication(None) + + async def _handshake(self) -> bool | None: assert self._writer is not None data = self._protocol.authenticate() self._writer.write(data) return await self._wait_for_authentication() - async def _wait_for_authentication(self) -> bool: - if self._auth_fut is None: - self._auth_fut = asyncio.get_running_loop().create_future() - return await self._auth_fut + async def _wait_for_authentication(self) -> bool | None: + self._auth_fut = maybe_create_fut(self._auth_fut) + return await asyncio.shield(self._auth_fut) def _handle_events(self, events: list[ClientEvent]) -> None: for event in events: @@ -124,9 +140,10 @@ def _handle_event(self, event: ClientEvent) -> None: self._set_authentication(event.success) self._dispatch_event(event) - def _set_authentication(self, success: bool) -> None: + def _set_authentication(self, result: bool | None) -> None: assert self._auth_fut is not None - self._auth_fut.set_result(success) + if not self._auth_fut.done(): + self._auth_fut.set_result(result) def _dispatch_event(self, event: ClientEvent) -> None: self.event_callback(event)