diff --git a/src/common/runtime/src/lib.rs b/src/common/runtime/src/lib.rs index 5121895d42..41803f06d7 100644 --- a/src/common/runtime/src/lib.rs +++ b/src/common/runtime/src/lib.rs @@ -1,16 +1,21 @@ use std::{ future::Future, panic::AssertUnwindSafe, + pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, Arc, OnceLock, }, + task::{Context, Poll}, }; use common_error::{DaftError, DaftResult}; use futures::FutureExt; use lazy_static::lazy_static; -use tokio::{runtime::RuntimeFlavor, task::JoinError}; +use tokio::{ + runtime::{Handle, RuntimeFlavor}, + task::JoinSet, +}; lazy_static! { static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get(); @@ -31,6 +36,38 @@ enum PoolType { IO, } +// A spawned task on a Runtime that can be awaited +// This is a wrapper around a JoinSet that allows us to cancel the task by dropping it +pub struct RuntimeTask { + joinset: JoinSet, +} + +impl RuntimeTask { + pub fn new(handle: &Handle, future: F) -> Self + where + F: Future + Send + 'static, + T: Send + 'static, + { + let mut joinset = JoinSet::new(); + joinset.spawn_on(future, handle); + Self { joinset } + } +} + +impl Future for RuntimeTask { + type Output = DaftResult; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.joinset).poll_join_next(cx) { + Poll::Ready(Some(result)) => { + Poll::Ready(result.map_err(|e| DaftError::External(e.into()))) + } + Poll::Ready(None) => panic!("JoinSet unexpectedly empty"), + Poll::Pending => Poll::Pending, + } + } +} + pub struct Runtime { runtime: tokio::runtime::Runtime, pool_type: PoolType, @@ -87,18 +124,12 @@ impl Runtime { } // Spawn a task on the runtime - pub fn spawn( - &self, - future: F, - ) -> impl Future> + Send + 'static + pub fn spawn(&self, future: F) -> RuntimeTask where F: Future + Send + 'static, F::Output: Send + 'static, { - // Spawn it on a joinset on the runtime, such that if the future gets dropped, the task is cancelled - let mut joinset = tokio::task::JoinSet::new(); - joinset.spawn_on(future, self.runtime.handle()); - async move { joinset.join_next().await.expect("just spawned task") }.boxed() + RuntimeTask::new(self.runtime.handle(), future) } } diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 155b121066..de9d41d740 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -4,7 +4,6 @@ use common_display::tree::TreeDisplay; use common_error::DaftResult; use common_runtime::get_compute_runtime; use daft_micropartition::MicroPartition; -use snafu::ResultExt; use tracing::{info_span, instrument}; use super::buffer::OperatorBuffer; @@ -12,7 +11,7 @@ use crate::{ channel::{create_channel, PipelineChannel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, - ExecutionRuntimeHandle, JoinSnafu, NUM_CPUS, + ExecutionRuntimeHandle, NUM_CPUS, }; pub(crate) trait DynIntermediateOpState: Send + Sync { @@ -120,7 +119,7 @@ impl IntermediateNode { let fut = async move { rt_context.in_span(&span, || op.execute(idx, &morsel, &state_wrapper)) }; - let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??; + let result = compute_runtime.spawn(fut).await??; match result { IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { let _ = sender.send(mp.into()).await; diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index b161206fed..00ffb04a50 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -4,14 +4,13 @@ use common_display::tree::TreeDisplay; use common_error::DaftResult; use common_runtime::get_compute_runtime; use daft_micropartition::MicroPartition; -use snafu::ResultExt; use tracing::info_span; use crate::{ channel::PipelineChannel, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::RuntimeStatsContext, - ExecutionRuntimeHandle, JoinSnafu, + ExecutionRuntimeHandle, }; pub enum BlockingSinkStatus { NeedMoreInput, @@ -103,7 +102,7 @@ impl PipelineNode for BlockingSinkNode { let mut guard = op.lock().await; rt_context.in_span(&span, || guard.sink(val.as_data())) }; - let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??; + let result = compute_runtime.spawn(fut).await??; if matches!(result, BlockingSinkStatus::Finished) { break; } @@ -115,8 +114,7 @@ impl PipelineNode for BlockingSinkNode { guard.finalize() }) }) - .await - .context(JoinSnafu)??; + .await??; if let Some(part) = finalized_result { let _ = destination_sender.send(part).await; } diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 81fd7a47d5..9ebab6f1a3 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -125,7 +125,7 @@ impl StreamingSinkNode { let fut = async move { rt_context.in_span(&span, || op.execute(idx, &morsel, state_wrapper.as_ref())) }; - let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??; + let result = compute_runtime.spawn(fut).await??; match result { StreamingSinkOutput::NeedMoreInput(mp) => { if let Some(mp) = mp { @@ -286,8 +286,7 @@ impl PipelineNode for StreamingSinkNode { op.finalize(finished_states) }) }) - .await - .context(JoinSnafu)??; + .await??; if let Some(res) = finalized_result { let _ = destination_sender.send(res.into()).await; } diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index a4d30e6cf3..1bd4bac0d7 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -183,8 +183,7 @@ async fn get_delete_map( } Ok(Some(delete_map)) }) - .await - .context(JoinSnafu)? + .await? } async fn stream_scan_task( diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index 883de475eb..f92ab78b1e 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -111,14 +111,13 @@ fn run_glob_parallel( let stream = io_client .glob(glob_input, None, None, None, io_stats, Some(file_format)) .await?; - let results = stream.collect::>().await; - Result::<_, daft_io::Error>::Ok(futures::stream::iter(results)) + let results = stream.map_err(|e| e.into()).collect::>().await; + DaftResult::Ok(futures::stream::iter(results)) }) })) .buffered(num_parallel_tasks) - .map(|v| v.map_err(|e| daft_io::Error::JoinError { source: e })?) + .map(|stream| stream?) .try_flatten() - .map(|v| Ok(v?)) .boxed(); // Construct a static-lifetime BoxStreamIterator