From 3aef8eb7ba1b3b18d120c1cccbfb83aeedb70540 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 25 Oct 2024 14:07:55 -0700 Subject: [PATCH 1/2] cancel on drop --- src/common/runtime/src/lib.rs | 63 ++++++++++++------- .../src/intermediate_ops/intermediate_op.rs | 5 +- .../src/sinks/blocking_sink.rs | 10 +-- .../src/sinks/streaming_sink.rs | 7 ++- .../src/sources/scan_task.rs | 5 +- 5 files changed, 57 insertions(+), 33 deletions(-) diff --git a/src/common/runtime/src/lib.rs b/src/common/runtime/src/lib.rs index 22a7fcdec9..5121895d42 100644 --- a/src/common/runtime/src/lib.rs +++ b/src/common/runtime/src/lib.rs @@ -10,7 +10,7 @@ use std::{ use common_error::{DaftError, DaftResult}; use futures::FutureExt; use lazy_static::lazy_static; -use tokio::{runtime::RuntimeFlavor, task::JoinHandle}; +use tokio::{runtime::RuntimeFlavor, task::JoinError}; lazy_static! { static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get(); @@ -41,7 +41,6 @@ impl Runtime { Arc::new(Self { runtime, pool_type }) } - // TODO: figure out a way to cancel the Future if this output is dropped. async fn execute_task(future: F, pool_type: PoolType) -> DaftResult where F: Future + Send + 'static, @@ -81,36 +80,25 @@ impl Runtime { rx.recv().expect("Spawned task transmitter dropped") } - /// Spawn a task on the runtime and await on it. - /// You should use this when you are spawning compute or IO tasks from the Executor. - pub async fn await_on(&self, future: F) -> DaftResult - where - F: Future + Send + 'static, - F::Output: Send + 'static, - { - let (tx, rx) = oneshot::channel(); - let pool_type = self.pool_type; - let _join_handle = self.spawn(async move { - let task_output = Self::execute_task(future, pool_type).await; - if tx.send(task_output).is_err() { - log::warn!("Spawned task output ignored: receiver dropped"); - } - }); - rx.await.expect("Spawned task transmitter dropped") - } - /// Blocks current thread to compute future. Can not be called in tokio runtime context /// pub fn block_on_current_thread(&self, future: F) -> F::Output { self.runtime.block_on(future) } - pub fn spawn(&self, future: F) -> JoinHandle + // Spawn a task on the runtime + pub fn spawn( + &self, + future: F, + ) -> impl Future> + Send + 'static where F: Future + Send + 'static, F::Output: Send + 'static, { - self.runtime.spawn(future) + // Spawn it on a joinset on the runtime, such that if the future gets dropped, the task is cancelled + let mut joinset = tokio::task::JoinSet::new(); + joinset.spawn_on(future, self.runtime.handle()); + async move { joinset.join_next().await.expect("just spawned task") }.boxed() } } @@ -183,3 +171,34 @@ pub fn get_io_pool_num_threads() -> Option { Err(_) => None, } } + +mod tests { + + #[tokio::test] + async fn test_spawned_task_cancelled_when_dropped() { + use super::*; + + let runtime = get_compute_runtime(); + let ptr = Arc::new(AtomicUsize::new(0)); + let ptr_clone = ptr.clone(); + + // Spawn a task that just does work in a loop + // The task should own a reference to the Arc, so the strong count should be 2 + let task = async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + ptr_clone.fetch_add(1, Ordering::SeqCst); + } + }; + let fut = runtime.spawn(task); + assert!(Arc::strong_count(&ptr) == 2); + + // Drop the future, which should cancel the task + drop(fut); + + // Wait for a while so that the task can be aborted + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + // The strong count should be 1 now + assert!(Arc::strong_count(&ptr) == 1); + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 412d7641a7..155b121066 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -4,6 +4,7 @@ use common_display::tree::TreeDisplay; use common_error::DaftResult; use common_runtime::get_compute_runtime; use daft_micropartition::MicroPartition; +use snafu::ResultExt; use tracing::{info_span, instrument}; use super::buffer::OperatorBuffer; @@ -11,7 +12,7 @@ use crate::{ channel::{create_channel, PipelineChannel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, - ExecutionRuntimeHandle, NUM_CPUS, + ExecutionRuntimeHandle, JoinSnafu, NUM_CPUS, }; pub(crate) trait DynIntermediateOpState: Send + Sync { @@ -119,7 +120,7 @@ impl IntermediateNode { let fut = async move { rt_context.in_span(&span, || op.execute(idx, &morsel, &state_wrapper)) }; - let result = compute_runtime.await_on(fut).await??; + let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??; match result { IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { let _ = sender.send(mp.into()).await; diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 3fcbf8d660..b161206fed 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -4,13 +4,14 @@ use common_display::tree::TreeDisplay; use common_error::DaftResult; use common_runtime::get_compute_runtime; use daft_micropartition::MicroPartition; +use snafu::ResultExt; use tracing::info_span; use crate::{ channel::PipelineChannel, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::RuntimeStatsContext, - ExecutionRuntimeHandle, + ExecutionRuntimeHandle, JoinSnafu, }; pub enum BlockingSinkStatus { NeedMoreInput, @@ -102,19 +103,20 @@ impl PipelineNode for BlockingSinkNode { let mut guard = op.lock().await; rt_context.in_span(&span, || guard.sink(val.as_data())) }; - let result = compute_runtime.await_on(fut).await??; + let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??; if matches!(result, BlockingSinkStatus::Finished) { break; } } let finalized_result = compute_runtime - .await_on(async move { + .spawn(async move { let mut guard = op.lock().await; rt_context.in_span(&info_span!("BlockingSinkNode::finalize"), || { guard.finalize() }) }) - .await??; + .await + .context(JoinSnafu)??; if let Some(part) = finalized_result { let _ = destination_sender.send(part).await; } diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 102fd39618..81fd7a47d5 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -125,7 +125,7 @@ impl StreamingSinkNode { let fut = async move { rt_context.in_span(&span, || op.execute(idx, &morsel, state_wrapper.as_ref())) }; - let result = compute_runtime.await_on(fut).await??; + let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??; match result { StreamingSinkOutput::NeedMoreInput(mp) => { if let Some(mp) = mp { @@ -281,12 +281,13 @@ impl PipelineNode for StreamingSinkNode { let compute_runtime = get_compute_runtime(); let finalized_result = compute_runtime - .await_on(async move { + .spawn(async move { runtime_stats.in_span(&info_span!("StreamingSinkNode::finalize"), || { op.finalize(finished_states) }) }) - .await??; + .await + .context(JoinSnafu)??; if let Some(res) = finalized_result { let _ = destination_sender.send(res.into()).await; } diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index 3be2f61691..a4d30e6cf3 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -139,7 +139,7 @@ async fn get_delete_map( .get_io_client_and_runtime()?; let scan_tasks = scan_tasks.to_vec(); runtime - .await_on(async move { + .spawn(async move { let mut delete_map = scan_tasks .iter() .flat_map(|st| st.sources.iter().map(|s| s.get_path().to_string())) @@ -183,7 +183,8 @@ async fn get_delete_map( } Ok(Some(delete_map)) }) - .await? + .await + .context(JoinSnafu)? } async fn stream_scan_task( From ab268013771c223bcea2a31c039a6d9b11e0ab22 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 31 Oct 2024 12:16:15 -0700 Subject: [PATCH 2/2] encapsulate the task in a struct, map_err to dafterror --- src/common/runtime/src/lib.rs | 49 +++++++++++++++---- .../src/intermediate_ops/intermediate_op.rs | 5 +- .../src/sinks/blocking_sink.rs | 8 ++- .../src/sinks/streaming_sink.rs | 5 +- .../src/sources/scan_task.rs | 3 +- src/daft-scan/src/glob.rs | 7 ++- 6 files changed, 51 insertions(+), 26 deletions(-) diff --git a/src/common/runtime/src/lib.rs b/src/common/runtime/src/lib.rs index 5121895d42..41803f06d7 100644 --- a/src/common/runtime/src/lib.rs +++ b/src/common/runtime/src/lib.rs @@ -1,16 +1,21 @@ use std::{ future::Future, panic::AssertUnwindSafe, + pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, Arc, OnceLock, }, + task::{Context, Poll}, }; use common_error::{DaftError, DaftResult}; use futures::FutureExt; use lazy_static::lazy_static; -use tokio::{runtime::RuntimeFlavor, task::JoinError}; +use tokio::{ + runtime::{Handle, RuntimeFlavor}, + task::JoinSet, +}; lazy_static! { static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get(); @@ -31,6 +36,38 @@ enum PoolType { IO, } +// A spawned task on a Runtime that can be awaited +// This is a wrapper around a JoinSet that allows us to cancel the task by dropping it +pub struct RuntimeTask { + joinset: JoinSet, +} + +impl RuntimeTask { + pub fn new(handle: &Handle, future: F) -> Self + where + F: Future + Send + 'static, + T: Send + 'static, + { + let mut joinset = JoinSet::new(); + joinset.spawn_on(future, handle); + Self { joinset } + } +} + +impl Future for RuntimeTask { + type Output = DaftResult; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.joinset).poll_join_next(cx) { + Poll::Ready(Some(result)) => { + Poll::Ready(result.map_err(|e| DaftError::External(e.into()))) + } + Poll::Ready(None) => panic!("JoinSet unexpectedly empty"), + Poll::Pending => Poll::Pending, + } + } +} + pub struct Runtime { runtime: tokio::runtime::Runtime, pool_type: PoolType, @@ -87,18 +124,12 @@ impl Runtime { } // Spawn a task on the runtime - pub fn spawn( - &self, - future: F, - ) -> impl Future> + Send + 'static + pub fn spawn(&self, future: F) -> RuntimeTask where F: Future + Send + 'static, F::Output: Send + 'static, { - // Spawn it on a joinset on the runtime, such that if the future gets dropped, the task is cancelled - let mut joinset = tokio::task::JoinSet::new(); - joinset.spawn_on(future, self.runtime.handle()); - async move { joinset.join_next().await.expect("just spawned task") }.boxed() + RuntimeTask::new(self.runtime.handle(), future) } } diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 155b121066..de9d41d740 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -4,7 +4,6 @@ use common_display::tree::TreeDisplay; use common_error::DaftResult; use common_runtime::get_compute_runtime; use daft_micropartition::MicroPartition; -use snafu::ResultExt; use tracing::{info_span, instrument}; use super::buffer::OperatorBuffer; @@ -12,7 +11,7 @@ use crate::{ channel::{create_channel, PipelineChannel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, - ExecutionRuntimeHandle, JoinSnafu, NUM_CPUS, + ExecutionRuntimeHandle, NUM_CPUS, }; pub(crate) trait DynIntermediateOpState: Send + Sync { @@ -120,7 +119,7 @@ impl IntermediateNode { let fut = async move { rt_context.in_span(&span, || op.execute(idx, &morsel, &state_wrapper)) }; - let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??; + let result = compute_runtime.spawn(fut).await??; match result { IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { let _ = sender.send(mp.into()).await; diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index b161206fed..00ffb04a50 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -4,14 +4,13 @@ use common_display::tree::TreeDisplay; use common_error::DaftResult; use common_runtime::get_compute_runtime; use daft_micropartition::MicroPartition; -use snafu::ResultExt; use tracing::info_span; use crate::{ channel::PipelineChannel, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::RuntimeStatsContext, - ExecutionRuntimeHandle, JoinSnafu, + ExecutionRuntimeHandle, }; pub enum BlockingSinkStatus { NeedMoreInput, @@ -103,7 +102,7 @@ impl PipelineNode for BlockingSinkNode { let mut guard = op.lock().await; rt_context.in_span(&span, || guard.sink(val.as_data())) }; - let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??; + let result = compute_runtime.spawn(fut).await??; if matches!(result, BlockingSinkStatus::Finished) { break; } @@ -115,8 +114,7 @@ impl PipelineNode for BlockingSinkNode { guard.finalize() }) }) - .await - .context(JoinSnafu)??; + .await??; if let Some(part) = finalized_result { let _ = destination_sender.send(part).await; } diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 81fd7a47d5..9ebab6f1a3 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -125,7 +125,7 @@ impl StreamingSinkNode { let fut = async move { rt_context.in_span(&span, || op.execute(idx, &morsel, state_wrapper.as_ref())) }; - let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??; + let result = compute_runtime.spawn(fut).await??; match result { StreamingSinkOutput::NeedMoreInput(mp) => { if let Some(mp) = mp { @@ -286,8 +286,7 @@ impl PipelineNode for StreamingSinkNode { op.finalize(finished_states) }) }) - .await - .context(JoinSnafu)??; + .await??; if let Some(res) = finalized_result { let _ = destination_sender.send(res.into()).await; } diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index a4d30e6cf3..1bd4bac0d7 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -183,8 +183,7 @@ async fn get_delete_map( } Ok(Some(delete_map)) }) - .await - .context(JoinSnafu)? + .await? } async fn stream_scan_task( diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index 883de475eb..f92ab78b1e 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -111,14 +111,13 @@ fn run_glob_parallel( let stream = io_client .glob(glob_input, None, None, None, io_stats, Some(file_format)) .await?; - let results = stream.collect::>().await; - Result::<_, daft_io::Error>::Ok(futures::stream::iter(results)) + let results = stream.map_err(|e| e.into()).collect::>().await; + DaftResult::Ok(futures::stream::iter(results)) }) })) .buffered(num_parallel_tasks) - .map(|v| v.map_err(|e| daft_io::Error::JoinError { source: e })?) + .map(|stream| stream?) .try_flatten() - .map(|v| Ok(v?)) .boxed(); // Construct a static-lifetime BoxStreamIterator