Skip to content

Commit

Permalink
[ENH] Add system scheduler to allow tasks to run with a schedule (chr…
Browse files Browse the repository at this point in the history
…oma-core#1839)

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - ...
 - New functionality
	 - This PR adds the compatibility to allow components to schedule tasks

## Test plan
*How are these changes tested?*

- [ ] cargo test

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
Ishiihara authored Mar 8, 2024
1 parent 2cbbb2c commit 369eddd
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 40 deletions.
48 changes: 25 additions & 23 deletions rust/worker/src/system/executor.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
use std::sync::Arc;

use tokio::select;

use crate::system::ComponentContext;

use super::{
scheduler::Scheduler,
sender::{Sender, Wrapper},
system::System,
Component,
};
use crate::system::ComponentContext;
use std::sync::Arc;
use tokio::select;

struct Inner<C>
where
Expand All @@ -17,6 +15,7 @@ where
pub(super) sender: Sender<C>,
pub(super) cancellation_token: tokio_util::sync::CancellationToken,
pub(super) system: System,
pub(super) scheduler: Scheduler,
}

#[derive(Clone)]
Expand All @@ -40,12 +39,14 @@ where
cancellation_token: tokio_util::sync::CancellationToken,
handler: C,
system: System,
scheduler: Scheduler,
) -> Self {
ComponentExecutor {
inner: Arc::new(Inner {
sender,
cancellation_token,
system,
scheduler,
}),
handler,
}
Expand All @@ -54,24 +55,25 @@ where
pub(super) async fn run(&mut self, mut channel: tokio::sync::mpsc::Receiver<Wrapper<C>>) {
loop {
select! {
_ = self.inner.cancellation_token.cancelled() => {
break;
}
message = channel.recv() => {
match message {
Some(mut message) => {
message.handle(&mut self.handler,
&ComponentContext{
system: self.inner.system.clone(),
sender: self.inner.sender.clone(),
cancellation_token: self.inner.cancellation_token.clone(),
}
).await;
}
None => {
// TODO: Log error
}
_ = self.inner.cancellation_token.cancelled() => {
break;
}
message = channel.recv() => {
match message {
Some(mut message) => {
message.handle(&mut self.handler,
&ComponentContext{
system: self.inner.system.clone(),
sender: self.inner.sender.clone(),
cancellation_token: self.inner.cancellation_token.clone(),
scheduler: self.inner.scheduler.clone(),
}
).await;
}
None => {
// TODO: Log error
}
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions rust/worker/src/system/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod executor;
mod scheduler;
mod sender;
mod system;
mod types;
Expand Down
211 changes: 211 additions & 0 deletions rust/worker/src/system/scheduler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
use parking_lot::RwLock;
use std::fmt::Debug;
use std::num;
use std::sync::Arc;
use std::time::Duration;
use tokio::select;

use super::{executor, Component, ComponentContext, ComponentHandle, Handler, StreamHandler};
use super::{
executor::ComponentExecutor, sender::Sender, system::System, Receiver, ReceiverImpl, Wrapper,
};

pub(crate) struct SchedulerTaskHandle {
join_handle: Option<tokio::task::JoinHandle<()>>,
cancel: tokio_util::sync::CancellationToken,
}

#[derive(Clone)]
pub(crate) struct Scheduler {
handles: Arc<RwLock<Vec<SchedulerTaskHandle>>>,
}

impl Scheduler {
pub(crate) fn new() -> Scheduler {
Scheduler {
handles: Arc::new(RwLock::new(Vec::new())),
}
}

pub(crate) fn schedule<C, M>(
&self,
sender: Sender<C>,
message: M,
duration: Duration,
ctx: &ComponentContext<C>,
) where
C: Component + Handler<M>,
M: Debug + Send + 'static,
{
let cancel = ctx.cancellation_token.clone();
let handle = tokio::spawn(async move {
select! {
_ = cancel.cancelled() => {
return;
}
_ = tokio::time::sleep(duration) => {
match sender.send(message).await {
Ok(_) => {
return;
},
Err(e) => {
// TODO: log error
println!("Error: {:?}", e);
return;
}
}
}
}
});
let handle = SchedulerTaskHandle {
join_handle: Some(handle),
cancel: ctx.cancellation_token.clone(),
};
self.handles.write().push(handle);
}

pub(crate) fn schedule_interval<C, M>(
&self,
sender: Sender<C>,
message: M,
duration: Duration,
num_times: Option<usize>,
ctx: &ComponentContext<C>,
) where
C: Component + Handler<M>,
M: Debug + Send + Clone + 'static,
{
let cancel = ctx.cancellation_token.clone();

let handle = tokio::spawn(async move {
let mut counter = 0;
while Self::should_continue(num_times, counter) {
select! {
_ = cancel.cancelled() => {
return;
}
_ = tokio::time::sleep(duration) => {
match sender.send(message.clone()).await {
Ok(_) => {
},
Err(e) => {
// TODO: log error
println!("Error: {:?}", e);
}
}
}
}
counter += 1;
}
});
let handle = SchedulerTaskHandle {
join_handle: Some(handle),
cancel: ctx.cancellation_token.clone(),
};
self.handles.write().push(handle);
}

fn should_continue(num_times: Option<usize>, counter: usize) -> bool {
if num_times.is_some() {
let num_times = num_times.unwrap();
if counter >= num_times {
return false;
}
}
true
}

// Note: this method holds the lock on the handles, should call it only after stop is
// called.
pub(crate) async fn join(&self) {
let mut handles = self.handles.write();
for handle in handles.iter_mut() {
if let Some(join_handle) = handle.join_handle.take() {
match join_handle.await {
Ok(_) => {}
Err(e) => {
println!("Error: {:?}", e);
}
}
}
}
}

pub(crate) fn stop(&self) {
let handles = self.handles.read();
for handle in handles.iter() {
handle.cancel.cancel();
}
}
}

mod tests {
use super::*;
use async_trait::async_trait;
use std::sync::Arc;
use std::time::Duration;

use std::sync::atomic::{AtomicUsize, Ordering};

#[derive(Debug)]
struct TestComponent {
queue_size: usize,
counter: Arc<AtomicUsize>,
}

#[derive(Clone, Debug)]
struct ScheduleMessage {}

impl TestComponent {
fn new(queue_size: usize, counter: Arc<AtomicUsize>) -> Self {
TestComponent {
queue_size,
counter,
}
}
}
#[async_trait]
impl Handler<ScheduleMessage> for TestComponent {
async fn handle(
&mut self,
_message: ScheduleMessage,
_ctx: &ComponentContext<TestComponent>,
) {
self.counter.fetch_add(1, Ordering::SeqCst);
}
}

impl Component for TestComponent {
fn queue_size(&self) -> usize {
self.queue_size
}

fn on_start(&mut self, ctx: &ComponentContext<TestComponent>) -> () {
let duration = Duration::from_millis(100);
ctx.scheduler
.schedule(ctx.sender.clone(), ScheduleMessage {}, duration, ctx);

let num_times = 4;
ctx.scheduler.schedule_interval(
ctx.sender.clone(),
ScheduleMessage {},
duration,
Some(num_times),
ctx,
);
}
}

#[tokio::test]
async fn test_schedule() {
let mut system = System::new();
let counter = Arc::new(AtomicUsize::new(0));
let component = TestComponent::new(10, counter.clone());
let _handle = system.start_component(component);
// yield to allow the component to process the messages
tokio::task::yield_now().await;
// We should have scheduled the message once
system.join().await;
assert_eq!(counter.load(Ordering::SeqCst), 5);
}
}
25 changes: 20 additions & 5 deletions rust/worker/src/system/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,26 @@ use tokio::{pin, select};

use super::ComponentRuntime;
// use super::executor::StreamComponentExecutor;
use super::scheduler::Scheduler;
use super::sender::{self, Sender, Wrapper};
use super::{executor, ComponentContext};
use super::{executor::ComponentExecutor, Component, ComponentHandle, Handler, StreamHandler};
use std::sync::Mutex;

#[derive(Clone)]
pub(crate) struct System {
inner: Arc<Mutex<Inner>>,
inner: Arc<Inner>,
}

struct Inner {}
struct Inner {
scheduler: Scheduler,
}

impl System {
pub(crate) fn new() -> System {
System {
inner: Arc::new(Mutex::new(Inner {})),
inner: Arc::new(Inner {
scheduler: Scheduler::new(),
}),
}
}

Expand All @@ -34,16 +38,18 @@ impl System {
let (tx, rx) = tokio::sync::mpsc::channel(component.queue_size());
let sender = Sender::new(tx);
let cancel_token = tokio_util::sync::CancellationToken::new();
let _ = component.on_start(&ComponentContext {
let _ = component.on_start(&mut ComponentContext {
system: self.clone(),
sender: sender.clone(),
cancellation_token: cancel_token.clone(),
scheduler: self.inner.scheduler.clone(),
});
let mut executor = ComponentExecutor::new(
sender.clone(),
cancel_token.clone(),
component,
self.clone(),
self.inner.scheduler.clone(),
);

match C::runtime() {
Expand Down Expand Up @@ -74,9 +80,18 @@ impl System {
system: self.clone(),
sender: ctx.sender.clone(),
cancellation_token: ctx.cancellation_token.clone(),
scheduler: ctx.scheduler.clone(),
};
tokio::spawn(async move { stream_loop(stream, &ctx).await });
}

pub(crate) async fn stop(&self) {
self.inner.scheduler.stop();
}

pub(crate) async fn join(&self) {
self.inner.scheduler.join().await;
}
}

async fn stream_loop<C, S, M>(stream: S, ctx: &ComponentContext<C>)
Expand Down
Loading

0 comments on commit 369eddd

Please sign in to comment.