From 369eddde01d94ad2f4af4ad5c6fd2733ce0a7ecd Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Thu, 7 Mar 2024 16:45:26 -0800 Subject: [PATCH] [ENH] Add system scheduler to allow tasks to run with a schedule (#1839) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - ... - New functionality - This PR adds the compatibility to allow components to schedule tasks ## Test plan *How are these changes tested?* - [ ] cargo test ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* --- rust/worker/src/system/executor.rs | 48 ++++--- rust/worker/src/system/mod.rs | 1 + rust/worker/src/system/scheduler.rs | 211 ++++++++++++++++++++++++++++ rust/worker/src/system/system.rs | 25 +++- rust/worker/src/system/types.rs | 24 ++-- 5 files changed, 269 insertions(+), 40 deletions(-) create mode 100644 rust/worker/src/system/scheduler.rs diff --git a/rust/worker/src/system/executor.rs b/rust/worker/src/system/executor.rs index c50ac4a56fa..f5b38779183 100644 --- a/rust/worker/src/system/executor.rs +++ b/rust/worker/src/system/executor.rs @@ -1,14 +1,12 @@ -use std::sync::Arc; - -use tokio::select; - -use crate::system::ComponentContext; - use super::{ + scheduler::Scheduler, sender::{Sender, Wrapper}, system::System, Component, }; +use crate::system::ComponentContext; +use std::sync::Arc; +use tokio::select; struct Inner where @@ -17,6 +15,7 @@ where pub(super) sender: Sender, pub(super) cancellation_token: tokio_util::sync::CancellationToken, pub(super) system: System, + pub(super) scheduler: Scheduler, } #[derive(Clone)] @@ -40,12 +39,14 @@ where cancellation_token: tokio_util::sync::CancellationToken, handler: C, system: System, + scheduler: Scheduler, ) -> Self { ComponentExecutor { inner: Arc::new(Inner { sender, cancellation_token, system, + scheduler, }), handler, } @@ -54,24 +55,25 @@ where pub(super) async fn run(&mut self, mut channel: tokio::sync::mpsc::Receiver>) { loop { select! { - _ = self.inner.cancellation_token.cancelled() => { - break; - } - message = channel.recv() => { - match message { - Some(mut message) => { - message.handle(&mut self.handler, - &ComponentContext{ - system: self.inner.system.clone(), - sender: self.inner.sender.clone(), - cancellation_token: self.inner.cancellation_token.clone(), - } - ).await; - } - None => { - // TODO: Log error - } + _ = self.inner.cancellation_token.cancelled() => { + break; + } + message = channel.recv() => { + match message { + Some(mut message) => { + message.handle(&mut self.handler, + &ComponentContext{ + system: self.inner.system.clone(), + sender: self.inner.sender.clone(), + cancellation_token: self.inner.cancellation_token.clone(), + scheduler: self.inner.scheduler.clone(), + } + ).await; + } + None => { + // TODO: Log error } + } } } } diff --git a/rust/worker/src/system/mod.rs b/rust/worker/src/system/mod.rs index 32ad862f768..9656b0291c2 100644 --- a/rust/worker/src/system/mod.rs +++ b/rust/worker/src/system/mod.rs @@ -1,4 +1,5 @@ mod executor; +mod scheduler; mod sender; mod system; mod types; diff --git a/rust/worker/src/system/scheduler.rs b/rust/worker/src/system/scheduler.rs new file mode 100644 index 00000000000..0168f05ff2a --- /dev/null +++ b/rust/worker/src/system/scheduler.rs @@ -0,0 +1,211 @@ +use parking_lot::RwLock; +use std::fmt::Debug; +use std::num; +use std::sync::Arc; +use std::time::Duration; +use tokio::select; + +use super::{executor, Component, ComponentContext, ComponentHandle, Handler, StreamHandler}; +use super::{ + executor::ComponentExecutor, sender::Sender, system::System, Receiver, ReceiverImpl, Wrapper, +}; + +pub(crate) struct SchedulerTaskHandle { + join_handle: Option>, + cancel: tokio_util::sync::CancellationToken, +} + +#[derive(Clone)] +pub(crate) struct Scheduler { + handles: Arc>>, +} + +impl Scheduler { + pub(crate) fn new() -> Scheduler { + Scheduler { + handles: Arc::new(RwLock::new(Vec::new())), + } + } + + pub(crate) fn schedule( + &self, + sender: Sender, + message: M, + duration: Duration, + ctx: &ComponentContext, + ) where + C: Component + Handler, + M: Debug + Send + 'static, + { + let cancel = ctx.cancellation_token.clone(); + let handle = tokio::spawn(async move { + select! { + _ = cancel.cancelled() => { + return; + } + _ = tokio::time::sleep(duration) => { + match sender.send(message).await { + Ok(_) => { + return; + }, + Err(e) => { + // TODO: log error + println!("Error: {:?}", e); + return; + } + } + } + } + }); + let handle = SchedulerTaskHandle { + join_handle: Some(handle), + cancel: ctx.cancellation_token.clone(), + }; + self.handles.write().push(handle); + } + + pub(crate) fn schedule_interval( + &self, + sender: Sender, + message: M, + duration: Duration, + num_times: Option, + ctx: &ComponentContext, + ) where + C: Component + Handler, + M: Debug + Send + Clone + 'static, + { + let cancel = ctx.cancellation_token.clone(); + + let handle = tokio::spawn(async move { + let mut counter = 0; + while Self::should_continue(num_times, counter) { + select! { + _ = cancel.cancelled() => { + return; + } + _ = tokio::time::sleep(duration) => { + match sender.send(message.clone()).await { + Ok(_) => { + }, + Err(e) => { + // TODO: log error + println!("Error: {:?}", e); + } + } + } + } + counter += 1; + } + }); + let handle = SchedulerTaskHandle { + join_handle: Some(handle), + cancel: ctx.cancellation_token.clone(), + }; + self.handles.write().push(handle); + } + + fn should_continue(num_times: Option, counter: usize) -> bool { + if num_times.is_some() { + let num_times = num_times.unwrap(); + if counter >= num_times { + return false; + } + } + true + } + + // Note: this method holds the lock on the handles, should call it only after stop is + // called. + pub(crate) async fn join(&self) { + let mut handles = self.handles.write(); + for handle in handles.iter_mut() { + if let Some(join_handle) = handle.join_handle.take() { + match join_handle.await { + Ok(_) => {} + Err(e) => { + println!("Error: {:?}", e); + } + } + } + } + } + + pub(crate) fn stop(&self) { + let handles = self.handles.read(); + for handle in handles.iter() { + handle.cancel.cancel(); + } + } +} + +mod tests { + use super::*; + use async_trait::async_trait; + use std::sync::Arc; + use std::time::Duration; + + use std::sync::atomic::{AtomicUsize, Ordering}; + + #[derive(Debug)] + struct TestComponent { + queue_size: usize, + counter: Arc, + } + + #[derive(Clone, Debug)] + struct ScheduleMessage {} + + impl TestComponent { + fn new(queue_size: usize, counter: Arc) -> Self { + TestComponent { + queue_size, + counter, + } + } + } + #[async_trait] + impl Handler for TestComponent { + async fn handle( + &mut self, + _message: ScheduleMessage, + _ctx: &ComponentContext, + ) { + self.counter.fetch_add(1, Ordering::SeqCst); + } + } + + impl Component for TestComponent { + fn queue_size(&self) -> usize { + self.queue_size + } + + fn on_start(&mut self, ctx: &ComponentContext) -> () { + let duration = Duration::from_millis(100); + ctx.scheduler + .schedule(ctx.sender.clone(), ScheduleMessage {}, duration, ctx); + + let num_times = 4; + ctx.scheduler.schedule_interval( + ctx.sender.clone(), + ScheduleMessage {}, + duration, + Some(num_times), + ctx, + ); + } + } + + #[tokio::test] + async fn test_schedule() { + let mut system = System::new(); + let counter = Arc::new(AtomicUsize::new(0)); + let component = TestComponent::new(10, counter.clone()); + let _handle = system.start_component(component); + // yield to allow the component to process the messages + tokio::task::yield_now().await; + // We should have scheduled the message once + system.join().await; + assert_eq!(counter.load(Ordering::SeqCst), 5); + } +} diff --git a/rust/worker/src/system/system.rs b/rust/worker/src/system/system.rs index 238da52a0ee..d2f573e98a9 100644 --- a/rust/worker/src/system/system.rs +++ b/rust/worker/src/system/system.rs @@ -8,22 +8,26 @@ use tokio::{pin, select}; use super::ComponentRuntime; // use super::executor::StreamComponentExecutor; +use super::scheduler::Scheduler; use super::sender::{self, Sender, Wrapper}; use super::{executor, ComponentContext}; use super::{executor::ComponentExecutor, Component, ComponentHandle, Handler, StreamHandler}; -use std::sync::Mutex; #[derive(Clone)] pub(crate) struct System { - inner: Arc>, + inner: Arc, } -struct Inner {} +struct Inner { + scheduler: Scheduler, +} impl System { pub(crate) fn new() -> System { System { - inner: Arc::new(Mutex::new(Inner {})), + inner: Arc::new(Inner { + scheduler: Scheduler::new(), + }), } } @@ -34,16 +38,18 @@ impl System { let (tx, rx) = tokio::sync::mpsc::channel(component.queue_size()); let sender = Sender::new(tx); let cancel_token = tokio_util::sync::CancellationToken::new(); - let _ = component.on_start(&ComponentContext { + let _ = component.on_start(&mut ComponentContext { system: self.clone(), sender: sender.clone(), cancellation_token: cancel_token.clone(), + scheduler: self.inner.scheduler.clone(), }); let mut executor = ComponentExecutor::new( sender.clone(), cancel_token.clone(), component, self.clone(), + self.inner.scheduler.clone(), ); match C::runtime() { @@ -74,9 +80,18 @@ impl System { system: self.clone(), sender: ctx.sender.clone(), cancellation_token: ctx.cancellation_token.clone(), + scheduler: ctx.scheduler.clone(), }; tokio::spawn(async move { stream_loop(stream, &ctx).await }); } + + pub(crate) async fn stop(&self) { + self.inner.scheduler.stop(); + } + + pub(crate) async fn join(&self) { + self.inner.scheduler.join().await; + } } async fn stream_loop(stream: S, ctx: &ComponentContext) diff --git a/rust/worker/src/system/types.rs b/rust/worker/src/system/types.rs index 9c2cd463561..76552287982 100644 --- a/rust/worker/src/system/types.rs +++ b/rust/worker/src/system/types.rs @@ -1,12 +1,9 @@ -use std::{fmt::Debug, sync::Arc}; - +use super::scheduler::Scheduler; use async_trait::async_trait; use futures::Stream; -use tokio::select; +use std::fmt::Debug; -use super::{ - executor::ComponentExecutor, sender::Sender, system::System, Receiver, ReceiverImpl, Wrapper, -}; +use super::{sender::Sender, system::System, Receiver, ReceiverImpl}; #[derive(Debug, PartialEq)] /// The state of a component @@ -135,6 +132,7 @@ where pub(crate) system: System, pub(crate) sender: Sender, pub(crate) cancellation_token: tokio_util::sync::CancellationToken, + pub(crate) scheduler: Scheduler, } #[cfg(test)] @@ -142,6 +140,8 @@ mod tests { use super::*; use async_trait::async_trait; use futures::stream; + use std::sync::Arc; + use std::time::Duration; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -154,8 +154,8 @@ mod tests { impl TestComponent { fn new(queue_size: usize, counter: Arc) -> Self { TestComponent { - queue_size: queue_size, - counter: counter, + queue_size, + counter, } } } @@ -166,12 +166,11 @@ mod tests { self.counter.fetch_add(message, Ordering::SeqCst); } } - impl StreamHandler for TestComponent {} impl Component for TestComponent { fn queue_size(&self) -> usize { - return self.queue_size; + self.queue_size } fn on_start(&mut self, ctx: &ComponentContext) -> () { @@ -191,12 +190,13 @@ mod tests { handle.sender.send(3).await.unwrap(); // yield to allow the component to process the messages tokio::task::yield_now().await; + // With the streaming data and the messages we should have 12 + assert_eq!(counter.load(Ordering::SeqCst), 12); handle.stop(); // Yield to allow the component to stop tokio::task::yield_now().await; + // Expect the component to be stopped assert_eq!(*handle.state(), ComponentState::Stopped); - // With the streaming data and the messages we should have 12 - assert_eq!(counter.load(Ordering::SeqCst), 12); let res = handle.sender.send(4).await; // Expect an error because the component is stopped assert!(res.is_err());