From d37e8ebba3670b888a92931852492329c822ffa0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Monnom?= Date: Fri, 27 Sep 2024 21:27:05 -0700 Subject: [PATCH] quality of life improvements (#267) --- livekit-rtc/livekit/rtc/audio_frame.py | 91 ++++++++++++++++++++ livekit-rtc/livekit/rtc/participant.py | 6 ++ livekit-rtc/livekit/rtc/room.py | 7 ++ livekit-rtc/livekit/rtc/track.py | 12 +++ livekit-rtc/livekit/rtc/track_publication.py | 6 ++ livekit-rtc/livekit/rtc/video_frame.py | 90 +++++++++++++++++++ 6 files changed, 212 insertions(+) diff --git a/livekit-rtc/livekit/rtc/audio_frame.py b/livekit-rtc/livekit/rtc/audio_frame.py index da7a2141..e6b8088b 100644 --- a/livekit-rtc/livekit/rtc/audio_frame.py +++ b/livekit-rtc/livekit/rtc/audio_frame.py @@ -21,6 +21,11 @@ class AudioFrame: + """ + A class that represents a frame of audio data with specific properties such as sample rate, + number of channels, and samples per channel. + """ + def __init__( self, data: Union[bytes, bytearray, memoryview], @@ -28,6 +33,19 @@ def __init__( num_channels: int, samples_per_channel: int, ) -> None: + """ + Initialize an AudioFrame instance. + + Args: + data (Union[bytes, bytearray, memoryview]): The raw audio data, which must be at least + `num_channels * samples_per_channel * sizeof(int16)` bytes long. + sample_rate (int): The sample rate of the audio in Hz. + num_channels (int): The number of audio channels (e.g., 1 for mono, 2 for stereo). + samples_per_channel (int): The number of samples per channel. + + Raises: + ValueError: If the length of `data` is smaller than the required size. + """ if len(data) < num_channels * samples_per_channel * ctypes.sizeof( ctypes.c_int16 ): @@ -44,6 +62,18 @@ def __init__( def create( sample_rate: int, num_channels: int, samples_per_channel: int ) -> "AudioFrame": + """ + Create a new empty AudioFrame instance with specified sample rate, number of channels, + and samples per channel. + + Args: + sample_rate (int): The sample rate of the audio in Hz. + num_channels (int): The number of audio channels (e.g., 1 for mono, 2 for stereo). + samples_per_channel (int): The number of samples per channel. + + Returns: + AudioFrame: A new AudioFrame instance with uninitialized (zeroed) data. + """ size = num_channels * samples_per_channel * ctypes.sizeof(ctypes.c_int16) data = bytearray(size) return AudioFrame(data, sample_rate, num_channels, samples_per_channel) @@ -91,16 +121,77 @@ def _proto_info(self) -> proto_audio.AudioFrameBufferInfo: @property def data(self) -> memoryview: + """ + Returns a memory view of the audio data as 16-bit signed integers. + + Returns: + memoryview: A memory view of the audio data. + """ return memoryview(self._data).cast("h") @property def sample_rate(self) -> int: + """ + Returns the sample rate of the audio frame. + + Returns: + int: The sample rate in Hz. + """ return self._sample_rate @property def num_channels(self) -> int: + """ + Returns the number of channels in the audio frame. + + Returns: + int: The number of audio channels (e.g., 1 for mono, 2 for stereo). + """ return self._num_channels @property def samples_per_channel(self) -> int: + """ + Returns the number of samples per channel. + + Returns: + int: The number of samples per channel. + """ return self._samples_per_channel + + @property + def duration(self) -> float: + """ + Returns the duration of the audio frame in seconds. + + Returns: + float: The duration in seconds. + """ + return self.samples_per_channel / self.sample_rate + + def to_wav_bytes(self) -> bytes: + """ + Convert the audio frame data to a WAV-formatted byte stream. + + Returns: + bytes: The audio data encoded in WAV format. + """ + import wave + import io + + with io.BytesIO() as wav_file: + with wave.open(wav_file, "wb") as wav: + wav.setnchannels(self.num_channels) + wav.setsampwidth(2) + wav.setframerate(self.sample_rate) + wav.writeframes(self._data) + + return wav_file.getvalue() + + def __repr__(self) -> str: + return ( + f"rtc.AudioFrame(sample_rate={self.sample_rate}, " + f"num_channels={self.num_channels}, " + f"samples_per_channel={self.samples_per_channel}, " + f"duration={self.duration:.3f})" + ) diff --git a/livekit-rtc/livekit/rtc/participant.py b/livekit-rtc/livekit/rtc/participant.py index 3a1c874b..49be9395 100644 --- a/livekit-rtc/livekit/rtc/participant.py +++ b/livekit-rtc/livekit/rtc/participant.py @@ -256,8 +256,14 @@ async def unpublish_track(self, track_sid: str) -> None: finally: self._room_queue.unsubscribe(queue) + def __repr__(self) -> str: + return f"rtc.LocalParticipant(sid={self.sid}, identity={self.identity}, name={self.name})" + class RemoteParticipant(Participant): def __init__(self, owned_info: proto_participant.OwnedParticipant) -> None: super().__init__(owned_info) self.track_publications: dict[str, RemoteTrackPublication] = {} # type: ignore + + def __repr__(self) -> str: + return f"rtc.RemoteParticipant(sid={self.sid}, identity={self.identity}, name={self.name})" diff --git a/livekit-rtc/livekit/rtc/room.py b/livekit-rtc/livekit/rtc/room.py index d774cb17..82fd2958 100644 --- a/livekit-rtc/livekit/rtc/room.py +++ b/livekit-rtc/livekit/rtc/room.py @@ -530,3 +530,10 @@ def _create_remote_participant( participant = RemoteParticipant(owned_info) self.remote_participants[participant.identity] = participant return participant + + def __repr__(self) -> str: + sid = "unknown" + if self._first_sid_future.done(): + sid = self._first_sid_future.result() + + return f"rtc.Room(sid={sid}, name={self.name}, metadata={self.metadata}, connection_state={self.connection_state})" diff --git a/livekit-rtc/livekit/rtc/track.py b/livekit-rtc/livekit/rtc/track.py index abfe1adc..9b8fc270 100644 --- a/livekit-rtc/livekit/rtc/track.py +++ b/livekit-rtc/livekit/rtc/track.py @@ -80,6 +80,9 @@ def create_audio_track(name: str, source: "AudioSource") -> "LocalAudioTrack": resp = FfiClient.instance.request(req) return LocalAudioTrack(resp.create_audio_track.track) + def __repr__(self) -> str: + return f"rtc.LocalAudioTrack(sid={self.sid}, name={self.name})" + class LocalVideoTrack(Track): def __init__(self, info: proto_track.OwnedTrack): @@ -94,16 +97,25 @@ def create_video_track(name: str, source: "VideoSource") -> "LocalVideoTrack": resp = FfiClient.instance.request(req) return LocalVideoTrack(resp.create_video_track.track) + def __repr__(self) -> str: + return f"rtc.LocalVideoTrack(sid={self.sid}, name={self.name})" + class RemoteAudioTrack(Track): def __init__(self, info: proto_track.OwnedTrack): super().__init__(info) + def __repr__(self) -> str: + return f"rtc.RemoteAudioTrack(sid={self.sid}, name={self.name})" + class RemoteVideoTrack(Track): def __init__(self, info: proto_track.OwnedTrack): super().__init__(info) + def __repr__(self) -> str: + return f"rtc.RemoteVideoTrack(sid={self.sid}, name={self.name})" + LocalTrack = Union[LocalVideoTrack, LocalAudioTrack] RemoteTrack = Union[RemoteVideoTrack, RemoteAudioTrack] diff --git a/livekit-rtc/livekit/rtc/track_publication.py b/livekit-rtc/livekit/rtc/track_publication.py index 2bbd1a05..86b930f7 100644 --- a/livekit-rtc/livekit/rtc/track_publication.py +++ b/livekit-rtc/livekit/rtc/track_publication.py @@ -77,6 +77,9 @@ def __init__(self, owned_info: proto_track.OwnedTrackPublication): async def wait_for_subscription(self) -> None: await asyncio.shield(self._first_subscription) + def __repr__(self) -> str: + return f"rtc.LocalTrackPublication(sid={self.sid}, name={self.name}, kind={self.kind}, source={self.source})" + class RemoteTrackPublication(TrackPublication): def __init__(self, owned_info: proto_track.OwnedTrackPublication): @@ -88,3 +91,6 @@ def set_subscribed(self, subscribed: bool): req.set_subscribed.subscribe = subscribed req.set_subscribed.publication_handle = self._ffi_handle.handle FfiClient.instance.request(req) + + def __repr__(self) -> str: + return f"rtc.RemoteTrackPublication(sid={self.sid}, name={self.name}, kind={self.kind}, source={self.source})" diff --git a/livekit-rtc/livekit/rtc/video_frame.py b/livekit-rtc/livekit/rtc/video_frame.py index 225243d1..6eb7a308 100644 --- a/livekit-rtc/livekit/rtc/video_frame.py +++ b/livekit-rtc/livekit/rtc/video_frame.py @@ -22,6 +22,13 @@ class VideoFrame: + """ + Represents a video frame with associated metadata and pixel data. + + This class provides methods to access video frame properties such as width, height, + and pixel format, as well as methods for manipulating and converting video frames. + """ + def __init__( self, width: int, @@ -29,6 +36,16 @@ def __init__( type: proto_video.VideoBufferType.ValueType, data: Union[bytes, bytearray, memoryview], ) -> None: + """ + Initializes a new VideoFrame instance. + + Args: + width (int): The width of the video frame in pixels. + height (int): The height of the video frame in pixels. + type (proto_video.VideoBufferType.ValueType): The format type of the video frame data + (e.g., RGBA, BGRA, RGB24, etc.). + data (Union[bytes, bytearray, memoryview]): The raw pixel data for the video frame. + """ self._width = width self._height = height self._type = type @@ -36,18 +53,42 @@ def __init__( @property def width(self) -> int: + """ + Returns the width of the video frame in pixels. + + Returns: + int: The width of the video frame. + """ return self._width @property def height(self) -> int: + """ + Returns the height of the video frame in pixels. + + Returns: + int: The height of the video frame. + """ return self._height @property def type(self) -> proto_video.VideoBufferType.ValueType: + """ + Returns the height of the video frame in pixels. + + Returns: + int: The height of the video frame. + """ return self._type @property def data(self) -> memoryview: + """ + Returns a memoryview of the raw pixel data for the video frame. + + Returns: + memoryview: The raw pixel data of the video frame as a memoryview object. + """ return memoryview(self._data) @staticmethod @@ -89,6 +130,19 @@ def _proto_info(self) -> proto_video.VideoBufferInfo: return info def get_plane(self, plane_nth: int) -> Optional[memoryview]: + """ + Returns the memoryview of a specific plane in the video frame, based on its index. + + Some video formats (e.g., I420, NV12) contain multiple planes (Y, U, V channels). + This method allows access to individual planes by index. + + Args: + plane_nth (int): The index of the plane to retrieve (starting from 0). + + Returns: + Optional[memoryview]: A memoryview of the specified plane's data, or None if + the index is out of bounds for the format. + """ plane_infos = _get_plane_infos( get_address(self.data), self.type, self.width, self.height ) @@ -102,6 +156,39 @@ def get_plane(self, plane_nth: int) -> Optional[memoryview]: def convert( self, type: proto_video.VideoBufferType.ValueType, *, flip_y: bool = False ) -> "VideoFrame": + """ + Converts the current video frame to a different format type, optionally flipping + the frame vertically. + + Args: + type (proto_video.VideoBufferType.ValueType): The target format type to convert to + (e.g., RGBA, I420). + flip_y (bool, optional): If True, the frame will be flipped vertically. Defaults to False. + + Returns: + VideoFrame: A new VideoFrame object in the specified format. + + Raises: + Exception: If there is an error during the conversion process. + + Example: + Convert a frame from RGBA to I420 format: + + >>> frame = VideoFrame(width=1920, height=1080, type=proto_video.VideoBufferType.RGBA, data=raw_data) + >>> converted_frame = frame.convert(proto_video.VideoBufferType.I420) + >>> print(converted_frame.type) + VideoBufferType.I420 + + Example: + Convert a frame from BGRA to RGB24 format and flip it vertically: + + >>> frame = VideoFrame(width=1280, height=720, type=proto_video.VideoBufferType.BGRA, data=raw_data) + >>> converted_frame = frame.convert(proto_video.VideoBufferType.RGB24, flip_y=True) + >>> print(converted_frame.type) + VideoBufferType.RGB24 + >>> print(converted_frame.width, converted_frame.height) + 1280 720 + """ req = proto.FfiRequest() req.video_convert.flip_y = flip_y req.video_convert.dst_type = type @@ -112,6 +199,9 @@ def convert( return VideoFrame._from_owned_info(resp.video_convert.buffer) + def __repr__(self) -> str: + return f"rtc.VideoFrame(width={self.width}, height={self.height}, type={self.type})" + def _component_info( data_ptr: int, stride: int, size: int