From 26ea0020ac3e680172f9e7d5eb8d74b6def4dbd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Fri, 27 Oct 2023 16:22:26 -0700 Subject: [PATCH] accept more buffer types --- livekit-rtc/livekit/rtc/audio_frame.py | 5 ++-- livekit-rtc/livekit/rtc/room.py | 4 +-- livekit-rtc/livekit/rtc/video_frame.py | 41 +++++++++++++++++--------- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/livekit-rtc/livekit/rtc/audio_frame.py b/livekit-rtc/livekit/rtc/audio_frame.py index 7a31fbb5..ca53e44d 100644 --- a/livekit-rtc/livekit/rtc/audio_frame.py +++ b/livekit-rtc/livekit/rtc/audio_frame.py @@ -17,12 +17,13 @@ from ._proto import audio_frame_pb2 as proto_audio from ._proto import ffi_pb2 as proto_ffi from ._utils import get_address +from typing import Union class AudioFrame: def __init__( self, - data: bytearray, + data: Union[bytes, bytearray, memoryview], sample_rate: int, num_channels: int, samples_per_channel: int, @@ -34,10 +35,10 @@ def __init__( "data length must be >= num_channels * samples_per_channel * sizeof(int16)" ) + self._data = bytearray(data) self._sample_rate = sample_rate self._num_channels = num_channels self._samples_per_channel = samples_per_channel - self._data = data @staticmethod def create( diff --git a/livekit-rtc/livekit/rtc/room.py b/livekit-rtc/livekit/rtc/room.py index 2087e015..f76edf9c 100644 --- a/livekit-rtc/livekit/rtc/room.py +++ b/livekit-rtc/livekit/rtc/room.py @@ -133,8 +133,8 @@ async def connect( if options.e2ee: req.connect.options.e2ee.encryption_type = options.e2ee.encryption_type req.connect.options.e2ee.key_provider_options.shared_key = ( - options.e2ee.key_provider_options.shared_key - ) # type: ignore + options.e2ee.key_provider_options.shared_key # type: ignore + ) req.connect.options.e2ee.key_provider_options.ratchet_salt = ( options.e2ee.key_provider_options.ratchet_salt ) diff --git a/livekit-rtc/livekit/rtc/video_frame.py b/livekit-rtc/livekit/rtc/video_frame.py index d82187a1..776b7cfb 100644 --- a/livekit-rtc/livekit/rtc/video_frame.py +++ b/livekit-rtc/livekit/rtc/video_frame.py @@ -38,12 +38,16 @@ def __init__( class VideoFrameBuffer(ABC): def __init__( self, - data: bytearray, + data: Union[bytes, bytearray, memoryview], width: int, height: int, buffer_type: VideoFrameBufferType.ValueType, ) -> None: - self._data = data + view = memoryview(data) + if not view.c_contiguous: + raise ValueError("data must be contiguous") + + self._data = bytearray(data) self._width = width self._height = height self._buffer_type = buffer_type @@ -145,7 +149,7 @@ def to_argb(self, dst: "ArgbFrame") -> None: class PlanarYuvBuffer(VideoFrameBuffer, ABC): def __init__( self, - data: bytearray, + data: Union[bytes, bytearray, memoryview], width: int, height: int, buffer_type: VideoFrameBufferType.ValueType, @@ -198,7 +202,7 @@ def stride_v(self) -> int: class PlanarYuv8Buffer(PlanarYuvBuffer, ABC): def __init__( self, - data: bytearray, + data: Union[bytes, bytearray, memoryview], width: int, height: int, buffer_type: VideoFrameBufferType.ValueType, @@ -251,7 +255,7 @@ def data_v(self) -> memoryview: class PlanarYuv16Buffer(PlanarYuvBuffer, ABC): def __init__( self, - data: bytearray, + data: Union[bytes, bytearray, memoryview], width: int, height: int, buffer_type: VideoFrameBufferType.ValueType, @@ -304,7 +308,7 @@ def data_v(self) -> memoryview: class BiplanaraYuv8Buffer(VideoFrameBuffer, ABC): def __init__( self, - data: bytearray, + data: Union[bytes, bytearray, memoryview], width: int, height: int, buffer_type: VideoFrameBufferType.ValueType, @@ -363,7 +367,7 @@ def data_uv(self) -> memoryview: class I420Buffer(PlanarYuv8Buffer): def __init__( self, - data: bytearray, + data: Union[bytes, bytearray, memoryview], width: int, height: int, stride_y: int, @@ -423,7 +427,7 @@ def create(width: int, height: int) -> "I420Buffer": class I420ABuffer(PlanarYuv8Buffer): def __init__( self, - data: bytearray, + data: Union[bytes, bytearray, memoryview], width: int, height: int, stride_y: int, @@ -506,7 +510,7 @@ def data_a(self) -> memoryview: class I422Buffer(PlanarYuv8Buffer): def __init__( self, - data: bytearray, + data: Union[bytes, bytearray, memoryview], width: int, height: int, stride_y: int, @@ -521,6 +525,10 @@ def __init__( ) ) + view = memoryview(data) + if not view.c_contiguous: + raise ValueError("data must be contiguous") + chroma_width = (width + 1) // 2 chroma_height = height super().__init__( @@ -559,7 +567,7 @@ def calc_data_size(height: int, stride_y: int, stride_u: int, stride_v: int) -> class I444Buffer(PlanarYuv8Buffer): def __init__( self, - data: bytearray, + data: Union[bytes, bytearray, memoryview], width: int, height: int, stride_y: int, @@ -612,7 +620,7 @@ def calc_data_size(height: int, stride_y: int, stride_u: int, stride_v: int) -> class I010Buffer(PlanarYuv16Buffer): def __init__( self, - data: bytearray, + data: Union[bytes, bytearray, memoryview], width: int, height: int, stride_y: int, @@ -668,7 +676,12 @@ def calc_data_size(height: int, stride_y: int, stride_u: int, stride_v: int) -> class NV12Buffer(BiplanaraYuv8Buffer): def __init__( - self, data: bytearray, width: int, height: int, stride_y: int, stride_uv: int + self, + data: Union[bytes, bytearray, memoryview], + width: int, + height: int, + stride_y: int, + stride_uv: int, ) -> None: if len(data) < NV12Buffer.calc_data_size(height, stride_y, stride_uv): raise ValueError( @@ -759,8 +772,8 @@ def to_i420(self) -> I420Buffer: return I420Buffer._from_owned_info(res.to_i420.buffer) @property - def data(self) -> bytearray: - return self._data + def data(self) -> memoryview: + return memoryview(self._data) @property def width(self) -> int: