Skip to content

Commit

Permalink
quality of life improvements (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom authored Sep 28, 2024
1 parent 2495e04 commit d37e8eb
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 0 deletions.
91 changes: 91 additions & 0 deletions livekit-rtc/livekit/rtc/audio_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,31 @@


class AudioFrame:
"""
A class that represents a frame of audio data with specific properties such as sample rate,
number of channels, and samples per channel.
"""

def __init__(
self,
data: Union[bytes, bytearray, memoryview],
sample_rate: int,
num_channels: int,
samples_per_channel: int,
) -> None:
"""
Initialize an AudioFrame instance.
Args:
data (Union[bytes, bytearray, memoryview]): The raw audio data, which must be at least
`num_channels * samples_per_channel * sizeof(int16)` bytes long.
sample_rate (int): The sample rate of the audio in Hz.
num_channels (int): The number of audio channels (e.g., 1 for mono, 2 for stereo).
samples_per_channel (int): The number of samples per channel.
Raises:
ValueError: If the length of `data` is smaller than the required size.
"""
if len(data) < num_channels * samples_per_channel * ctypes.sizeof(
ctypes.c_int16
):
Expand All @@ -44,6 +62,18 @@ def __init__(
def create(
sample_rate: int, num_channels: int, samples_per_channel: int
) -> "AudioFrame":
"""
Create a new empty AudioFrame instance with specified sample rate, number of channels,
and samples per channel.
Args:
sample_rate (int): The sample rate of the audio in Hz.
num_channels (int): The number of audio channels (e.g., 1 for mono, 2 for stereo).
samples_per_channel (int): The number of samples per channel.
Returns:
AudioFrame: A new AudioFrame instance with uninitialized (zeroed) data.
"""
size = num_channels * samples_per_channel * ctypes.sizeof(ctypes.c_int16)
data = bytearray(size)
return AudioFrame(data, sample_rate, num_channels, samples_per_channel)
Expand Down Expand Up @@ -91,16 +121,77 @@ def _proto_info(self) -> proto_audio.AudioFrameBufferInfo:

@property
def data(self) -> memoryview:
"""
Returns a memory view of the audio data as 16-bit signed integers.
Returns:
memoryview: A memory view of the audio data.
"""
return memoryview(self._data).cast("h")

@property
def sample_rate(self) -> int:
"""
Returns the sample rate of the audio frame.
Returns:
int: The sample rate in Hz.
"""
return self._sample_rate

@property
def num_channels(self) -> int:
"""
Returns the number of channels in the audio frame.
Returns:
int: The number of audio channels (e.g., 1 for mono, 2 for stereo).
"""
return self._num_channels

@property
def samples_per_channel(self) -> int:
"""
Returns the number of samples per channel.
Returns:
int: The number of samples per channel.
"""
return self._samples_per_channel

@property
def duration(self) -> float:
"""
Returns the duration of the audio frame in seconds.
Returns:
float: The duration in seconds.
"""
return self.samples_per_channel / self.sample_rate

def to_wav_bytes(self) -> bytes:
"""
Convert the audio frame data to a WAV-formatted byte stream.
Returns:
bytes: The audio data encoded in WAV format.
"""
import wave
import io

with io.BytesIO() as wav_file:
with wave.open(wav_file, "wb") as wav:
wav.setnchannels(self.num_channels)
wav.setsampwidth(2)
wav.setframerate(self.sample_rate)
wav.writeframes(self._data)

return wav_file.getvalue()

def __repr__(self) -> str:
return (
f"rtc.AudioFrame(sample_rate={self.sample_rate}, "
f"num_channels={self.num_channels}, "
f"samples_per_channel={self.samples_per_channel}, "
f"duration={self.duration:.3f})"
)
6 changes: 6 additions & 0 deletions livekit-rtc/livekit/rtc/participant.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,14 @@ async def unpublish_track(self, track_sid: str) -> None:
finally:
self._room_queue.unsubscribe(queue)

def __repr__(self) -> str:
return f"rtc.LocalParticipant(sid={self.sid}, identity={self.identity}, name={self.name})"


class RemoteParticipant(Participant):
def __init__(self, owned_info: proto_participant.OwnedParticipant) -> None:
super().__init__(owned_info)
self.track_publications: dict[str, RemoteTrackPublication] = {} # type: ignore

def __repr__(self) -> str:
return f"rtc.RemoteParticipant(sid={self.sid}, identity={self.identity}, name={self.name})"
7 changes: 7 additions & 0 deletions livekit-rtc/livekit/rtc/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,3 +530,10 @@ def _create_remote_participant(
participant = RemoteParticipant(owned_info)
self.remote_participants[participant.identity] = participant
return participant

def __repr__(self) -> str:
sid = "unknown"
if self._first_sid_future.done():
sid = self._first_sid_future.result()

return f"rtc.Room(sid={sid}, name={self.name}, metadata={self.metadata}, connection_state={self.connection_state})"
12 changes: 12 additions & 0 deletions livekit-rtc/livekit/rtc/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def create_audio_track(name: str, source: "AudioSource") -> "LocalAudioTrack":
resp = FfiClient.instance.request(req)
return LocalAudioTrack(resp.create_audio_track.track)

def __repr__(self) -> str:
return f"rtc.LocalAudioTrack(sid={self.sid}, name={self.name})"


class LocalVideoTrack(Track):
def __init__(self, info: proto_track.OwnedTrack):
Expand All @@ -94,16 +97,25 @@ def create_video_track(name: str, source: "VideoSource") -> "LocalVideoTrack":
resp = FfiClient.instance.request(req)
return LocalVideoTrack(resp.create_video_track.track)

def __repr__(self) -> str:
return f"rtc.LocalVideoTrack(sid={self.sid}, name={self.name})"


class RemoteAudioTrack(Track):
def __init__(self, info: proto_track.OwnedTrack):
super().__init__(info)

def __repr__(self) -> str:
return f"rtc.RemoteAudioTrack(sid={self.sid}, name={self.name})"


class RemoteVideoTrack(Track):
def __init__(self, info: proto_track.OwnedTrack):
super().__init__(info)

def __repr__(self) -> str:
return f"rtc.RemoteVideoTrack(sid={self.sid}, name={self.name})"


LocalTrack = Union[LocalVideoTrack, LocalAudioTrack]
RemoteTrack = Union[RemoteVideoTrack, RemoteAudioTrack]
Expand Down
6 changes: 6 additions & 0 deletions livekit-rtc/livekit/rtc/track_publication.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def __init__(self, owned_info: proto_track.OwnedTrackPublication):
async def wait_for_subscription(self) -> None:
await asyncio.shield(self._first_subscription)

def __repr__(self) -> str:
return f"rtc.LocalTrackPublication(sid={self.sid}, name={self.name}, kind={self.kind}, source={self.source})"


class RemoteTrackPublication(TrackPublication):
def __init__(self, owned_info: proto_track.OwnedTrackPublication):
Expand All @@ -88,3 +91,6 @@ def set_subscribed(self, subscribed: bool):
req.set_subscribed.subscribe = subscribed
req.set_subscribed.publication_handle = self._ffi_handle.handle
FfiClient.instance.request(req)

def __repr__(self) -> str:
return f"rtc.RemoteTrackPublication(sid={self.sid}, name={self.name}, kind={self.kind}, source={self.source})"
90 changes: 90 additions & 0 deletions livekit-rtc/livekit/rtc/video_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,73 @@


class VideoFrame:
"""
Represents a video frame with associated metadata and pixel data.
This class provides methods to access video frame properties such as width, height,
and pixel format, as well as methods for manipulating and converting video frames.
"""

def __init__(
self,
width: int,
height: int,
type: proto_video.VideoBufferType.ValueType,
data: Union[bytes, bytearray, memoryview],
) -> None:
"""
Initializes a new VideoFrame instance.
Args:
width (int): The width of the video frame in pixels.
height (int): The height of the video frame in pixels.
type (proto_video.VideoBufferType.ValueType): The format type of the video frame data
(e.g., RGBA, BGRA, RGB24, etc.).
data (Union[bytes, bytearray, memoryview]): The raw pixel data for the video frame.
"""
self._width = width
self._height = height
self._type = type
self._data = bytearray(data)

@property
def width(self) -> int:
"""
Returns the width of the video frame in pixels.
Returns:
int: The width of the video frame.
"""
return self._width

@property
def height(self) -> int:
"""
Returns the height of the video frame in pixels.
Returns:
int: The height of the video frame.
"""
return self._height

@property
def type(self) -> proto_video.VideoBufferType.ValueType:
"""
Returns the height of the video frame in pixels.
Returns:
int: The height of the video frame.
"""
return self._type

@property
def data(self) -> memoryview:
"""
Returns a memoryview of the raw pixel data for the video frame.
Returns:
memoryview: The raw pixel data of the video frame as a memoryview object.
"""
return memoryview(self._data)

@staticmethod
Expand Down Expand Up @@ -89,6 +130,19 @@ def _proto_info(self) -> proto_video.VideoBufferInfo:
return info

def get_plane(self, plane_nth: int) -> Optional[memoryview]:
"""
Returns the memoryview of a specific plane in the video frame, based on its index.
Some video formats (e.g., I420, NV12) contain multiple planes (Y, U, V channels).
This method allows access to individual planes by index.
Args:
plane_nth (int): The index of the plane to retrieve (starting from 0).
Returns:
Optional[memoryview]: A memoryview of the specified plane's data, or None if
the index is out of bounds for the format.
"""
plane_infos = _get_plane_infos(
get_address(self.data), self.type, self.width, self.height
)
Expand All @@ -102,6 +156,39 @@ def get_plane(self, plane_nth: int) -> Optional[memoryview]:
def convert(
self, type: proto_video.VideoBufferType.ValueType, *, flip_y: bool = False
) -> "VideoFrame":
"""
Converts the current video frame to a different format type, optionally flipping
the frame vertically.
Args:
type (proto_video.VideoBufferType.ValueType): The target format type to convert to
(e.g., RGBA, I420).
flip_y (bool, optional): If True, the frame will be flipped vertically. Defaults to False.
Returns:
VideoFrame: A new VideoFrame object in the specified format.
Raises:
Exception: If there is an error during the conversion process.
Example:
Convert a frame from RGBA to I420 format:
>>> frame = VideoFrame(width=1920, height=1080, type=proto_video.VideoBufferType.RGBA, data=raw_data)
>>> converted_frame = frame.convert(proto_video.VideoBufferType.I420)
>>> print(converted_frame.type)
VideoBufferType.I420
Example:
Convert a frame from BGRA to RGB24 format and flip it vertically:
>>> frame = VideoFrame(width=1280, height=720, type=proto_video.VideoBufferType.BGRA, data=raw_data)
>>> converted_frame = frame.convert(proto_video.VideoBufferType.RGB24, flip_y=True)
>>> print(converted_frame.type)
VideoBufferType.RGB24
>>> print(converted_frame.width, converted_frame.height)
1280 720
"""
req = proto.FfiRequest()
req.video_convert.flip_y = flip_y
req.video_convert.dst_type = type
Expand All @@ -112,6 +199,9 @@ def convert(

return VideoFrame._from_owned_info(resp.video_convert.buffer)

def __repr__(self) -> str:
return f"rtc.VideoFrame(width={self.width}, height={self.height}, type={self.type})"


def _component_info(
data_ptr: int, stride: int, size: int
Expand Down

0 comments on commit d37e8eb

Please sign in to comment.