diff --git a/infer_server/src/bin/hour_glass.rs b/infer_server/src/bin/infer_server.rs similarity index 88% rename from infer_server/src/bin/hour_glass.rs rename to infer_server/src/bin/infer_server.rs index 1e809b0..47ff4fc 100644 --- a/infer_server/src/bin/hour_glass.rs +++ b/infer_server/src/bin/infer_server.rs @@ -7,14 +7,12 @@ use axum::{routing::get, Extension, Router}; use clap::Parser; use env_logger::TimestampPrecision; use infer_server::{ - hour_glass::{ - data_socket::spawn_data_socket, - endpoints::{faces_stream, healthcheck, named_stream}, - inferer::Inferer, - router::FrameRouter, - INCOMING_FRAMES_CHANNEL, INFER_IMAGES_CHANNEL, - }, + data_socket::spawn_data_socket, + endpoints::{faces_stream, healthcheck, named_stream}, + inferer::Inferer, meter::spawn_meter_logger, + router::FrameRouter, + INCOMING_FRAMES_CHANNEL, INFER_IMAGES_CHANNEL, }; #[derive(Parser, Debug)] diff --git a/infer_server/src/bin/msg_passing.rs b/infer_server/src/bin/msg_passing.rs deleted file mode 100644 index 26d0b84..0000000 --- a/infer_server/src/bin/msg_passing.rs +++ /dev/null @@ -1,62 +0,0 @@ -//! Infer server binary. -//! -use std::{net::SocketAddr, sync::Arc}; - -use anyhow::Result; -use axum::{routing::get, Extension, Router}; -use clap::Parser; -use env_logger::TimestampPrecision; -use infer_server::{ - meter::spawn_meter_logger, - msg_passing::{ - data_socket::spawn_data_socket, - endpoints::{healthcheck, named_stream}, - router::Registry, - }, -}; - -#[derive(Parser, Debug)] -#[clap(author, version)] -struct Args { - /// Address of the infer server to connect to - #[clap(long, default_value = "127.0.0.1:3000")] - server_address: String, - - /// Address of the infer server to connect to - #[clap(long, default_value = "127.0.0.1:3001")] - socket_address: String, -} - -#[tokio::main] -async fn main() -> Result<()> { - let args = Args::parse(); - - // Setup logger - env_logger::builder() - .format_timestamp(Some(TimestampPrecision::Millis)) - .init(); - - let mut registry = Registry::new(); - let comm = registry.get_comm(); - - tokio::spawn(async move { registry.run().await }); - - // Create socket to receive image streams via network - spawn_data_socket(comm.tcp_tasks_comm_tx.clone(), &args.socket_address).await?; - - spawn_meter_logger(); - - // Build HTTP server with endpoints - let app = Router::new() - .route("/healthcheck", get(healthcheck)) - .route("/stream", get(named_stream)) - .layer(Extension(Arc::new(comm))); - - // Serve HTTP server - let addr: SocketAddr = args.server_address.parse()?; - axum::Server::bind(&addr) - .serve(app.into_make_service()) - .await?; - - Ok(()) -} diff --git a/infer_server/src/hour_glass/data_socket.rs b/infer_server/src/data_socket.rs similarity index 96% rename from infer_server/src/hour_glass/data_socket.rs rename to infer_server/src/data_socket.rs index 469253f..b88e3bb 100644 --- a/infer_server/src/hour_glass/data_socket.rs +++ b/infer_server/src/data_socket.rs @@ -10,7 +10,7 @@ use tokio::{ }; use tokio_util::codec::{Framed, LengthDelimitedCodec}; -use crate::hour_glass::StaticFrameSender; +use crate::StaticFrameSender; pub async fn spawn_data_socket( tx: StaticFrameSender, diff --git a/infer_server/src/hour_glass/endpoints.rs b/infer_server/src/endpoints.rs similarity index 97% rename from infer_server/src/hour_glass/endpoints.rs rename to infer_server/src/endpoints.rs index c6e21dc..7519173 100644 --- a/infer_server/src/hour_glass/endpoints.rs +++ b/infer_server/src/endpoints.rs @@ -7,7 +7,7 @@ use futures::StreamExt; use serde::Deserialize; use tokio_stream::wrappers::BroadcastStream; -use crate::{hour_glass::router::FrameRouter, meter::METER}; +use crate::{meter::METER, router::FrameRouter}; /// Search parameters available to streams. #[derive(Debug, Deserialize)] diff --git a/infer_server/src/hour_glass/mod.rs b/infer_server/src/hour_glass/mod.rs deleted file mode 100644 index 06ee201..0000000 --- a/infer_server/src/hour_glass/mod.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::{ - collections::hash_map::DefaultHasher, - hash::{Hash, Hasher}, -}; - -use bytes::{Bytes, BytesMut}; -use thingbuf::mpsc::{StaticChannel, StaticReceiver, StaticSender}; - -pub mod data_socket; -pub mod endpoints; -pub mod inferer; -pub mod router; - -pub type StaticFrameSender = StaticSender; -pub type StaticFrameReceiver = StaticReceiver; - -pub static INCOMING_FRAMES_CHANNEL: StaticChannel = StaticChannel::new(); - -pub type BroadcastSender = tokio::sync::broadcast::Sender; -pub type BroadcastReceiver = tokio::sync::broadcast::Receiver; - -pub fn broadcast_channel() -> (BroadcastSender, BroadcastReceiver) { - tokio::sync::broadcast::channel(20) -} - -pub type StaticImage = (u32, u32, Vec, Option); - -pub type StaticImageSender = StaticSender; -pub type StaticImageReceiver = StaticReceiver; - -pub static INFER_IMAGES_CHANNEL: StaticChannel = StaticChannel::new(); - -fn hashed(data: T) -> u64 -where - T: Hash, -{ - let mut hasher = DefaultHasher::new(); - data.hash(&mut hasher); - hasher.finish() -} - -fn as_jpeg_stream_item(data: &[u8]) -> Bytes { - Bytes::copy_from_slice( - &[ - "--frame\r\nContent-Type: image/jpeg\r\n\r\n".as_bytes(), - &data[..], - "\r\n\r\n".as_bytes(), - ] - .concat(), - ) -} diff --git a/infer_server/src/hour_glass/inferer.rs b/infer_server/src/inferer.rs similarity index 96% rename from infer_server/src/hour_glass/inferer.rs rename to infer_server/src/inferer.rs index eaa92c0..a1ed400 100644 --- a/infer_server/src/hour_glass/inferer.rs +++ b/infer_server/src/inferer.rs @@ -7,8 +7,8 @@ use imageproc::{ use lazy_static::lazy_static; use crate::{ - hour_glass::StaticImageReceiver, nn::{Bbox, InferModel, UltrafaceModel}, + StaticImageReceiver, }; use super::as_jpeg_stream_item; @@ -93,7 +93,7 @@ fn draw_bboxes_on_image( lazy_static! { static ref DEJAVU_MONO: rusttype::Font<'static> = { - let font_data: &[u8] = include_bytes!("../../../resources/DejaVuSansMono.ttf"); + let font_data: &[u8] = include_bytes!("../../resources/DejaVuSansMono.ttf"); let font: rusttype::Font<'static> = rusttype::Font::try_from_bytes(font_data).expect("failed to load font"); font diff --git a/infer_server/src/lib.rs b/infer_server/src/lib.rs index a85a8b3..5870745 100644 --- a/infer_server/src/lib.rs +++ b/infer_server/src/lib.rs @@ -6,12 +6,36 @@ use std::{ hash::{Hash, Hasher}, }; -pub mod hour_glass; +use bytes::{Bytes, BytesMut}; +use thingbuf::mpsc::{StaticChannel, StaticReceiver, StaticSender}; + +pub mod data_socket; +pub mod endpoints; +pub mod inferer; pub mod meter; -pub mod msg_passing; pub mod nn; +pub mod router; pub mod utils; +pub type StaticFrameSender = StaticSender; +pub type StaticFrameReceiver = StaticReceiver; + +pub static INCOMING_FRAMES_CHANNEL: StaticChannel = StaticChannel::new(); + +pub type BroadcastSender = tokio::sync::broadcast::Sender; +pub type BroadcastReceiver = tokio::sync::broadcast::Receiver; + +pub fn broadcast_channel() -> (BroadcastSender, BroadcastReceiver) { + tokio::sync::broadcast::channel(20) +} + +pub type StaticImage = (u32, u32, Vec, Option); + +pub type StaticImageSender = StaticSender; +pub type StaticImageReceiver = StaticReceiver; + +pub static INFER_IMAGES_CHANNEL: StaticChannel = StaticChannel::new(); + fn hashed(data: T) -> u64 where T: Hash, @@ -20,3 +44,14 @@ where data.hash(&mut hasher); hasher.finish() } + +fn as_jpeg_stream_item(data: &[u8]) -> Bytes { + Bytes::copy_from_slice( + &[ + "--frame\r\nContent-Type: image/jpeg\r\n\r\n".as_bytes(), + &data[..], + "\r\n\r\n".as_bytes(), + ] + .concat(), + ) +} diff --git a/infer_server/src/msg_passing/data_socket.rs b/infer_server/src/msg_passing/data_socket.rs deleted file mode 100644 index a3272d9..0000000 --- a/infer_server/src/msg_passing/data_socket.rs +++ /dev/null @@ -1,108 +0,0 @@ -//! Data socket module to receive image streams via network. -//! -use std::net::SocketAddr; - -use anyhow::{bail, Result}; -use bytes::Bytes; -use common::protocol::ProtoMsg; -use futures::StreamExt; -use tokio::{ - net::{TcpListener, TcpStream}, - sync::mpsc, - task::JoinHandle, -}; -use tokio_util::codec::{Framed, LengthDelimitedCodec}; - -use crate::msg_passing::router::TcpTaskDataSender; - -/// Spawn a data socket and register the stream with the Pub/Sub-Engine. -pub async fn spawn_data_socket( - registry_tx: TcpTaskDataSender, - addr: &str, -) -> Result>> { - let socket: SocketAddr = addr.parse()?; - Ok(tokio::spawn(async move { - let listener = TcpListener::bind(socket).await?; - - loop { - let (socket, _peer_addr) = listener.accept().await?; - let registry_tx = registry_tx.clone(); - tokio::spawn(async move { - handle_incoming(registry_tx, socket).await?; - Ok::<_, anyhow::Error>(()) - }); - } - })) -} - -async fn handle_incoming(registry_tx: TcpTaskDataSender, stream: TcpStream) -> Result<()> { - let addr = stream.peer_addr()?; - log::info!("{}: New TCP connection", &addr); - - let mut transport = Framed::new(stream, LengthDelimitedCodec::new()); - - let channel_name = { - if let Some(Ok(data)) = transport.next().await { - if let Ok(ProtoMsg::ConnectReq(channel)) = ProtoMsg::deserialize(&data) { - channel - } else { - bail!("no channel name"); - } - } else { - bail!("no channel name"); - } - }; - - let (senders_tx, mut senders_rx) = mpsc::channel(20); - registry_tx.send((channel_name, senders_tx)).await?; - - let mut listeners = Vec::new(); - let mut failed_senders = Vec::new(); - - loop { - tokio::select! { - res = senders_rx.recv() => { - match res { - None => panic!("registry closed"), - Some(new_listeners) => { - listeners.extend(new_listeners.into_iter()); - } - } - } - res = transport.next() => match res { - None => { - log::info!("TCP stream ended"); - } - Some(Ok(data)) => { - if let Ok(ProtoMsg::FrameMsg(msg)) = ProtoMsg::deserialize(&data) { - let data = as_jpeg_stream_item(&msg.data); - // Send - for (idx, sender) in listeners.iter().enumerate() { - if sender.send(data.clone()).await.is_err() { - failed_senders.push(idx); - } - } - - for idx in failed_senders.iter().rev() { - listeners.swap_remove(*idx); - } - } - } - Some(Err(e)) => { - log::warn!("Error in TCP codec: {e}"); - } - } - } - } -} - -fn as_jpeg_stream_item(data: &[u8]) -> Bytes { - Bytes::copy_from_slice( - &[ - "--frame\r\nContent-Type: image/jpeg\r\n\r\n".as_bytes(), - &data[..], - "\r\n\r\n".as_bytes(), - ] - .concat(), - ) -} diff --git a/infer_server/src/msg_passing/endpoints.rs b/infer_server/src/msg_passing/endpoints.rs deleted file mode 100644 index 4731a19..0000000 --- a/infer_server/src/msg_passing/endpoints.rs +++ /dev/null @@ -1,53 +0,0 @@ -//! Endpoints of HTTP server. -//! -use std::sync::Arc; - -use axum::{body::StreamBody, extract::Query, http::header, response::IntoResponse, Extension}; -use futures::StreamExt; -use serde::Deserialize; -use tokio::sync::mpsc; -use tokio_stream::wrappers::ReceiverStream; - -use crate::{meter::METER, msg_passing::router::RegistryComm}; - -/// Search parameters available to streams. -#[derive(Debug, Deserialize)] -pub struct StreamParams { - #[serde(default)] - name: Option, -} - -/// Health check endpoint. -pub async fn healthcheck() -> &'static str { - "healthy" -} - -// Endpoint of received image streams. -pub async fn named_stream( - Extension(registry): Extension>, - Query(params): Query, -) -> impl IntoResponse { - let name = params.name.unwrap_or_else(|| "unknown".into()); - log::info!("Stream for {} requested", &name); - - let (tx, rx) = mpsc::channel(20); - registry - .frame_stream_listener_tx - .send((name, tx)) - .await - .ok(); - - let stream = ReceiverStream::new(rx).map(|x| { - METER.tick_raw(); - Ok::<_, String>(x) - }); - - // Set body and headers for multipart streaming - let body = StreamBody::new(stream); - let headers = [( - header::CONTENT_TYPE, - "multipart/x-mixed-replace; boundary=frame", - )]; - - (headers, body) -} diff --git a/infer_server/src/msg_passing/mod.rs b/infer_server/src/msg_passing/mod.rs deleted file mode 100644 index f92aa10..0000000 --- a/infer_server/src/msg_passing/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod data_socket; -pub mod endpoints; -pub mod router; diff --git a/infer_server/src/msg_passing/router.rs b/infer_server/src/msg_passing/router.rs deleted file mode 100644 index 93a5101..0000000 --- a/infer_server/src/msg_passing/router.rs +++ /dev/null @@ -1,192 +0,0 @@ -use std::{collections::HashMap, time::Duration}; - -use bytes::Bytes; -use futures::{stream::FuturesUnordered, StreamExt}; -use tokio::sync::mpsc::{self, Receiver, Sender}; - -use crate::hashed; - -pub struct RegistryComm { - pub frame_stream_listener_tx: Sender<(String, Sender)>, - pub infered_stream_listener_tx: Sender<(String, Sender)>, - pub tcp_tasks_comm_tx: TcpTaskDataSender, -} - -pub type TcpTaskDataSender = Sender<(String, Sender>>)>; - -/// Router registry. -/// -/// In this approach, every HTTP stream creates its own tx/rx channel pair. The -/// tx end is stored in a map. Infer streams also generate a tx/rx channel pair -/// with the tx end stored in a separate map. -/// When a new data socket is connected, it creates a channel to receive updates -/// from the registry. Initially, it gets a clone of all the tx ends of streams -/// which want to receive updates of this data socket. -/// For every tx/rx channel pair, we spawn a future with a clone of the tx -/// end that completes when the rx end is dropped. This way, we notice if a -/// channel is closed and we should remove the sender from the map. -pub struct Registry { - frames_sender_map: HashMap>>, - infered_frames_sender_map: HashMap>>, - tcp_tasks_comm: HashMap>>>, - - frame_stream_listener_tx: Sender<(String, Sender)>, - frame_stream_listener_rx: Option)>>, - - infered_stream_listener_tx: Sender<(String, Sender)>, - infered_stream_listener_rx: Option)>>, - - tcp_tasks_comm_tx: Sender<(String, Sender>>)>, - tcp_tasks_comm_rx: Option>>)>>, -} - -impl Registry { - pub fn new() -> Self { - let (frame_stream_listener_tx, frame_stream_listener_rx) = mpsc::channel(20); - let (infered_stream_listener_tx, infered_stream_listener_rx) = mpsc::channel(20); - let (tcp_tasks_comm_tx, tcp_tasks_comm_rx) = mpsc::channel(20); - - Self { - frames_sender_map: HashMap::new(), - infered_frames_sender_map: HashMap::new(), - tcp_tasks_comm: HashMap::new(), - frame_stream_listener_tx, - frame_stream_listener_rx: Some(frame_stream_listener_rx), - infered_stream_listener_tx, - infered_stream_listener_rx: Some(infered_stream_listener_rx), - tcp_tasks_comm_tx, - tcp_tasks_comm_rx: Some(tcp_tasks_comm_rx), - } - } - - pub fn get_comm(&self) -> RegistryComm { - RegistryComm { - frame_stream_listener_tx: self.frame_stream_listener_tx.clone(), - infered_stream_listener_tx: self.infered_stream_listener_tx.clone(), - tcp_tasks_comm_tx: self.tcp_tasks_comm_tx.clone(), - } - } - - pub async fn run(&mut self) { - if let ( - Some(mut frame_stream_listener_rx), - Some(mut infered_stream_listener_rx), - Some(mut tcp_tasks_comm_rx), - ) = ( - self.frame_stream_listener_rx.take(), - self.infered_stream_listener_rx.take(), - self.tcp_tasks_comm_rx.take(), - ) { - let mut expired_sender_ids = FuturesUnordered::new(); - - loop { - tokio::select! { - res = expired_sender_ids.next() => { - match res { - None => { - tokio::time::sleep(Duration::from_millis(500)).await; - }, - Some(map_id) => match map_id { - MapId::Frame(id) => { - log::info!("Removing expired frame sender for ID {}", id); - self.frames_sender_map.remove(&id); - } - MapId::Infered(id) => { - log::info!("Removing expired infered frame sender for ID {}", id); - self.infered_frames_sender_map.remove(&id); - } - MapId::Tcp(id) => { - log::info!("Removing expired TCP communicator for ID {}", id); - self.tcp_tasks_comm.remove(&id); - } - } - } - - } - - res = frame_stream_listener_rx.recv() => { - match res { - None => panic!("stream ended"), - Some((name, tx)) => { - let id = hashed(&name); - { - let tx = tx.clone(); - expired_sender_ids.push( - Box::pin(async move { - tx.closed().await; - MapId::Frame(id) - }) as Pin + Send + 'static>> - ); - } - self.frames_sender_map.entry(id).or_insert_with(|| Vec::with_capacity(1)).push(tx.clone()); - if let Some(tcp_task) = self.tcp_tasks_comm.get(&id) { - // TODO: handle error - tcp_task.send(vec![tx]).await.ok(); - } - } - } - - }, - - res = infered_stream_listener_rx.recv() => { - match res { - None => panic!("stream ended"), - Some((name, tx)) => { - let id = hashed(&name); - { - let tx = tx.clone(); - expired_sender_ids.push( - Box::pin(async move { - tx.closed().await; - MapId::Infered(id) - }) as Pin + Send + 'static>> - ); - } - self.infered_frames_sender_map.entry(id).or_insert_with(|| Vec::with_capacity(1)).push(tx.clone()); - if let Some(tcp_task) = self.tcp_tasks_comm.get(&id) { - // TODO: handle error - tcp_task.send(vec![tx]).await.ok(); - } - } - } - - }, - - res = tcp_tasks_comm_rx.recv() => { - match res { - None => panic!("stream ended"), - Some((name, tx)) => { - let id = hashed(&name); - { - let tx = tx.clone(); - expired_sender_ids.push( - Box::pin(async move { - tx.closed().await; - MapId::Tcp(id) - }) as Pin + Send + 'static>> - ); - } - self.tcp_tasks_comm.insert(id, tx.clone()); - if let Some(senders) = self.frames_sender_map.get(&id) { - tx.send(senders.clone()).await.ok(); - } - if let Some(senders) = self.infered_frames_sender_map.get(&id) { - tx.send(senders.clone()).await.ok(); - } - } - } - - }, - } - } - } else { - panic!("trying to run uninitialized registry"); - } - } -} - -enum MapId { - Frame(u64), - Infered(u64), - Tcp(u64), -} diff --git a/infer_server/src/nn.rs b/infer_server/src/nn.rs index 07361e9..cb50e3a 100644 --- a/infer_server/src/nn.rs +++ b/infer_server/src/nn.rs @@ -119,12 +119,6 @@ impl UltrafaceModel { .chunks(4) .map(|x| Bbox::try_from(x).unwrap()); - // TODO: - // - BorrowedBbox<'_> - // - Work with non-sorted data for non_maximum_suppression - // - Preallocate vec in non-max-supp - // - Impl GenericImgView trait for different buffer - // Fuse bounding boxes with confidence scores // Filter out bounding boxes with a confidence score below the threshold let mut bboxes_with_confidences: Vec<_> = bboxes diff --git a/infer_server/src/hour_glass/router.rs b/infer_server/src/router.rs similarity index 99% rename from infer_server/src/hour_glass/router.rs rename to infer_server/src/router.rs index c1e7d3c..ffc0275 100644 --- a/infer_server/src/hour_glass/router.rs +++ b/infer_server/src/router.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Mutex}; use anyhow::{bail, Result}; use common::protocol::ProtoMsg; -use crate::hour_glass::{ +use crate::{ broadcast_channel, hashed, BroadcastReceiver, BroadcastSender, StaticFrameReceiver, StaticImageSender, };