From e8cf31de6f06f09a2679b0a945e598d40b6f2909 Mon Sep 17 00:00:00 2001 From: Birnendampf <93873756+Birnendampf@users.noreply.github.com> Date: Thu, 21 Nov 2024 22:05:53 +0100 Subject: [PATCH 1/2] implement tasks set --- asyncssh/connection.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/asyncssh/connection.py b/asyncssh/connection.py index f7ee231..1375c2b 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -926,6 +926,8 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._close_event = asyncio.Event() + self._tasks: Set[asyncio.Task[None]] = set() + self._server_host_key_algs: Optional[Sequence[bytes]] = None self._logger = logger.get_child( @@ -1083,6 +1085,7 @@ def _reap_task(self, task_logger: Optional[SSHLogger], task: 'asyncio.Task[None]') -> None: """Collect result of an async task, reporting errors""" + self._tasks.discard(task) # pylint: disable=broad-except try: task.result() @@ -1101,6 +1104,7 @@ def create_task(self, coro: Awaitable[None], task = asyncio.ensure_future(coro) task.add_done_callback(partial(self._reap_task, task_logger)) + self._tasks.add(task) return task def is_client(self) -> bool: @@ -2727,6 +2731,9 @@ def abort(self) -> None: self.logger.info('Aborting connection') + # cancel all running tasks + for task in self._tasks: + task.cancel() self._force_close(None) def close(self) -> None: @@ -2754,6 +2761,7 @@ async def wait_closed(self) -> None: await self._agent.wait_closed() await self._close_event.wait() + await asyncio.gather(*self._tasks, return_exceptions=True) def disconnect(self, code: int, reason: str, lang: str = DEFAULT_LANG) -> None: From dc8ce6c645c20b716e7cfdf58883ad4779dae46f Mon Sep 17 00:00:00 2001 From: Birnendampf <93873756+Birnendampf@users.noreply.github.com> Date: Thu, 21 Nov 2024 23:46:41 +0100 Subject: [PATCH 2/2] fix a bug where cancelling wait_closed would also cancel the tasks. should this be intended? it would mean putting a timeout on wait_closed would also cancel the tasks if it runs out, which seems counterintuitive --- asyncssh/connection.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/asyncssh/connection.py b/asyncssh/connection.py index 1375c2b..d0c698f 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -2761,7 +2761,8 @@ async def wait_closed(self) -> None: await self._agent.wait_closed() await self._close_event.wait() - await asyncio.gather(*self._tasks, return_exceptions=True) + if self._tasks: + await asyncio.wait(self._tasks) def disconnect(self, code: int, reason: str, lang: str = DEFAULT_LANG) -> None: