Skip to content

Commit

Permalink
accept more buffer types
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom committed Oct 27, 2023
1 parent a788af3 commit 26ea002
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 18 deletions.
5 changes: 3 additions & 2 deletions livekit-rtc/livekit/rtc/audio_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
from ._proto import audio_frame_pb2 as proto_audio
from ._proto import ffi_pb2 as proto_ffi
from ._utils import get_address
from typing import Union


class AudioFrame:
def __init__(
self,
data: bytearray,
data: Union[bytes, bytearray, memoryview],
sample_rate: int,
num_channels: int,
samples_per_channel: int,
Expand All @@ -34,10 +35,10 @@ def __init__(
"data length must be >= num_channels * samples_per_channel * sizeof(int16)"
)

self._data = bytearray(data)
self._sample_rate = sample_rate
self._num_channels = num_channels
self._samples_per_channel = samples_per_channel
self._data = data

@staticmethod
def create(
Expand Down
4 changes: 2 additions & 2 deletions livekit-rtc/livekit/rtc/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ async def connect(
if options.e2ee:
req.connect.options.e2ee.encryption_type = options.e2ee.encryption_type
req.connect.options.e2ee.key_provider_options.shared_key = (
options.e2ee.key_provider_options.shared_key
) # type: ignore
options.e2ee.key_provider_options.shared_key # type: ignore
)
req.connect.options.e2ee.key_provider_options.ratchet_salt = (
options.e2ee.key_provider_options.ratchet_salt
)
Expand Down
41 changes: 27 additions & 14 deletions livekit-rtc/livekit/rtc/video_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ def __init__(
class VideoFrameBuffer(ABC):
def __init__(
self,
data: bytearray,
data: Union[bytes, bytearray, memoryview],
width: int,
height: int,
buffer_type: VideoFrameBufferType.ValueType,
) -> None:
self._data = data
view = memoryview(data)
if not view.c_contiguous:
raise ValueError("data must be contiguous")

self._data = bytearray(data)
self._width = width
self._height = height
self._buffer_type = buffer_type
Expand Down Expand Up @@ -145,7 +149,7 @@ def to_argb(self, dst: "ArgbFrame") -> None:
class PlanarYuvBuffer(VideoFrameBuffer, ABC):
def __init__(
self,
data: bytearray,
data: Union[bytes, bytearray, memoryview],
width: int,
height: int,
buffer_type: VideoFrameBufferType.ValueType,
Expand Down Expand Up @@ -198,7 +202,7 @@ def stride_v(self) -> int:
class PlanarYuv8Buffer(PlanarYuvBuffer, ABC):
def __init__(
self,
data: bytearray,
data: Union[bytes, bytearray, memoryview],
width: int,
height: int,
buffer_type: VideoFrameBufferType.ValueType,
Expand Down Expand Up @@ -251,7 +255,7 @@ def data_v(self) -> memoryview:
class PlanarYuv16Buffer(PlanarYuvBuffer, ABC):
def __init__(
self,
data: bytearray,
data: Union[bytes, bytearray, memoryview],
width: int,
height: int,
buffer_type: VideoFrameBufferType.ValueType,
Expand Down Expand Up @@ -304,7 +308,7 @@ def data_v(self) -> memoryview:
class BiplanaraYuv8Buffer(VideoFrameBuffer, ABC):
def __init__(
self,
data: bytearray,
data: Union[bytes, bytearray, memoryview],
width: int,
height: int,
buffer_type: VideoFrameBufferType.ValueType,
Expand Down Expand Up @@ -363,7 +367,7 @@ def data_uv(self) -> memoryview:
class I420Buffer(PlanarYuv8Buffer):
def __init__(
self,
data: bytearray,
data: Union[bytes, bytearray, memoryview],
width: int,
height: int,
stride_y: int,
Expand Down Expand Up @@ -423,7 +427,7 @@ def create(width: int, height: int) -> "I420Buffer":
class I420ABuffer(PlanarYuv8Buffer):
def __init__(
self,
data: bytearray,
data: Union[bytes, bytearray, memoryview],
width: int,
height: int,
stride_y: int,
Expand Down Expand Up @@ -506,7 +510,7 @@ def data_a(self) -> memoryview:
class I422Buffer(PlanarYuv8Buffer):
def __init__(
self,
data: bytearray,
data: Union[bytes, bytearray, memoryview],
width: int,
height: int,
stride_y: int,
Expand All @@ -521,6 +525,10 @@ def __init__(
)
)

view = memoryview(data)
if not view.c_contiguous:
raise ValueError("data must be contiguous")

chroma_width = (width + 1) // 2
chroma_height = height
super().__init__(
Expand Down Expand Up @@ -559,7 +567,7 @@ def calc_data_size(height: int, stride_y: int, stride_u: int, stride_v: int) ->
class I444Buffer(PlanarYuv8Buffer):
def __init__(
self,
data: bytearray,
data: Union[bytes, bytearray, memoryview],
width: int,
height: int,
stride_y: int,
Expand Down Expand Up @@ -612,7 +620,7 @@ def calc_data_size(height: int, stride_y: int, stride_u: int, stride_v: int) ->
class I010Buffer(PlanarYuv16Buffer):
def __init__(
self,
data: bytearray,
data: Union[bytes, bytearray, memoryview],
width: int,
height: int,
stride_y: int,
Expand Down Expand Up @@ -668,7 +676,12 @@ def calc_data_size(height: int, stride_y: int, stride_u: int, stride_v: int) ->

class NV12Buffer(BiplanaraYuv8Buffer):
def __init__(
self, data: bytearray, width: int, height: int, stride_y: int, stride_uv: int
self,
data: Union[bytes, bytearray, memoryview],
width: int,
height: int,
stride_y: int,
stride_uv: int,
) -> None:
if len(data) < NV12Buffer.calc_data_size(height, stride_y, stride_uv):
raise ValueError(
Expand Down Expand Up @@ -759,8 +772,8 @@ def to_i420(self) -> I420Buffer:
return I420Buffer._from_owned_info(res.to_i420.buffer)

@property
def data(self) -> bytearray:
return self._data
def data(self) -> memoryview:
return memoryview(self._data)

@property
def width(self) -> int:
Expand Down

0 comments on commit 26ea002

Please sign in to comment.