Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ✨ fix redis stream empty subscribe bug && supported redis stream cached (#148) #149

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ test.db
venv/
build/
dist/
.idea/
.vscode/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ Python 3.8+
* `Broadcast('memory://')`
* `Broadcast("redis://localhost:6379")`
* `Broadcast("redis-stream://localhost:6379")`
* `Broadcast("redis-stream-cached://localhost:6379")`
* `Broadcast("postgres://localhost:5432/broadcaster")`
* `Broadcast("kafka://localhost:9092")`

Expand Down
3 changes: 2 additions & 1 deletion broadcaster/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._base import Broadcast, Event
from ._base import Broadcast
from ._event import Event
from .backends.base import BroadcastBackend

__version__ = "0.3.1"
Expand Down
37 changes: 23 additions & 14 deletions broadcaster/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,12 @@
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, cast
from urllib.parse import urlparse

if TYPE_CHECKING: # pragma: no cover
from broadcaster.backends.base import BroadcastBackend


class Event:
def __init__(self, channel: str, message: str) -> None:
self.channel = channel
self.message = message
from broadcaster.backends.base import BroadcastCacheBackend

def __eq__(self, other: object) -> bool:
return isinstance(other, Event) and self.channel == other.channel and self.message == other.message
from ._event import Event

def __repr__(self) -> str:
return f"Event(channel={self.channel!r}, message={self.message!r})"
if TYPE_CHECKING: # pragma: no cover
from broadcaster.backends.base import BroadcastBackend


class Unsubscribed(Exception):
Expand All @@ -43,6 +35,11 @@ def _create_backend(self, url: str) -> BroadcastBackend:

return RedisStreamBackend(url)

elif parsed_url.scheme == "redis-stream-cached":
from broadcaster.backends.redis import RedisStreamCachedBackend

return RedisStreamCachedBackend(url)

elif parsed_url.scheme in ("postgres", "postgresql"):
from broadcaster.backends.postgres import PostgresBackend

Expand Down Expand Up @@ -87,15 +84,27 @@ async def publish(self, channel: str, message: Any) -> None:
await self._backend.publish(channel, message)

@asynccontextmanager
async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]:
async def subscribe(self, channel: str, history: int | None = None) -> AsyncIterator[Subscriber]:
queue: asyncio.Queue[Event | None] = asyncio.Queue()

try:
if not self._subscribers.get(channel):
await self._backend.subscribe(channel)
self._subscribers[channel] = {queue}
else:
self._subscribers[channel].add(queue)
if isinstance(self._backend, BroadcastCacheBackend):
try:
current_id = await self._backend.get_current_channel_id(channel)
self._backend._ready.clear()
for message in await self._backend.get_history_messages(channel, current_id, history):
queue.put_nowait(message)
self._subscribers[channel].add(queue)
finally:
# wake up the listener after inqueue history messages
# for sorted messages by publish time
self._backend._ready.set()
else:
self._subscribers[channel].add(queue)

yield Subscriber(queue)
finally:
Expand Down
10 changes: 10 additions & 0 deletions broadcaster/_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class Event:
def __init__(self, channel: str, message: str) -> None:
self.channel = channel
self.message = message

def __eq__(self, other: object) -> bool:
return isinstance(other, Event) and self.channel == other.channel and self.message == other.message

def __repr__(self) -> str:
return f"Event(channel={self.channel!r}, message={self.message!r})"
20 changes: 19 additions & 1 deletion broadcaster/backends/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import asyncio
from typing import Any

from .._base import Event
from .._event import Event


class BroadcastBackend:
Expand All @@ -24,3 +27,18 @@ async def publish(self, channel: str, message: Any) -> None:

async def next_published(self) -> Event:
raise NotImplementedError()


class BroadcastCacheBackend(BroadcastBackend):
_ready: asyncio.Event

async def get_current_channel_id(self, channel: str) -> str | bytes | memoryview | int:
raise NotImplementedError()

async def get_history_messages(
self,
channel: str,
msg_id: int | bytes | str | memoryview,
count: int | None = None,
) -> list[Event]:
raise NotImplementedError()
2 changes: 1 addition & 1 deletion broadcaster/backends/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from aiokafka import AIOKafkaConsumer, AIOKafkaProducer

from .._base import Event
from .._event import Event
from .base import BroadcastBackend


Expand Down
2 changes: 1 addition & 1 deletion broadcaster/backends/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import typing

from .._base import Event
from .._event import Event
from .base import BroadcastBackend


Expand Down
2 changes: 1 addition & 1 deletion broadcaster/backends/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import asyncpg

from .._base import Event
from .._event import Event
from .base import BroadcastBackend


Expand Down
86 changes: 83 additions & 3 deletions broadcaster/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from redis import asyncio as redis

from .._base import Event
from .base import BroadcastBackend
from .._event import Event
from .base import BroadcastBackend, BroadcastCacheBackend


class RedisBackend(BroadcastBackend):
Expand Down Expand Up @@ -88,14 +88,20 @@ async def subscribe(self, channel: str) -> None:

async def unsubscribe(self, channel: str) -> None:
self.streams.pop(channel, None)
if not self.streams:
self._ready.clear()

async def publish(self, channel: str, message: typing.Any) -> None:
await self._producer.xadd(channel, {"message": message})

async def wait_for_messages(self) -> list[StreamMessageType]:
await self._ready.wait()
messages = None
while not messages:
if not self.streams:
# 1. save cpu usage
# 2. redis raise expection when self.streams is empty
self._ready.clear()
await self._ready.wait()
messages = await self._consumer.xread(self.streams, count=1, block=100)
return messages

Expand All @@ -108,3 +114,77 @@ async def next_published(self) -> Event:
channel=stream.decode("utf-8"),
message=message.get(b"message", b"").decode("utf-8"),
)


class RedisStreamCachedBackend(BroadcastCacheBackend):
def __init__(self, url: str):
url = url.replace("redis-stream-cached", "redis", 1)
self.streams: dict[bytes | str | memoryview, int | bytes | str | memoryview] = {}
self._ready = asyncio.Event()
self._producer = redis.Redis.from_url(url)
self._consumer = redis.Redis.from_url(url)

async def connect(self) -> None:
pass

async def disconnect(self) -> None:
await self._producer.aclose()
await self._consumer.aclose()

async def subscribe(self, channel: str) -> None:
# read from beginning
last_id = "0"
self.streams[channel] = last_id
self._ready.set()

async def unsubscribe(self, channel: str) -> None:
self.streams.pop(channel, None)
if not self.streams:
self._ready.clear()

async def publish(self, channel: str, message: typing.Any) -> None:
await self._producer.xadd(channel, {"message": message})

async def wait_for_messages(self) -> list[StreamMessageType]:
messages = None
while not messages:
if not self.streams:
# 1. save cpu usage
# 2. redis raise expection when self.streams is empty
self._ready.clear()
await self._ready.wait()
messages = await self._consumer.xread(self.streams, count=1, block=100)
return messages

async def next_published(self) -> Event:
messages = await self.wait_for_messages()
stream, events = messages[0]
_msg_id, message = events[0]
self.streams[stream.decode("utf-8")] = _msg_id.decode("utf-8")
return Event(
channel=stream.decode("utf-8"),
message=message.get(b"message", b"").decode("utf-8"),
)

async def get_current_channel_id(self, channel: str) -> int | bytes | str | memoryview:
try:
info = await self._consumer.xinfo_stream(channel)
last_id: int | bytes | str | memoryview = info["last-generated-id"]
except redis.ResponseError:
last_id = "0"
return last_id

async def get_history_messages(
self,
channel: str,
msg_id: int | bytes | str | memoryview,
count: int | None = None,
) -> list[Event]:
messages = await self._consumer.xrevrange(channel, max=msg_id, count=count)
return [
Event(
channel=channel,
message=message.get(b"message", b"").decode("utf-8"),
)
for _, message in reversed(messages or [])
]
4 changes: 3 additions & 1 deletion example/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,7 @@ async def chatroom_ws_sender(websocket):


app = Starlette(
routes=routes, on_startup=[broadcast.connect], on_shutdown=[broadcast.disconnect],
routes=routes,
on_startup=[broadcast.connect],
on_shutdown=[broadcast.disconnect],
)
25 changes: 25 additions & 0 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,31 @@ async def test_redis_stream():
assert event.message == "hello"


@pytest.mark.asyncio
async def test_redis_stream_cache():
messages = ["hello", "I'm cached"]
async with Broadcast("redis-stream-cached://localhost:6379") as broadcast:
await broadcast.publish("chatroom_cached", messages[0])
await broadcast.publish("chatroom_cached", messages[1])
await broadcast.publish("chatroom_cached", "quit")
sub1_messages = []
async with broadcast.subscribe("chatroom_cached") as subscriber:
async for event in subscriber:
if event:
if event.message == "quit":
break
sub1_messages.append(event.message)
sub2_messages = []
async with broadcast.subscribe("chatroom_cached") as subscriber:
async for event in subscriber:
if event:
if event.message == "quit":
break
sub2_messages.append(event.message)

assert sub1_messages == sub2_messages == messages


@pytest.mark.asyncio
async def test_postgres():
async with Broadcast("postgres://postgres:postgres@localhost:5432/broadcaster") as broadcast:
Expand Down
Loading