forked from chroma-core/chroma
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] Add system scheduler to allow tasks to run with a schedule (chr…
…oma-core#1839) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - ... - New functionality - This PR adds the compatibility to allow components to schedule tasks ## Test plan *How are these changes tested?* - [ ] cargo test ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?*
- Loading branch information
Showing
5 changed files
with
269 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
mod executor; | ||
mod scheduler; | ||
mod sender; | ||
mod system; | ||
mod types; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
use parking_lot::RwLock; | ||
use std::fmt::Debug; | ||
use std::num; | ||
use std::sync::Arc; | ||
use std::time::Duration; | ||
use tokio::select; | ||
|
||
use super::{executor, Component, ComponentContext, ComponentHandle, Handler, StreamHandler}; | ||
use super::{ | ||
executor::ComponentExecutor, sender::Sender, system::System, Receiver, ReceiverImpl, Wrapper, | ||
}; | ||
|
||
pub(crate) struct SchedulerTaskHandle { | ||
join_handle: Option<tokio::task::JoinHandle<()>>, | ||
cancel: tokio_util::sync::CancellationToken, | ||
} | ||
|
||
#[derive(Clone)] | ||
pub(crate) struct Scheduler { | ||
handles: Arc<RwLock<Vec<SchedulerTaskHandle>>>, | ||
} | ||
|
||
impl Scheduler { | ||
pub(crate) fn new() -> Scheduler { | ||
Scheduler { | ||
handles: Arc::new(RwLock::new(Vec::new())), | ||
} | ||
} | ||
|
||
pub(crate) fn schedule<C, M>( | ||
&self, | ||
sender: Sender<C>, | ||
message: M, | ||
duration: Duration, | ||
ctx: &ComponentContext<C>, | ||
) where | ||
C: Component + Handler<M>, | ||
M: Debug + Send + 'static, | ||
{ | ||
let cancel = ctx.cancellation_token.clone(); | ||
let handle = tokio::spawn(async move { | ||
select! { | ||
_ = cancel.cancelled() => { | ||
return; | ||
} | ||
_ = tokio::time::sleep(duration) => { | ||
match sender.send(message).await { | ||
Ok(_) => { | ||
return; | ||
}, | ||
Err(e) => { | ||
// TODO: log error | ||
println!("Error: {:?}", e); | ||
return; | ||
} | ||
} | ||
} | ||
} | ||
}); | ||
let handle = SchedulerTaskHandle { | ||
join_handle: Some(handle), | ||
cancel: ctx.cancellation_token.clone(), | ||
}; | ||
self.handles.write().push(handle); | ||
} | ||
|
||
pub(crate) fn schedule_interval<C, M>( | ||
&self, | ||
sender: Sender<C>, | ||
message: M, | ||
duration: Duration, | ||
num_times: Option<usize>, | ||
ctx: &ComponentContext<C>, | ||
) where | ||
C: Component + Handler<M>, | ||
M: Debug + Send + Clone + 'static, | ||
{ | ||
let cancel = ctx.cancellation_token.clone(); | ||
|
||
let handle = tokio::spawn(async move { | ||
let mut counter = 0; | ||
while Self::should_continue(num_times, counter) { | ||
select! { | ||
_ = cancel.cancelled() => { | ||
return; | ||
} | ||
_ = tokio::time::sleep(duration) => { | ||
match sender.send(message.clone()).await { | ||
Ok(_) => { | ||
}, | ||
Err(e) => { | ||
// TODO: log error | ||
println!("Error: {:?}", e); | ||
} | ||
} | ||
} | ||
} | ||
counter += 1; | ||
} | ||
}); | ||
let handle = SchedulerTaskHandle { | ||
join_handle: Some(handle), | ||
cancel: ctx.cancellation_token.clone(), | ||
}; | ||
self.handles.write().push(handle); | ||
} | ||
|
||
fn should_continue(num_times: Option<usize>, counter: usize) -> bool { | ||
if num_times.is_some() { | ||
let num_times = num_times.unwrap(); | ||
if counter >= num_times { | ||
return false; | ||
} | ||
} | ||
true | ||
} | ||
|
||
// Note: this method holds the lock on the handles, should call it only after stop is | ||
// called. | ||
pub(crate) async fn join(&self) { | ||
let mut handles = self.handles.write(); | ||
for handle in handles.iter_mut() { | ||
if let Some(join_handle) = handle.join_handle.take() { | ||
match join_handle.await { | ||
Ok(_) => {} | ||
Err(e) => { | ||
println!("Error: {:?}", e); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
pub(crate) fn stop(&self) { | ||
let handles = self.handles.read(); | ||
for handle in handles.iter() { | ||
handle.cancel.cancel(); | ||
} | ||
} | ||
} | ||
|
||
mod tests { | ||
use super::*; | ||
use async_trait::async_trait; | ||
use std::sync::Arc; | ||
use std::time::Duration; | ||
|
||
use std::sync::atomic::{AtomicUsize, Ordering}; | ||
|
||
#[derive(Debug)] | ||
struct TestComponent { | ||
queue_size: usize, | ||
counter: Arc<AtomicUsize>, | ||
} | ||
|
||
#[derive(Clone, Debug)] | ||
struct ScheduleMessage {} | ||
|
||
impl TestComponent { | ||
fn new(queue_size: usize, counter: Arc<AtomicUsize>) -> Self { | ||
TestComponent { | ||
queue_size, | ||
counter, | ||
} | ||
} | ||
} | ||
#[async_trait] | ||
impl Handler<ScheduleMessage> for TestComponent { | ||
async fn handle( | ||
&mut self, | ||
_message: ScheduleMessage, | ||
_ctx: &ComponentContext<TestComponent>, | ||
) { | ||
self.counter.fetch_add(1, Ordering::SeqCst); | ||
} | ||
} | ||
|
||
impl Component for TestComponent { | ||
fn queue_size(&self) -> usize { | ||
self.queue_size | ||
} | ||
|
||
fn on_start(&mut self, ctx: &ComponentContext<TestComponent>) -> () { | ||
let duration = Duration::from_millis(100); | ||
ctx.scheduler | ||
.schedule(ctx.sender.clone(), ScheduleMessage {}, duration, ctx); | ||
|
||
let num_times = 4; | ||
ctx.scheduler.schedule_interval( | ||
ctx.sender.clone(), | ||
ScheduleMessage {}, | ||
duration, | ||
Some(num_times), | ||
ctx, | ||
); | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_schedule() { | ||
let mut system = System::new(); | ||
let counter = Arc::new(AtomicUsize::new(0)); | ||
let component = TestComponent::new(10, counter.clone()); | ||
let _handle = system.start_component(component); | ||
// yield to allow the component to process the messages | ||
tokio::task::yield_now().await; | ||
// We should have scheduled the message once | ||
system.join().await; | ||
assert_eq!(counter.load(Ordering::SeqCst), 5); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.