Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add AVSynchronizer #1168

Closed
wants to merge 11 commits into from
172 changes: 172 additions & 0 deletions examples/video-stream/av_sync_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import asyncio
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import AsyncIterable

import av
import numpy as np
from dotenv import load_dotenv
from livekit import rtc
from livekit.agents import JobContext, WorkerOptions, cli, utils
from livekit.agents.utils.av_sync import AVSynchronizer

# Load environment variables
load_dotenv()

logger = logging.getLogger(__name__)


@dataclass
class MediaInfo:
video_width: int
video_height: int
video_fps: float
audio_sample_rate: int
audio_channels: int


class MediaFileStreamer:
"""Streams video and audio frames from a media file."""

def __init__(self, media_file: str | Path) -> None:
self._media_file = str(media_file)
self._container = av.open(self._media_file)

self._video_stream = self._container.streams.video[0]
self._audio_stream = self._container.streams.audio[0]

# Cache media info
self._info = MediaInfo(
video_width=self._video_stream.width,
video_height=self._video_stream.height,
video_fps=float(self._video_stream.average_rate),
audio_sample_rate=self._audio_stream.sample_rate,
audio_channels=self._audio_stream.channels,
)

@property
def info(self) -> MediaInfo:
return self._info

async def stream_video(self) -> AsyncIterable[rtc.VideoFrame]:
"""Streams video frames from the media file."""
container = av.open(self._media_file)
try:
for frame in container.decode(video=0):
# Convert video frame to RGBA
frame = frame.to_rgb().to_ndarray()
frame_rgba = np.ones(
(frame.shape[0], frame.shape[1], 4), dtype=np.uint8
)
frame_rgba[:, :, :3] = frame
yield rtc.VideoFrame(
width=frame.shape[1],
height=frame.shape[0],
type=rtc.VideoBufferType.RGBA,
data=frame_rgba.tobytes(),
)
finally:
container.close()

async def stream_audio(self) -> AsyncIterable[rtc.AudioFrame]:
"""Streams audio frames from the media file."""
container = av.open(self._media_file)
try:
for frame in container.decode(audio=0):
# Convert audio frame to raw int16 samples
frame: np.ndarray = frame.to_ndarray(format="s16")
frame = (frame * 32768).astype(np.int16)
yield rtc.AudioFrame(
data=frame.tobytes(),
sample_rate=self.info.audio_sample_rate,
num_channels=frame.shape[0],
samples_per_channel=frame.shape[1],
)
finally:
container.close()

async def aclose(self) -> None:
"""Closes the media container."""
self._container.close()


async def entrypoint(job: JobContext):
await job.connect()
room = job.room

# Create media streamer
media_path = "/path/to/sample.mp4"
streamer = MediaFileStreamer(media_path)
media_info = streamer.info

# Create video and audio sources/tracks
queue_size_ms = 100
video_source = rtc.VideoSource(
width=media_info.video_width,
height=media_info.video_height,
)
audio_source = rtc.AudioSource(
sample_rate=media_info.audio_sample_rate,
num_channels=media_info.audio_channels,
queue_size_ms=queue_size_ms,
)

video_track = rtc.LocalVideoTrack.create_video_track("video", video_source)
audio_track = rtc.LocalAudioTrack.create_audio_track("audio", audio_source)

# Publish tracks
video_options = rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_CAMERA)
audio_options = rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE)

await room.local_participant.publish_track(video_track, video_options)
await room.local_participant.publish_track(audio_track, audio_options)

# Create AV synchronizer
av_sync = AVSynchronizer(
audio_source=audio_source,
video_source=video_source,
video_fps=media_info.video_fps,
video_queue_size_ms=queue_size_ms,
)

@utils.log_exceptions(logger=logger)
async def _push_video_frames(
video_stream: AsyncIterable[rtc.VideoFrame], av_sync: AVSynchronizer
) -> None:
"""Task to push video frames to the AV synchronizer."""
async for frame in video_stream:
await av_sync.push(frame)

@utils.log_exceptions(logger=logger)
async def _push_audio_frames(
audio_stream: AsyncIterable[rtc.AudioFrame], av_sync: AVSynchronizer
) -> None:
"""Task to push audio frames to the AV synchronizer."""
async for frame in audio_stream:
await av_sync.push(frame)

try:
while True:
# Create and run video and audio streaming tasks
video_stream = streamer.stream_video()
audio_stream = streamer.stream_audio()

video_task = asyncio.create_task(_push_video_frames(video_stream, av_sync))
audio_task = asyncio.create_task(_push_audio_frames(audio_stream, av_sync))

# Wait for both tasks to complete
# TODO: wait the frame in buffer to be processed
await asyncio.gather(video_task, audio_task)
finally:
await av_sync.aclose()
await streamer.aclose()


if __name__ == "__main__":
cli.run_app(
WorkerOptions(
entrypoint_fnc=entrypoint,
job_memory_warn_mb=400,
)
)
158 changes: 158 additions & 0 deletions livekit-agents/livekit/agents/utils/av_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import asyncio
import logging
import time
from collections import deque
from typing import Optional, Union

import livekit.agents.utils as utils
from livekit import rtc

logger = logging.getLogger(__name__)


class AVSynchronizer:
"""Synchronize audio and video capture.

Usage:
av_sync = AVSynchronizer(
audio_source=audio_source,
video_source=video_source,
video_fps=video_fps,
)

async for video_frame, audio_frame in video_generator:
await av_sync.push(video_frame)
await av_sync.push(audio_frame)
"""

def __init__(
self,
*,
audio_source: rtc.AudioSource,
video_source: rtc.VideoSource,
video_fps: float,
video_queue_size_ms: float = 1000,
_max_delay_tolerance_ms: float = 300,
):
self._audio_source = audio_source
self._video_source = video_source
self._video_fps = video_fps
self._video_queue_size_ms = video_queue_size_ms
self._max_delay_tolerance_ms = _max_delay_tolerance_ms

self._stopped = False

self._video_queue_max_size = int(
self._video_fps * self._video_queue_size_ms / 1000
)
self._video_queue = asyncio.Queue[rtc.VideoFrame](
maxsize=self._video_queue_max_size
)
self._capture_video_task = asyncio.create_task(self._capture_video())

async def push(self, frame: Union[rtc.VideoFrame, rtc.AudioFrame]) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this wait to have at least one AudioFrame and one VideoFrame to start the capture? So we synchronize the start of the stream as well

Copy link
Collaborator Author

@longcw longcw Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think no, the agent can stream only video frames if there is no audio at the beginning. The user only need to make sure when there is a audio for playing, the first audio frame is pushed at the same time with the corresponding video frame.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to make the queue_size smaller enough for that case, e.g. 100ms or even less.

if isinstance(frame, rtc.AudioFrame):
# TODO: test if frame duration is too long
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we need to keep track of the pushed audio.

await self._audio_source.capture_frame(frame)
return

await self._video_queue.put(frame)

async def _capture_video(self) -> None:
fps_controller = _FPSController(
expected_fps=self._video_fps,
max_delay_tolerance_ms=self._max_delay_tolerance_ms,
)
while not self._stopped:
frame = await self._video_queue.get()
async with fps_controller:
self._video_source.capture_frame(frame)

async def aclose(self) -> None:
self._stopped = True
if self._capture_video_task:
await utils.aio.gracefully_cancel(self._capture_video_task)


class _FPSController:
def __init__(
self, *, expected_fps: float, max_delay_tolerance_ms: float = 300
) -> None:
"""Controls frame rate by adjusting sleep time based on actual FPS.

Usage:
async with _FPSController(expected_fps=30):
# process frame
pass

Args:
expected_fps: Target frames per second
max_delay_tolerance_ms: Maximum delay tolerance in milliseconds
"""
self._expected_fps = expected_fps
self._frame_interval = 1.0 / expected_fps
self._max_delay_tolerance_secs = max_delay_tolerance_ms / 1000

self._next_frame_time: Optional[float] = None
self._fps_calc_winsize = max(2, int(0.5 * expected_fps))
self._send_timestamps: deque[float] = deque(maxlen=self._fps_calc_winsize)

async def __aenter__(self) -> None:
await self.wait_next_process()

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
self.after_process()

async def wait_next_process(self) -> None:
"""Wait until it's time for the next frame.

Adjusts sleep time based on actual FPS to maintain target rate.
"""
current_time = time.perf_counter()

# initialize the next frame time
if self._next_frame_time is None:
self._next_frame_time = current_time

# calculate sleep time
sleep_time = self._next_frame_time - current_time
if sleep_time > 0:
await asyncio.sleep(sleep_time)
else:
logger.debug(
"Sync state",
extra={"sleep_time": sleep_time, "fps": self.actual_fps},
)
# check if significantly behind schedule
if -sleep_time > self._max_delay_tolerance_secs:
logger.warning(
f"Frame capture was behind schedule for "
f"{-sleep_time * 1000:.2f} ms"
)
self._next_frame_time = time.perf_counter()

def after_process(self) -> None:
"""Update timing information after processing a frame."""
assert (
self._next_frame_time is not None
), "wait_next_process must be called first"

# update timing information
self._send_timestamps.append(time.perf_counter())

# calculate next frame time
self._next_frame_time += self._frame_interval

@property
def expected_fps(self) -> float:
return self._expected_fps

@property
def actual_fps(self) -> float:
"""Get current average FPS."""
if len(self._send_timestamps) < 2:
return 0

return (len(self._send_timestamps) - 1) / (
self._send_timestamps[-1] - self._send_timestamps[0]
)
Loading