From cebfe802ad3e6c760a3b969c81ffa195fb3f46bd Mon Sep 17 00:00:00 2001 From: Ben Cherry Date: Thu, 21 Nov 2024 16:18:34 -0800 Subject: [PATCH] Raise error if EventEmitter used with async callback (#312) --- livekit-rtc/livekit/rtc/event_emitter.py | 6 ++++++ livekit-rtc/tests/test_emitter.py | 19 ++++++++++--------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/livekit-rtc/livekit/rtc/event_emitter.py b/livekit-rtc/livekit/rtc/event_emitter.py index 53e8c67b..7e62aca0 100644 --- a/livekit-rtc/livekit/rtc/event_emitter.py +++ b/livekit-rtc/livekit/rtc/event_emitter.py @@ -1,4 +1,5 @@ import inspect +import asyncio from typing import Callable, Dict, Set, Optional, Generic, TypeVar from .log import logger @@ -156,6 +157,11 @@ def greet(name): ``` """ if callback is not None: + if asyncio.iscoroutinefunction(callback): + raise ValueError( + "Cannot register an async callback with `.on()`. Use `asyncio.create_task` within your synchronous callback instead." + ) + if event not in self._events: self._events[event] = set() self._events[event].add(callback) diff --git a/livekit-rtc/tests/test_emitter.py b/livekit-rtc/tests/test_emitter.py index f626e8c7..830feb3a 100644 --- a/livekit-rtc/tests/test_emitter.py +++ b/livekit-rtc/tests/test_emitter.py @@ -60,26 +60,27 @@ def on_whatever(first, second, third): emitter.emit("whatever", 1, 2, 3) emitter.emit("whatever", 1, 2, 3, 4, 5) # only 3 arguments will be passed - assert len(calls) == 2 - assert calls[0] == (1, 2, 3) - assert calls[1] == (1, 2, 3) - - calls = [] + assert calls == [(1, 2, 3), (1, 2, 3)] with pytest.raises(TypeError): emitter.emit("whatever", 1, 2) - assert len(calls) == 0 + +def test_varargs(): + EventTypes = Literal["whatever"] + + emitter = EventEmitter[EventTypes]() + + calls = [] @emitter.on("whatever") def on_whatever_varargs(*args): calls.append(args) emitter.emit("whatever", 1, 2, 3, 4, 5) + emitter.emit("whatever", 1, 2) - assert len(calls) == 2 - assert calls[0] == (1, 2, 3) - assert calls[1] == (1, 2, 3, 4, 5) + assert calls == [(1, 2, 3, 4, 5), (1, 2)] def test_throw():