Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHORE] Cancel tasks spawned on compute runtime #3128

Merged
merged 3 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 72 additions & 22 deletions src/common/runtime/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
use std::{
future::Future,
panic::AssertUnwindSafe,
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, OnceLock,
},
task::{Context, Poll},
};

use common_error::{DaftError, DaftResult};
use futures::FutureExt;
use lazy_static::lazy_static;
use tokio::{runtime::RuntimeFlavor, task::JoinHandle};
use tokio::{
runtime::{Handle, RuntimeFlavor},
task::JoinSet,
};

lazy_static! {
static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get();
Expand All @@ -31,6 +36,38 @@
IO,
}

// A spawned task on a Runtime that can be awaited
// This is a wrapper around a JoinSet that allows us to cancel the task by dropping it
pub struct RuntimeTask<T> {
joinset: JoinSet<T>,
}

impl<T> RuntimeTask<T> {
pub fn new<F>(handle: &Handle, future: F) -> Self
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let mut joinset = JoinSet::new();
joinset.spawn_on(future, handle);
Self { joinset }
}
}

impl<T: Send + 'static> Future for RuntimeTask<T> {
type Output = DaftResult<T>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.joinset).poll_join_next(cx) {
Poll::Ready(Some(result)) => {
Poll::Ready(result.map_err(|e| DaftError::External(e.into())))
}
Poll::Ready(None) => panic!("JoinSet unexpectedly empty"),

Check warning on line 65 in src/common/runtime/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/runtime/src/lib.rs#L65

Added line #L65 was not covered by tests
Poll::Pending => Poll::Pending,
}
}
}

pub struct Runtime {
runtime: tokio::runtime::Runtime,
pool_type: PoolType,
Expand All @@ -41,7 +78,6 @@
Arc::new(Self { runtime, pool_type })
}

// TODO: figure out a way to cancel the Future if this output is dropped.
async fn execute_task<F>(future: F, pool_type: PoolType) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
Expand Down Expand Up @@ -81,36 +117,19 @@
rx.recv().expect("Spawned task transmitter dropped")
}

/// Spawn a task on the runtime and await on it.
/// You should use this when you are spawning compute or IO tasks from the Executor.
pub async fn await_on<F>(&self, future: F) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = oneshot::channel();
let pool_type = self.pool_type;
let _join_handle = self.spawn(async move {
let task_output = Self::execute_task(future, pool_type).await;
if tx.send(task_output).is_err() {
log::warn!("Spawned task output ignored: receiver dropped");
}
});
rx.await.expect("Spawned task transmitter dropped")
}

/// Blocks current thread to compute future. Can not be called in tokio runtime context
///
pub fn block_on_current_thread<F: Future>(&self, future: F) -> F::Output {
self.runtime.block_on(future)
}

pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
// Spawn a task on the runtime
pub fn spawn<F>(&self, future: F) -> RuntimeTask<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.runtime.spawn(future)
RuntimeTask::new(self.runtime.handle(), future)
}
}

Expand Down Expand Up @@ -183,3 +202,34 @@
Err(_) => None,
}
}

mod tests {

#[tokio::test]
async fn test_spawned_task_cancelled_when_dropped() {
use super::*;

let runtime = get_compute_runtime();
let ptr = Arc::new(AtomicUsize::new(0));
let ptr_clone = ptr.clone();

// Spawn a task that just does work in a loop
// The task should own a reference to the Arc, so the strong count should be 2
let task = async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
ptr_clone.fetch_add(1, Ordering::SeqCst);
}
};
let fut = runtime.spawn(task);
assert!(Arc::strong_count(&ptr) == 2);

// Drop the future, which should cancel the task
drop(fut);

// Wait for a while so that the task can be aborted
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
// The strong count should be 1 now
assert!(Arc::strong_count(&ptr) == 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ impl IntermediateNode {
let fut = async move {
rt_context.in_span(&span, || op.execute(idx, &morsel, &state_wrapper))
};
let result = compute_runtime.await_on(fut).await??;
let result = compute_runtime.spawn(fut).await??;
match result {
IntermediateOperatorResult::NeedMoreInput(Some(mp)) => {
let _ = sender.send(mp.into()).await;
Expand Down
4 changes: 2 additions & 2 deletions src/daft-local-execution/src/sinks/blocking_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,13 @@ impl PipelineNode for BlockingSinkNode {
let mut guard = op.lock().await;
rt_context.in_span(&span, || guard.sink(val.as_data()))
};
let result = compute_runtime.await_on(fut).await??;
let result = compute_runtime.spawn(fut).await??;
if matches!(result, BlockingSinkStatus::Finished) {
break;
}
}
let finalized_result = compute_runtime
.await_on(async move {
.spawn(async move {
let mut guard = op.lock().await;
rt_context.in_span(&info_span!("BlockingSinkNode::finalize"), || {
guard.finalize()
Expand Down
4 changes: 2 additions & 2 deletions src/daft-local-execution/src/sinks/streaming_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl StreamingSinkNode {
let fut = async move {
rt_context.in_span(&span, || op.execute(idx, &morsel, state_wrapper.as_ref()))
};
let result = compute_runtime.await_on(fut).await??;
let result = compute_runtime.spawn(fut).await??;
match result {
StreamingSinkOutput::NeedMoreInput(mp) => {
if let Some(mp) = mp {
Expand Down Expand Up @@ -281,7 +281,7 @@ impl PipelineNode for StreamingSinkNode {

let compute_runtime = get_compute_runtime();
let finalized_result = compute_runtime
.await_on(async move {
.spawn(async move {
runtime_stats.in_span(&info_span!("StreamingSinkNode::finalize"), || {
op.finalize(finished_states)
})
Expand Down
2 changes: 1 addition & 1 deletion src/daft-local-execution/src/sources/scan_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@
.get_io_client_and_runtime()?;
let scan_tasks = scan_tasks.to_vec();
runtime
.await_on(async move {
.spawn(async move {

Check warning on line 142 in src/daft-local-execution/src/sources/scan_task.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sources/scan_task.rs#L142

Added line #L142 was not covered by tests
let mut delete_map = scan_tasks
.iter()
.flat_map(|st| st.sources.iter().map(|s| s.get_path().to_string()))
Expand Down
7 changes: 3 additions & 4 deletions src/daft-scan/src/glob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,13 @@ fn run_glob_parallel(
let stream = io_client
.glob(glob_input, None, None, None, io_stats, Some(file_format))
.await?;
let results = stream.collect::<Vec<_>>().await;
Result::<_, daft_io::Error>::Ok(futures::stream::iter(results))
let results = stream.map_err(|e| e.into()).collect::<Vec<_>>().await;
DaftResult::Ok(futures::stream::iter(results))
})
}))
.buffered(num_parallel_tasks)
.map(|v| v.map_err(|e| daft_io::Error::JoinError { source: e })?)
.map(|stream| stream?)
.try_flatten()
.map(|v| Ok(v?))
.boxed();

// Construct a static-lifetime BoxStreamIterator
Expand Down
Loading