From 3fde461f944b67c6494004b9b491a7f0bc551b0e Mon Sep 17 00:00:00 2001
From: Russ Allbery <rra@lsst.org>
Date: Fri, 1 Sep 2023 15:40:37 -0700
Subject: [PATCH] Add AsyncMultiQueue data structure

Add an asyncio multiple writer, multiple reader queue, and convert
the Kubernetes mock to use it instead of implementing its own. This
data structure is also used in the JupyterHub REST spawner and in
the Nublado lab controller.
---
 changelog.d/20230905_090034_rra_DM_40638.md |   3 +
 docs/user-guide/asyncio-queue.rst           |  51 +++++++
 docs/user-guide/index.rst                   |   1 +
 src/safir/asyncio.py                        | 140 +++++++++++++++++++-
 src/safir/testing/kubernetes.py             |  53 +++-----
 tests/asyncio_test.py                       |  51 ++++++-
 6 files changed, 257 insertions(+), 42 deletions(-)
 create mode 100644 changelog.d/20230905_090034_rra_DM_40638.md
 create mode 100644 docs/user-guide/asyncio-queue.rst

diff --git a/changelog.d/20230905_090034_rra_DM_40638.md b/changelog.d/20230905_090034_rra_DM_40638.md
new file mode 100644
index 00000000..22d5b627
--- /dev/null
+++ b/changelog.d/20230905_090034_rra_DM_40638.md
@@ -0,0 +1,3 @@
+### New features
+
+- Add new `safir.asyncio.AsyncMultiQueue` data structure, which is an asyncio multi-reader queue that delivers all messages to each reader independently.
diff --git a/docs/user-guide/asyncio-queue.rst b/docs/user-guide/asyncio-queue.rst
new file mode 100644
index 00000000..0f19f9c6
--- /dev/null
+++ b/docs/user-guide/asyncio-queue.rst
@@ -0,0 +1,51 @@
+#######################################
+Using the asyncio multiple-reader queue
+#######################################
+
+`~asyncio.Queue`, provided by the basic library, is an asyncio queue implementation, but the protocol it implements delivers each item in the queue to only one reader.
+In some cases, you may need behavior more like a publish/subscribe queue: multiple readers all see the full contents of the queue, independently.
+
+Safir provides the `~safir.asyncio.AsyncMultiQueue` data structure for this use case.
+Its API is somewhat inspired by that of `~asyncio.Queue`, but it is intended for use as an async iterator rather than by calling a ``get`` method.
+
+The writer should use the queue as follows:
+
+.. code-block:: python
+
+   from safir.asyncio import AsyncMultiQueue
+
+
+   queue = AsyncMultiQueue[str]()
+   queue.put("soemthing")
+   queue.put("else")
+
+   # Calling clear will deliver the contents of the queue to all readers
+   # and then tell them that the queue of data has ended so their
+   # iterators will stop.
+   queue.clear()
+
+The type information for `~safir.asyncio.AsyncMultiQueue` can be any type.
+Note that the writer interface is fully synchronous.
+
+A typical reader looks like this:
+
+.. code-block:: python
+
+   async for item in queue:
+       await do_something(item)
+
+This iterates over the full contents of the queue until ``clear`` is called by the writer.
+
+Readers can also start at any position and specify a timeout.
+The timeout, if given, is the total length of time the iterator is allowed to run, not the time to wait for the next element.
+
+.. code-block:: python
+
+   from datetime import timedelta
+
+
+   timeout = timedelta(seconds=5)
+   async for item in queue.aiter_from(4, timeout):
+       await do_something(item)
+
+This reader will ignore all elements until the fourth, and will raise `TimeoutError` after five seconds of total time in the iterator.
diff --git a/docs/user-guide/index.rst b/docs/user-guide/index.rst
index ad6ee052..73c951c3 100644
--- a/docs/user-guide/index.rst
+++ b/docs/user-guide/index.rst
@@ -31,3 +31,4 @@ User guide
    slack-webhook
    github-apps/index
    click
+   asyncio-queue
diff --git a/src/safir/asyncio.py b/src/safir/asyncio.py
index 219d621a..a1078d17 100644
--- a/src/safir/asyncio.py
+++ b/src/safir/asyncio.py
@@ -3,13 +3,147 @@
 from __future__ import annotations
 
 import asyncio
-from collections.abc import Callable, Coroutine
+from collections.abc import AsyncIterator, Callable, Coroutine
+from datetime import timedelta
 from functools import wraps
-from typing import Any, TypeVar
+from types import EllipsisType
+from typing import Any, Generic, TypeVar
 
+from .datetime import current_datetime
+
+#: Type variable of objects being stored in `AsyncMultiQueue`.
 T = TypeVar("T")
 
-__all__ = ["run_with_asyncio"]
+__all__ = [
+    "AsyncMultiQueue",
+    "run_with_asyncio",
+    "T",
+]
+
+
+class AsyncMultiQueue(Generic[T]):
+    """An asyncio multiple reader, multiple writer queue.
+
+    Provides a generic queue for asyncio that supports multiple readers (via
+    async iterator) and multiple writers. Readers can start reading at any
+    time and will begin reading from the start of the queue. There is no
+    maximum size of the queue; new items will be added subject only to the
+    limits of available memory.
+
+    This data structure is not thread-safe. It uses only asyncio locking, not
+    thread-safe locking.
+
+    The ellipsis object (``...``) is used as a placeholder to indicate the end
+    of the queue, so cannot be pushed onto the queue.
+    """
+
+    def __init__(self) -> None:
+        self._contents: list[T | EllipsisType] = []
+        self._triggers: list[asyncio.Event] = []
+
+    def __aiter__(self) -> AsyncIterator[T]:
+        """Return an async iterator over the queue."""
+        return self.aiter_from(0)
+
+    def aiter_from(
+        self, start: int, timeout: timedelta | None = None
+    ) -> AsyncIterator[T]:
+        """Return an async iterator over the queue.
+
+        Each call to this function returns a separate iterator over the same
+        underlying contents, and each iterator will be triggered separately.
+
+        Parameters
+        ----------
+        start
+            Starting position in the queue. This can be larger than the
+            current queue size, in which case no items are returned until the
+            queue passes the given starting position.
+        timeout
+            If given, total length of time for the iterator. This isn't the
+            timeout waiting for the next item; this is the total execution
+            time of the iterator.
+
+        Raises
+        ------
+        TimeoutError
+            Raised when the timeout is reached.
+        """
+        if timeout:
+            end_time = current_datetime(microseconds=True) + timeout
+        else:
+            end_time = None
+
+        # Grab a reference to the current contents so that the iterator
+        # detaches from the contents on clear.
+        contents = self._contents
+
+        # Add a trigger for this caller and make sure it's set if there are
+        # any existing contents.
+        trigger = asyncio.Event()
+        if contents:
+            trigger.set()
+        self._triggers.append(trigger)
+
+        # Construct the iteartor, which waits for the trigger and returns any
+        # new events until it sees the placeholder for the end of the queue
+        # (the ellipsis object).
+        async def iterator() -> AsyncIterator[T]:
+            position = start
+            try:
+                while True:
+                    trigger.clear()
+                    end = len(contents)
+                    if position < end:
+                        for item in contents[position:end]:
+                            if item is Ellipsis:
+                                return
+                            yield item
+                        position = end
+                    elif contents and contents[-1] is Ellipsis:
+                        return
+                    if end_time:
+                        now = current_datetime(microseconds=True)
+                        timeout_left = (end_time - now).total_seconds()
+                        async with asyncio.timeout(timeout_left):
+                            await trigger.wait()
+                    else:
+                        await trigger.wait()
+            finally:
+                self._triggers = [t for t in self._triggers if t != trigger]
+
+        return iterator()
+
+    def clear(self) -> None:
+        """Empty the contents of the queue.
+
+        Any existing readers will still see all items pushed to the queue
+        before the clear, but will become detached from the queue and will not
+        see any new events added after the clear.
+        """
+        contents = self._contents
+        triggers = self._triggers
+        self._contents = []
+        self._triggers = []
+        contents.append(Ellipsis)
+        for trigger in triggers:
+            trigger.set()
+
+    def put(self, item: T) -> None:
+        """Add an item to the queue.
+
+        Parameters
+        ----------
+        item
+           Item to add.
+        """
+        self._contents.append(item)
+        for trigger in self._triggers:
+            trigger.set()
+
+    def qsize(self) -> int:
+        """Return the number of items currently in the queue."""
+        return len(self._contents)
 
 
 def run_with_asyncio(
diff --git a/src/safir/testing/kubernetes.py b/src/safir/testing/kubernetes.py
index 5346e66d..daf91dbf 100644
--- a/src/safir/testing/kubernetes.py
+++ b/src/safir/testing/kubernetes.py
@@ -2,7 +2,6 @@
 
 from __future__ import annotations
 
-import asyncio
 import copy
 import json
 import os
@@ -45,7 +44,7 @@
     V1Status,
 )
 
-from ..datetime import current_datetime
+from ..asyncio import AsyncMultiQueue
 
 __all__ = [
     "MockKubernetesApi",
@@ -166,14 +165,14 @@ class _EventStream:
     a stream of watchable events and a list of `asyncio.Event` triggers. A
     watch can register interest in this event stream, in which case its
     trigger will be notified when anything new is added to the event stream.
+
     The events are generic dicts, which will be interpreted by the Kubernetes
     library differently depending on which underlying API is using this data
     structure.
     """
 
     def __init__(self) -> None:
-        self._events: list[dict[str, Any]] = []
-        self._triggers: list[asyncio.Event] = []
+        self._queue = AsyncMultiQueue[dict[str, Any]]()
 
     @property
     def next_resource_version(self) -> str:
@@ -183,7 +182,7 @@ def next_resource_version(self) -> str:
         special and means to return all known events, so it must be adjusted
         when indexing into a list of events.
         """
-        return str(len(self._events) + 1)
+        return str(self._queue.qsize() + 1)
 
     def add_event(self, event: dict[str, Any]) -> None:
         """Add a new event and notify all watchers.
@@ -193,9 +192,7 @@ def add_event(self, event: dict[str, Any]) -> None:
         event
             New event.
         """
-        self._events.append(event)
-        for trigger in self._triggers:
-            trigger.set()
+        self._queue.put(event)
 
     def build_watch_response(
         self,
@@ -247,7 +244,7 @@ async def readline() -> bytes:
         response.content.readline.side_effect = readline
         return response
 
-    def _build_watcher(  # noqa: C901
+    def _build_watcher(
         self,
         resource_version: str | None,
         timeout_seconds: int | None,
@@ -287,7 +284,7 @@ def _build_watcher(  # noqa: C901
         """
         timeout = None
         if timeout_seconds is not None:
-            timeout = current_datetime() + timedelta(seconds=timeout_seconds)
+            timeout = timedelta(seconds=timeout_seconds)
 
         # Parse the field selector, if one was provided.
         name = None
@@ -297,42 +294,22 @@ def _build_watcher(  # noqa: C901
             assert match.group(1)
             name = match.group(1)
 
-        # Create and register a new trigger.
-        trigger = asyncio.Event()
-        self._triggers.append(trigger)
-
         # Construct the iterator.
         async def next_event() -> AsyncIterator[bytes]:
             if resource_version:
-                position = int(resource_version)
+                start = int(resource_version)
             else:
-                position = len(self._events)
-            while True:
-                for event in self._events[position:]:
-                    position += 1
+                start = self._queue.qsize()
+            try:
+                async for event in self._queue.aiter_from(start, timeout):
                     if name and event["object"]["metadata"]["name"] != name:
                         continue
-                    if not _check_labels(
-                        event["object"]["metadata"]["labels"], label_selector
-                    ):
+                    labels = event["object"]["metadata"]["labels"]
+                    if not _check_labels(labels, label_selector):
                         continue
                     yield json.dumps(event).encode()
-                if not timeout:
-                    await trigger.wait()
-                else:
-                    now = current_datetime()
-                    timeout_left = (timeout - now).total_seconds()
-                    if timeout_left <= 0:
-                        yield b""
-                        break
-                    try:
-                        async with asyncio.timeout(timeout_left):
-                            await trigger.wait()
-                    except TimeoutError:
-                        yield b""
-                        break
-                trigger.clear()
-            self._triggers = [t for t in self._triggers if t != trigger]
+            except TimeoutError:
+                yield b""
 
         # Return the iterator.
         return next_event()
diff --git a/tests/asyncio_test.py b/tests/asyncio_test.py
index 118898e1..9e5783e0 100644
--- a/tests/asyncio_test.py
+++ b/tests/asyncio_test.py
@@ -2,7 +2,56 @@
 
 from __future__ import annotations
 
-from safir.asyncio import run_with_asyncio
+import asyncio
+
+import pytest
+
+from safir.asyncio import AsyncMultiQueue, run_with_asyncio
+
+
+@pytest.mark.asyncio
+async def test_async_multi_queue() -> None:
+    queue = AsyncMultiQueue[str]()
+    queue.put("one")
+    queue.put("two")
+    assert queue.qsize() == 2
+
+    async def watcher(position: int) -> list[str]:
+        result = []
+        async for string in queue.aiter_from(position):
+            result.append(string)  # noqa: PERF402
+        return result
+
+    async def watcher_iter() -> list[str]:
+        result = []
+        async for string in queue:
+            result.append(string)  # noqa: PERF402
+        return result
+
+    start_task = asyncio.create_task(watcher_iter())
+    one_task = asyncio.create_task(watcher(1))
+    current_task = asyncio.create_task(watcher(queue.qsize()))
+    future_task = asyncio.create_task(watcher(3))
+    way_future_task = asyncio.create_task(watcher(10))
+
+    await asyncio.sleep(0.1)
+    queue.put("three")
+    await asyncio.sleep(0.1)
+    queue.put("four")
+    await asyncio.sleep(0.1)
+    queue.clear()
+    assert queue.qsize() == 0
+
+    after_clear_task = asyncio.create_task(watcher_iter())
+    await asyncio.sleep(0.1)
+    queue.clear()
+
+    assert await start_task == ["one", "two", "three", "four"]
+    assert await one_task == ["two", "three", "four"]
+    assert await current_task == ["three", "four"]
+    assert await future_task == ["four"]
+    assert await way_future_task == []
+    assert await after_clear_task == []
 
 
 def test_run_with_asyncio() -> None: