diff --git a/src/common/runtime/src/lib.rs b/src/common/runtime/src/lib.rs index 22a7fcdec9..41803f06d7 100644 --- a/src/common/runtime/src/lib.rs +++ b/src/common/runtime/src/lib.rs @@ -1,16 +1,21 @@ use std::{ future::Future, panic::AssertUnwindSafe, + pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, Arc, OnceLock, }, + task::{Context, Poll}, }; use common_error::{DaftError, DaftResult}; use futures::FutureExt; use lazy_static::lazy_static; -use tokio::{runtime::RuntimeFlavor, task::JoinHandle}; +use tokio::{ + runtime::{Handle, RuntimeFlavor}, + task::JoinSet, +}; lazy_static! { static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get(); @@ -31,6 +36,38 @@ enum PoolType { IO, } +// A spawned task on a Runtime that can be awaited +// This is a wrapper around a JoinSet that allows us to cancel the task by dropping it +pub struct RuntimeTask { + joinset: JoinSet, +} + +impl RuntimeTask { + pub fn new(handle: &Handle, future: F) -> Self + where + F: Future + Send + 'static, + T: Send + 'static, + { + let mut joinset = JoinSet::new(); + joinset.spawn_on(future, handle); + Self { joinset } + } +} + +impl Future for RuntimeTask { + type Output = DaftResult; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.joinset).poll_join_next(cx) { + Poll::Ready(Some(result)) => { + Poll::Ready(result.map_err(|e| DaftError::External(e.into()))) + } + Poll::Ready(None) => panic!("JoinSet unexpectedly empty"), + Poll::Pending => Poll::Pending, + } + } +} + pub struct Runtime { runtime: tokio::runtime::Runtime, pool_type: PoolType, @@ -41,7 +78,6 @@ impl Runtime { Arc::new(Self { runtime, pool_type }) } - // TODO: figure out a way to cancel the Future if this output is dropped. async fn execute_task(future: F, pool_type: PoolType) -> DaftResult where F: Future + Send + 'static, @@ -81,36 +117,19 @@ impl Runtime { rx.recv().expect("Spawned task transmitter dropped") } - /// Spawn a task on the runtime and await on it. - /// You should use this when you are spawning compute or IO tasks from the Executor. - pub async fn await_on(&self, future: F) -> DaftResult - where - F: Future + Send + 'static, - F::Output: Send + 'static, - { - let (tx, rx) = oneshot::channel(); - let pool_type = self.pool_type; - let _join_handle = self.spawn(async move { - let task_output = Self::execute_task(future, pool_type).await; - if tx.send(task_output).is_err() { - log::warn!("Spawned task output ignored: receiver dropped"); - } - }); - rx.await.expect("Spawned task transmitter dropped") - } - /// Blocks current thread to compute future. Can not be called in tokio runtime context /// pub fn block_on_current_thread(&self, future: F) -> F::Output { self.runtime.block_on(future) } - pub fn spawn(&self, future: F) -> JoinHandle + // Spawn a task on the runtime + pub fn spawn(&self, future: F) -> RuntimeTask where F: Future + Send + 'static, F::Output: Send + 'static, { - self.runtime.spawn(future) + RuntimeTask::new(self.runtime.handle(), future) } } @@ -183,3 +202,34 @@ pub fn get_io_pool_num_threads() -> Option { Err(_) => None, } } + +mod tests { + + #[tokio::test] + async fn test_spawned_task_cancelled_when_dropped() { + use super::*; + + let runtime = get_compute_runtime(); + let ptr = Arc::new(AtomicUsize::new(0)); + let ptr_clone = ptr.clone(); + + // Spawn a task that just does work in a loop + // The task should own a reference to the Arc, so the strong count should be 2 + let task = async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + ptr_clone.fetch_add(1, Ordering::SeqCst); + } + }; + let fut = runtime.spawn(task); + assert!(Arc::strong_count(&ptr) == 2); + + // Drop the future, which should cancel the task + drop(fut); + + // Wait for a while so that the task can be aborted + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + // The strong count should be 1 now + assert!(Arc::strong_count(&ptr) == 1); + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index d7503fa33c..d268120c06 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -129,7 +129,7 @@ impl IntermediateNode { let fut = async move { rt_context.in_span(&span, || op.execute(idx, &morsel, &state_wrapper)) }; - let result = compute_runtime.await_on(fut).await??; + let result = compute_runtime.spawn(fut).await??; match result { IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { let _ = sender.send(mp.into()).await; diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 7835fbe138..51ef590d09 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -80,7 +80,7 @@ impl BlockingSinkNode { let span = span.clone(); let rt_context = rt_context.clone(); let fut = async move { rt_context.in_span(&span, || op.sink(morsel.as_data(), state)) }; - let result = compute_runtime.await_on(fut).await??; + let result = compute_runtime.spawn(fut).await??; match result { BlockingSinkStatus::NeedMoreInput(new_state) => { state = new_state; @@ -179,7 +179,7 @@ impl PipelineNode for BlockingSinkNode { let compute_runtime = get_compute_runtime(); let finalized_result = compute_runtime - .await_on(async move { + .spawn(async move { runtime_stats.in_span(&info_span!("BlockingSinkNode::finalize"), || { op.finalize(finished_states) }) diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 102fd39618..9ebab6f1a3 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -125,7 +125,7 @@ impl StreamingSinkNode { let fut = async move { rt_context.in_span(&span, || op.execute(idx, &morsel, state_wrapper.as_ref())) }; - let result = compute_runtime.await_on(fut).await??; + let result = compute_runtime.spawn(fut).await??; match result { StreamingSinkOutput::NeedMoreInput(mp) => { if let Some(mp) = mp { @@ -281,7 +281,7 @@ impl PipelineNode for StreamingSinkNode { let compute_runtime = get_compute_runtime(); let finalized_result = compute_runtime - .await_on(async move { + .spawn(async move { runtime_stats.in_span(&info_span!("StreamingSinkNode::finalize"), || { op.finalize(finished_states) }) diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index 3be2f61691..1bd4bac0d7 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -139,7 +139,7 @@ async fn get_delete_map( .get_io_client_and_runtime()?; let scan_tasks = scan_tasks.to_vec(); runtime - .await_on(async move { + .spawn(async move { let mut delete_map = scan_tasks .iter() .flat_map(|st| st.sources.iter().map(|s| s.get_path().to_string())) diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index 883de475eb..f92ab78b1e 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -111,14 +111,13 @@ fn run_glob_parallel( let stream = io_client .glob(glob_input, None, None, None, io_stats, Some(file_format)) .await?; - let results = stream.collect::>().await; - Result::<_, daft_io::Error>::Ok(futures::stream::iter(results)) + let results = stream.map_err(|e| e.into()).collect::>().await; + DaftResult::Ok(futures::stream::iter(results)) }) })) .buffered(num_parallel_tasks) - .map(|v| v.map_err(|e| daft_io::Error::JoinError { source: e })?) + .map(|stream| stream?) .try_flatten() - .map(|v| Ok(v?)) .boxed(); // Construct a static-lifetime BoxStreamIterator