Skip to content

Commit

Permalink
fix typing of track_publications using covariant Mapping (#287)
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom authored Oct 9, 2024
1 parent 135edaa commit c1111ac
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
33 changes: 24 additions & 9 deletions livekit-rtc/livekit/rtc/participant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from __future__ import annotations

import ctypes
from typing import List, Union
from typing import List, Mapping, Union
from abc import abstractmethod, ABC

from ._ffi_client import FfiClient, FfiHandle
from ._proto import ffi_pb2 as proto_ffi
Expand Down Expand Up @@ -61,18 +62,18 @@ def __init__(self, message: str) -> None:
self.message = message


class Participant:
class Participant(ABC):
def __init__(self, owned_info: proto_participant.OwnedParticipant) -> None:
self._info = owned_info.info
self._ffi_handle = FfiHandle(owned_info.handle.id)
self._track_publications: dict[str, TrackPublication] = {}

@property
def track_publications(self) -> dict[str, TrackPublication]:
@abstractmethod
def track_publications(self) -> Mapping[str, TrackPublication]:
"""
A dictionary of track publications associated with the participant.
"""
return self._track_publications
...

@property
def sid(self) -> str:
Expand Down Expand Up @@ -111,7 +112,14 @@ def __init__(
) -> None:
super().__init__(owned_info)
self._room_queue = room_queue
self.track_publications: dict[str, LocalTrackPublication] = {} # type: ignore
self._track_publications: dict[str, LocalTrackPublication] = {} # type: ignore

@property
def track_publications(self) -> Mapping[str, LocalTrackPublication]:
"""
A dictionary of track publications associated with the participant.
"""
return self._track_publications

async def publish_data(
self,
Expand Down Expand Up @@ -330,7 +338,7 @@ async def publish_track(
track_publication = LocalTrackPublication(cb.publish_track.publication)
track_publication.track = track
track._info.sid = track_publication.sid
self.track_publications[track_publication.sid] = track_publication
self._track_publications[track_publication.sid] = track_publication

queue.task_done()
return track_publication
Expand Down Expand Up @@ -361,7 +369,7 @@ async def unpublish_track(self, track_sid: str) -> None:
if cb.unpublish_track.error:
raise UnpublishTrackError(cb.unpublish_track.error)

publication = self.track_publications.pop(track_sid)
publication = self._track_publications.pop(track_sid)
publication.track = None
queue.task_done()
finally:
Expand All @@ -374,7 +382,14 @@ def __repr__(self) -> str:
class RemoteParticipant(Participant):
def __init__(self, owned_info: proto_participant.OwnedParticipant) -> None:
super().__init__(owned_info)
self.track_publications: dict[str, RemoteTrackPublication] = {} # type: ignore
self._track_publications: dict[str, RemoteTrackPublication] = {} # type: ignore

@property
def track_publications(self) -> Mapping[str, RemoteTrackPublication]:
"""
A dictionary of track publications associated with the participant.
"""
return self._track_publications

def __repr__(self) -> str:
return f"rtc.RemoteParticipant(sid={self.sid}, identity={self.identity}, name={self.name})"
10 changes: 5 additions & 5 deletions livekit-rtc/livekit/rtc/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import ctypes
import logging
from dataclasses import dataclass, field
from typing import Callable, Dict, Literal, Optional, cast
from typing import Callable, Dict, Literal, Optional, cast, Mapping

from .event_emitter import EventEmitter
from ._ffi_client import FfiClient, FfiHandle
Expand Down Expand Up @@ -174,7 +174,7 @@ def connection_state(self) -> ConnectionState.ValueType:
return self._connection_state

@property
def remote_participants(self) -> dict[str, RemoteParticipant]:
def remote_participants(self) -> Mapping[str, RemoteParticipant]:
"""Gets the remote participants in the room.
Returns:
Expand Down Expand Up @@ -389,7 +389,7 @@ def on_participant_connected(participant):
# add the initial remote participant tracks
for owned_publication_info in pt.publications:
publication = RemoteTrackPublication(owned_publication_info)
rp.track_publications[publication.sid] = publication
rp._track_publications[publication.sid] = publication

# start listening to room events
self._task = self._loop.create_task(self._listen_task())
Expand Down Expand Up @@ -466,13 +466,13 @@ def _on_room_event(self, event: proto_room.RoomEvent):
event.track_published.participant_identity
]
rpublication = RemoteTrackPublication(event.track_published.publication)
rparticipant.track_publications[rpublication.sid] = rpublication
rparticipant._track_publications[rpublication.sid] = rpublication
self.emit("track_published", rpublication, rparticipant)
elif which == "track_unpublished":
rparticipant = self._remote_participants[
event.track_unpublished.participant_identity
]
rpublication = rparticipant.track_publications.pop(
rpublication = rparticipant._track_publications.pop(
event.track_unpublished.publication_sid
)
self.emit("track_unpublished", rpublication, rparticipant)
Expand Down

0 comments on commit c1111ac

Please sign in to comment.