Skip to content

Commit

Permalink
refactor: simplify frame capture code
Browse files Browse the repository at this point in the history
  • Loading branch information
cs50victor committed Jan 13, 2024
1 parent c944963 commit db9ddec
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 119 deletions.
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions lkgpt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ ezsockets = { version = "0.6.1", features = ["client"] }
futures = "0.3.29"
futures-intrusive = "0.5.0"
image = { version = "0.24.7", default-features = false, features = ["png"] }
livekit = { git = "https://github.com/livekit/rust-sdks", rev = "b9ba46e", features = [
livekit = { git = "https://github.com/cs50victor/livekit-rust-sdks.git", branch = "update-proto", features = [
"native-tls",
] }
livekit-api = { git = "https://github.com/livekit/rust-sdks", rev = "b9ba46e" }
livekit-protocol = { git = "https://github.com/livekit/rust-sdks", rev = "b9ba46e" }
livekit-api = { git = "https://github.com/cs50victor/livekit-rust-sdks.git", branch = "update-proto" }
livekit-protocol = { git = "https://github.com/cs50victor/livekit-rust-sdks.git", branch = "update-proto" }
log = "0.4.20"
parking_lot = "0.12.1"
pollster = "0.3.0"
Expand Down
19 changes: 5 additions & 14 deletions lkgpt/src/frame_capture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,8 @@ pub mod image_copy {
}
}
pub mod scene {
use std::path::PathBuf;

use bevy::{
app::AppExit,
prelude::*,
render::{camera::RenderTarget, renderer::RenderDevice},
};
Expand Down Expand Up @@ -325,31 +323,24 @@ pub mod scene {
for image in images_to_save.iter() {
let img_bytes = images.get_mut(image.id()).unwrap();

let img = match img_bytes.clone().try_into_dynamic() {
let rgba_img = match img_bytes.clone().try_into_dynamic() {
Ok(img) => img.to_rgba8(),
Err(e) => panic!("Failed to create image buffer {e:?}"),
};

single_frame_data.frame_data.image = img;

if let Err(e) = async_runtime
.rt
.spawn_blocking({
let frame_data = single_frame_data.frame_data.clone();
let source = single_frame_data.video_src.clone();
move || {
let img_vec = frame_data.image.as_raw();
let mut framebuffer = frame_data.framebuffer.lock();
// VIDEO FRAME BUFFER (i420_buffer)
let mut video_frame = frame_data.video_frame.lock();
let i420_buffer = &mut video_frame.buffer;

let (stride_y, stride_u, stride_v) = i420_buffer.strides();
let (data_y, data_u, data_v) = i420_buffer.data_mut();

framebuffer.clone_from_slice(img_vec);
let (stride_y, stride_u, stride_v) = video_frame.buffer.strides();
let (data_y, data_u, data_v) = video_frame.buffer.data_mut();

livekit::webrtc::native::yuv_helper::abgr_to_i420(
&framebuffer,
rgba_img.as_raw(),
w * pixel_size,
data_y,
stride_y,
Expand Down
94 changes: 54 additions & 40 deletions lkgpt/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,14 @@ mod stt;
mod tts;
mod video;

use std::{
borrow::BorrowMut,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};

use async_openai::{
config::OpenAIConfig,
types::{
ChatCompletionRequestUserMessageArgs, ChatCompletionRequestUserMessageContent,
CreateChatCompletionRequestArgs,
},
Client,
};
use chrono::Utc;
use frame_capture::scene::SceneController;
use livekit::{publication::LocalTrackPublication, Room};
use livekit::{publication::LocalTrackPublication, webrtc::video_frame::VideoBuffer, Room};
// use actix_web::{middleware, web::Data, App, HttpServer};
use log::info;

Expand All @@ -40,18 +29,13 @@ use livekit::{
RoomError,
};

use async_openai::Client as OPENAI_CLIENT;

use bevy::{
app::ScheduleRunnerPlugin, core::Name, core_pipeline::tonemapping::Tonemapping, log::LogPlugin,
prelude::*, render::renderer::RenderDevice, time::common_conditions::on_timer,
};
use bevy_panorbit_camera::{PanOrbitCamera, PanOrbitCameraPlugin};

use bevy_gaussian_splatting::{
random_gaussians, utils::get_arg, GaussianCloud, GaussianSplattingBundle,
GaussianSplattingPlugin,
};
use bevy_gaussian_splatting::{GaussianCloud, GaussianSplattingBundle, GaussianSplattingPlugin};

use stt::{receive_and_process_audio, STT};

Expand Down Expand Up @@ -90,8 +74,6 @@ impl FromWorld for AsyncRuntime {

#[derive(Clone)]
struct FrameData {
image: image::ImageBuffer<image::Rgba<u8>, Vec<u8>>,
framebuffer: std::sync::Arc<parking_lot::Mutex<Vec<u8>>>,
video_frame: std::sync::Arc<
parking_lot::Mutex<
livekit::webrtc::video_frame::VideoFrame<livekit::webrtc::video_frame::I420Buffer>,
Expand Down Expand Up @@ -161,7 +143,7 @@ pub fn handle_room_events(
async_runtime: Res<AsyncRuntime>,
llm_channel: Res<llm::LLMChannel>,
audio_channel: Res<stt::AudioInputChannel>,
video_channel: Res<video::VideoChannel>,
_video_channel: Res<video::VideoChannel>,
audio_syncer: ResMut<AudioSync>,
mut room_events: ResMut<LivekitRoom>,
) {
Expand Down Expand Up @@ -198,19 +180,41 @@ pub fn handle_room_events(
let mut video_stream = NativeVideoStream::new(video_rtc_track);
async_runtime.rt.spawn(async move {
while let Some(frame) = video_stream.next().await {
let c = frame.buffer.to_i420();
// frame.buffer.to_argb(format, dst, dst_stride, dst_width, dst_height)
let _video_frame_buffer = frame.buffer.to_i420();
// let width: u32 = video_frame_buffer.width();
// let height: u32 = video_frame_buffer.height();
// let rgba_stride = video_frame_buffer.width() * 4;

// let (stride_y, stride_u, stride_v) = video_frame_buffer.strides();
// let (data_y, data_u, data_v) = video_frame_buffer.data();

// livekit::webrtc::native::yuv_helper::i420_to_rgba(
// c,
// frame.width,
// frame.height,
// frame.width,
// frame.height,
// video_track.rotation,
// video_track.flip,
// data_y,
// stride_y,
// data_u,
// stride_u,
// data_v,
// stride_v,
// rgba_ptr,
// rgba_stride,
// video_frame_buffer.width() as i32,
// video_frame_buffer.height() as i32,
// );

// if let Err(e)= audio_channel_tx.send(frame.data.to_vec()){
// log::error!("Couldn't send audio frame to stt {e}");
// };

/*
image: image::ImageBuffer<image::Rgba<u8>, Vec<u8>>,
framebuffer: std::sync::Arc<parking_lot::Mutex<Vec<u8>>>,
video_frame: std::sync::Arc<
parking_lot::Mutex<
livekit::webrtc::video_frame::VideoFrame<livekit::webrtc::video_frame::I420Buffer>,
>,
>,
*/
}
});
},
Expand Down Expand Up @@ -239,10 +243,10 @@ pub fn handle_room_events(
}
},
// ignoring the participant for now, currently assuming there is only one participant
RoomEvent::TrackMuted { participant, publication } => {
RoomEvent::TrackMuted { participant: _, publication: _ } => {
audio_syncer.should_stop.store(true, Ordering::Relaxed);
},
RoomEvent::TrackUnmuted { participant, publication } => {
RoomEvent::TrackUnmuted { participant: _, publication: _ } => {
audio_syncer.should_stop.store(false, Ordering::Relaxed);
},
_ => info!("received room event {:?}", event),
Expand All @@ -260,6 +264,7 @@ pub struct TracksPublicationData {
pub async fn publish_tracks(
room: std::sync::Arc<Room>,
bot_name: &str,
video_frame_dimension: (u32, u32),
) -> Result<TracksPublicationData, RoomError> {
let audio_src = NativeAudioSource::new(
AudioSourceOptions::default(),
Expand All @@ -270,7 +275,7 @@ pub async fn publish_tracks(
let audio_track =
LocalAudioTrack::create_audio_track(bot_name, RtcAudioSource::Native(audio_src.clone()));

let (width, height) = (1920, 1080);
let (width, height) = video_frame_dimension;
let video_src =
NativeVideoSource::new(livekit::webrtc::video_source::VideoResolution { width, height });
let video_track =
Expand Down Expand Up @@ -303,7 +308,7 @@ pub async fn publish_tracks(
fn setup_gaussian_cloud(
mut commands: Commands,
asset_server: Res<AssetServer>,
mut gaussian_assets: ResMut<Assets<GaussianCloud>>,
_gaussian_assets: ResMut<Assets<GaussianCloud>>,
mut scene_controller: ResMut<frame_capture::scene::SceneController>,
mut images: ResMut<Assets<Image>>,
render_device: Res<RenderDevice>,
Expand Down Expand Up @@ -349,7 +354,7 @@ pub fn sync_bevy_and_server_resources(
async_runtime: Res<AsyncRuntime>,
mut server_state_clone: ResMut<AppStateSync>,
mut set_app_state: ResMut<NextState<AppState>>,
mut scene_controller: Res<SceneController>,
scene_controller: Res<SceneController>,
) {
if !server_state_clone.dirty {
let participant_room_name = &(server_state_clone.state.lock().0).clone();
Expand All @@ -361,14 +366,14 @@ pub fn sync_bevy_and_server_resources(
));
match status {
Ok(room_data) => {
info!("connected to livekit room");
info!("🎉connected to livekit room");

let RoomData {
livekit_room,
stream_frame_data,
audio_src,
video_pub,
audio_pub,
video_pub: _,
audio_pub: _,
} = room_data;

info!("initializing required bevy resources");
Expand Down Expand Up @@ -458,7 +463,10 @@ fn main() {
app.init_resource::<frame_capture::scene::SceneController>();
app.add_event::<frame_capture::scene::SceneController>();

// app.add_systems(Update, move_camera);

app.add_systems(Update, server::shutdown_bevy_remotely);

app.add_systems(
Update,
handle_room_events
Expand Down Expand Up @@ -491,3 +499,9 @@ fn main() {

app.run();
}

fn move_camera(mut camera: Query<&mut Transform, With<Camera>>) {
for mut transform in camera.iter_mut() {
transform.translation.x += 5.0;
}
}
27 changes: 9 additions & 18 deletions lkgpt/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,20 @@ use livekit::{
webrtc::{
audio_source::native::NativeAudioSource,
video_frame::{I420Buffer, VideoFrame, VideoRotation},
video_source::native::NativeVideoSource,
},
RoomEvent,
};
use log::info;

use actix_web::web;
use parking_lot::Mutex;
use tokio::sync::mpsc::UnboundedReceiver;

use std::{borrow::Borrow, fmt::Debug, ops::DerefMut, sync::Arc};
use std::{fmt::Debug, sync::Arc};

use bevy::{ecs::world, prelude::*};
use bevy::prelude::*;
use livekit_api::access_token::{AccessToken, VideoGrants};
use serde::{Deserialize, Serialize};

use crate::{LivekitRoom, LIVEKIT_API_KEY, LIVEKIT_API_SECRET, LIVEKIT_WS_URL, OPENAI_ORG_ID};
use crate::{LivekitRoom, LIVEKIT_API_KEY, LIVEKIT_API_SECRET, LIVEKIT_WS_URL};

#[derive(Debug, Serialize, Deserialize)]
pub struct ServerMsg<T> {
Expand Down Expand Up @@ -115,29 +112,21 @@ pub async fn setup_and_connect_to_livekit(

// ************** SETUP OPENAI, TTS, & STT **************
let crate::TracksPublicationData { video_pub, video_src, audio_src, audio_pub } =
crate::publish_tracks(room.clone(), bot_name).await?;
crate::publish_tracks(room.clone(), bot_name, video_frame_dimension).await?;

let pixel_size = 4;
let pixel_size = 4_u32;

let (w, h) = (video_frame_dimension.0 as usize, video_frame_dimension.1 as usize);

let frame_data = crate::FrameData {
image: ImageBuffer::<Rgba<u8>, Vec<u8>>::from_raw(
w as u32,
h as u32,
vec![0u8; w * h * pixel_size],
)
.unwrap(),
framebuffer: Arc::new(Mutex::new(vec![0u8; w * h * pixel_size])),
video_frame: Arc::new(Mutex::new(VideoFrame {
rotation: VideoRotation::VideoRotation0,
buffer: I420Buffer::new(w as u32, h as u32),
timestamp_us: 0,
})),
};

let stream_frame_data =
crate::StreamingFrameData { pixel_size: pixel_size as u32, video_src, frame_data };
let stream_frame_data = crate::StreamingFrameData { pixel_size, video_src, frame_data };

let livekit_room = LivekitRoom { room, room_events };

Expand Down Expand Up @@ -188,12 +177,14 @@ mod lsdk_webhook {
.unwrap_or_default()
.to_string();

let jwt = jwt.trim();

let body = match std::str::from_utf8(&body) {
Ok(i) => i,
Err(e) => return Resp::BadRequest().json(ServerMsg::error(e.to_string())),
};

let event = match webhook_receiver.receive(body, &jwt) {
let event = match webhook_receiver.receive(body, jwt) {
Ok(i) => i,
Err(e) => return Resp::InternalServerError().json(ServerMsg::error(e.to_string())),
};
Expand Down
Loading

0 comments on commit db9ddec

Please sign in to comment.