From c1111ac36a74da459f4909ff83c0ec4cee668ecd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Monnom?= Date: Wed, 9 Oct 2024 16:58:17 -0700 Subject: [PATCH] fix typing of track_publications using covariant Mapping (#287) --- livekit-rtc/livekit/rtc/participant.py | 33 +++++++++++++++++++------- livekit-rtc/livekit/rtc/room.py | 10 ++++---- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/livekit-rtc/livekit/rtc/participant.py b/livekit-rtc/livekit/rtc/participant.py index cca7c02c..0314a39c 100644 --- a/livekit-rtc/livekit/rtc/participant.py +++ b/livekit-rtc/livekit/rtc/participant.py @@ -15,7 +15,8 @@ from __future__ import annotations import ctypes -from typing import List, Union +from typing import List, Mapping, Union +from abc import abstractmethod, ABC from ._ffi_client import FfiClient, FfiHandle from ._proto import ffi_pb2 as proto_ffi @@ -61,18 +62,18 @@ def __init__(self, message: str) -> None: self.message = message -class Participant: +class Participant(ABC): def __init__(self, owned_info: proto_participant.OwnedParticipant) -> None: self._info = owned_info.info self._ffi_handle = FfiHandle(owned_info.handle.id) - self._track_publications: dict[str, TrackPublication] = {} @property - def track_publications(self) -> dict[str, TrackPublication]: + @abstractmethod + def track_publications(self) -> Mapping[str, TrackPublication]: """ A dictionary of track publications associated with the participant. """ - return self._track_publications + ... @property def sid(self) -> str: @@ -111,7 +112,14 @@ def __init__( ) -> None: super().__init__(owned_info) self._room_queue = room_queue - self.track_publications: dict[str, LocalTrackPublication] = {} # type: ignore + self._track_publications: dict[str, LocalTrackPublication] = {} # type: ignore + + @property + def track_publications(self) -> Mapping[str, LocalTrackPublication]: + """ + A dictionary of track publications associated with the participant. + """ + return self._track_publications async def publish_data( self, @@ -330,7 +338,7 @@ async def publish_track( track_publication = LocalTrackPublication(cb.publish_track.publication) track_publication.track = track track._info.sid = track_publication.sid - self.track_publications[track_publication.sid] = track_publication + self._track_publications[track_publication.sid] = track_publication queue.task_done() return track_publication @@ -361,7 +369,7 @@ async def unpublish_track(self, track_sid: str) -> None: if cb.unpublish_track.error: raise UnpublishTrackError(cb.unpublish_track.error) - publication = self.track_publications.pop(track_sid) + publication = self._track_publications.pop(track_sid) publication.track = None queue.task_done() finally: @@ -374,7 +382,14 @@ def __repr__(self) -> str: class RemoteParticipant(Participant): def __init__(self, owned_info: proto_participant.OwnedParticipant) -> None: super().__init__(owned_info) - self.track_publications: dict[str, RemoteTrackPublication] = {} # type: ignore + self._track_publications: dict[str, RemoteTrackPublication] = {} # type: ignore + + @property + def track_publications(self) -> Mapping[str, RemoteTrackPublication]: + """ + A dictionary of track publications associated with the participant. + """ + return self._track_publications def __repr__(self) -> str: return f"rtc.RemoteParticipant(sid={self.sid}, identity={self.identity}, name={self.name})" diff --git a/livekit-rtc/livekit/rtc/room.py b/livekit-rtc/livekit/rtc/room.py index deafb828..8e3e7f51 100644 --- a/livekit-rtc/livekit/rtc/room.py +++ b/livekit-rtc/livekit/rtc/room.py @@ -17,7 +17,7 @@ import ctypes import logging from dataclasses import dataclass, field -from typing import Callable, Dict, Literal, Optional, cast +from typing import Callable, Dict, Literal, Optional, cast, Mapping from .event_emitter import EventEmitter from ._ffi_client import FfiClient, FfiHandle @@ -174,7 +174,7 @@ def connection_state(self) -> ConnectionState.ValueType: return self._connection_state @property - def remote_participants(self) -> dict[str, RemoteParticipant]: + def remote_participants(self) -> Mapping[str, RemoteParticipant]: """Gets the remote participants in the room. Returns: @@ -389,7 +389,7 @@ def on_participant_connected(participant): # add the initial remote participant tracks for owned_publication_info in pt.publications: publication = RemoteTrackPublication(owned_publication_info) - rp.track_publications[publication.sid] = publication + rp._track_publications[publication.sid] = publication # start listening to room events self._task = self._loop.create_task(self._listen_task()) @@ -466,13 +466,13 @@ def _on_room_event(self, event: proto_room.RoomEvent): event.track_published.participant_identity ] rpublication = RemoteTrackPublication(event.track_published.publication) - rparticipant.track_publications[rpublication.sid] = rpublication + rparticipant._track_publications[rpublication.sid] = rpublication self.emit("track_published", rpublication, rparticipant) elif which == "track_unpublished": rparticipant = self._remote_participants[ event.track_unpublished.participant_identity ] - rpublication = rparticipant.track_publications.pop( + rpublication = rparticipant._track_publications.pop( event.track_unpublished.publication_sid ) self.emit("track_unpublished", rpublication, rparticipant)