Skip to content

Commit

Permalink
fix: potential dead lock during client connection
Browse files Browse the repository at this point in the history
Connection closure now always sets result of _auth_fut,
preventing attempt_connection() from hanging indefinitely.
  • Loading branch information
thegamecracks committed Mar 23, 2024
1 parent 278a771 commit eb7a397
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
9 changes: 2 additions & 7 deletions src/dumdum/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,13 @@ def submit(self, coro: Awaitable[Any]) -> concurrent.futures.Future:
fut.add_done_callback(log_fut_exception)
return fut

async def attempt_connection(self, host: str, port: int, nick: str) -> bool:
async def attempt_connection(self, host: str, port: int, nick: str) -> None:
self.client = AsyncClient(nick, event_callback=self._handle_event_threadsafe)

coro = self._run_connection(host, port)
self._connection_task = asyncio.create_task(coro)

async with asyncio.TaskGroup() as tg:
auth_task = tg.create_task(self.client._wait_for_authentication())
tasks = [self._connection_task, auth_task]
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

return await auth_task
await self.client._wait_for_authentication()

def _connect_lifetime_with_event_thread(self, event_thread: EventThread) -> None:
# In our application we'll be running an asyncio event loop in
Expand Down
35 changes: 26 additions & 9 deletions src/dumdum/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@
from .errors import AuthenticationFailedError


def maybe_create_fut(last: asyncio.Future | None) -> asyncio.Future:
if last is None:
return asyncio.get_running_loop().create_future()
return last


class AsyncClient:
_reader: asyncio.StreamReader | None
_writer: asyncio.StreamWriter | None
_read_task: asyncio.Task | None
_addr: str | None
_auth_fut: asyncio.Future[bool] | None
_auth_fut: asyncio.Future[bool | None] | None

def __init__(
self,
Expand Down Expand Up @@ -45,10 +51,18 @@ def addr(self) -> str:
@contextlib.asynccontextmanager
async def connect(self, host: str, port: int) -> AsyncIterator[Self]:
self._addr = f"{host}:{port}" # FIXME: must be canonicalized
self._reader, self._writer = await asyncio.open_connection(host, port)
self._auth_fut = maybe_create_fut(self._auth_fut)
try:
connector = asyncio.open_connection(host, port)
self._reader, self._writer = await connector
except BaseException:
self._set_authentication(None)
raise

async with asyncio.TaskGroup() as tg:
_read_coro = self._read_loop(self._reader, self._writer)
self._read_task = tg.create_task(_read_coro)
self._read_task.add_done_callback(self._on_read_task_done)

try:
success = await self._handshake()
Expand Down Expand Up @@ -101,16 +115,18 @@ async def _read_loop(
self._handle_events(events)
await writer.drain() # exert backpressure

async def _handshake(self) -> bool:
def _on_read_task_done(self, task: asyncio.Task) -> None:
self._set_authentication(None)

async def _handshake(self) -> bool | None:
assert self._writer is not None
data = self._protocol.authenticate()
self._writer.write(data)
return await self._wait_for_authentication()

async def _wait_for_authentication(self) -> bool:
if self._auth_fut is None:
self._auth_fut = asyncio.get_running_loop().create_future()
return await self._auth_fut
async def _wait_for_authentication(self) -> bool | None:
self._auth_fut = maybe_create_fut(self._auth_fut)
return await asyncio.shield(self._auth_fut)

def _handle_events(self, events: list[ClientEvent]) -> None:
for event in events:
Expand All @@ -124,9 +140,10 @@ def _handle_event(self, event: ClientEvent) -> None:
self._set_authentication(event.success)
self._dispatch_event(event)

def _set_authentication(self, success: bool) -> None:
def _set_authentication(self, result: bool | None) -> None:
assert self._auth_fut is not None
self._auth_fut.set_result(success)
if not self._auth_fut.done():
self._auth_fut.set_result(result)

def _dispatch_event(self, event: ClientEvent) -> None:
self.event_callback(event)
Expand Down

0 comments on commit eb7a397

Please sign in to comment.