Skip to content

Commit

Permalink
encapsulate the task in a struct, map_err to dafterror
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Oct 31, 2024
1 parent 3aef8eb commit ab26801
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 26 deletions.
49 changes: 40 additions & 9 deletions src/common/runtime/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
use std::{
future::Future,
panic::AssertUnwindSafe,
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, OnceLock,
},
task::{Context, Poll},
};

use common_error::{DaftError, DaftResult};
use futures::FutureExt;
use lazy_static::lazy_static;
use tokio::{runtime::RuntimeFlavor, task::JoinError};
use tokio::{
runtime::{Handle, RuntimeFlavor},
task::JoinSet,
};

lazy_static! {
static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get();
Expand All @@ -31,6 +36,38 @@ enum PoolType {
IO,
}

// A spawned task on a Runtime that can be awaited
// This is a wrapper around a JoinSet that allows us to cancel the task by dropping it
pub struct RuntimeTask<T> {
joinset: JoinSet<T>,
}

impl<T> RuntimeTask<T> {
pub fn new<F>(handle: &Handle, future: F) -> Self
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let mut joinset = JoinSet::new();
joinset.spawn_on(future, handle);
Self { joinset }
}
}

impl<T: Send + 'static> Future for RuntimeTask<T> {
type Output = DaftResult<T>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.joinset).poll_join_next(cx) {
Poll::Ready(Some(result)) => {
Poll::Ready(result.map_err(|e| DaftError::External(e.into())))
}
Poll::Ready(None) => panic!("JoinSet unexpectedly empty"),

Check warning on line 65 in src/common/runtime/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/runtime/src/lib.rs#L65

Added line #L65 was not covered by tests
Poll::Pending => Poll::Pending,
}
}
}

pub struct Runtime {
runtime: tokio::runtime::Runtime,
pool_type: PoolType,
Expand Down Expand Up @@ -87,18 +124,12 @@ impl Runtime {
}

// Spawn a task on the runtime
pub fn spawn<F>(
&self,
future: F,
) -> impl Future<Output = Result<F::Output, JoinError>> + Send + 'static
pub fn spawn<F>(&self, future: F) -> RuntimeTask<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
// Spawn it on a joinset on the runtime, such that if the future gets dropped, the task is cancelled
let mut joinset = tokio::task::JoinSet::new();
joinset.spawn_on(future, self.runtime.handle());
async move { joinset.join_next().await.expect("just spawned task") }.boxed()
RuntimeTask::new(self.runtime.handle(), future)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ use common_display::tree::TreeDisplay;
use common_error::DaftResult;
use common_runtime::get_compute_runtime;
use daft_micropartition::MicroPartition;
use snafu::ResultExt;
use tracing::{info_span, instrument};

use super::buffer::OperatorBuffer;
use crate::{
channel::{create_channel, PipelineChannel, Receiver, Sender},
pipeline::{PipelineNode, PipelineResultType},
runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext},
ExecutionRuntimeHandle, JoinSnafu, NUM_CPUS,
ExecutionRuntimeHandle, NUM_CPUS,
};

pub(crate) trait DynIntermediateOpState: Send + Sync {
Expand Down Expand Up @@ -120,7 +119,7 @@ impl IntermediateNode {
let fut = async move {
rt_context.in_span(&span, || op.execute(idx, &morsel, &state_wrapper))
};
let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??;
let result = compute_runtime.spawn(fut).await??;
match result {
IntermediateOperatorResult::NeedMoreInput(Some(mp)) => {
let _ = sender.send(mp.into()).await;
Expand Down
8 changes: 3 additions & 5 deletions src/daft-local-execution/src/sinks/blocking_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@ use common_display::tree::TreeDisplay;
use common_error::DaftResult;
use common_runtime::get_compute_runtime;
use daft_micropartition::MicroPartition;
use snafu::ResultExt;
use tracing::info_span;

use crate::{
channel::PipelineChannel,
pipeline::{PipelineNode, PipelineResultType},
runtime_stats::RuntimeStatsContext,
ExecutionRuntimeHandle, JoinSnafu,
ExecutionRuntimeHandle,
};
pub enum BlockingSinkStatus {
NeedMoreInput,
Expand Down Expand Up @@ -103,7 +102,7 @@ impl PipelineNode for BlockingSinkNode {
let mut guard = op.lock().await;
rt_context.in_span(&span, || guard.sink(val.as_data()))
};
let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??;
let result = compute_runtime.spawn(fut).await??;
if matches!(result, BlockingSinkStatus::Finished) {
break;
}
Expand All @@ -115,8 +114,7 @@ impl PipelineNode for BlockingSinkNode {
guard.finalize()
})
})
.await
.context(JoinSnafu)??;
.await??;
if let Some(part) = finalized_result {
let _ = destination_sender.send(part).await;
}
Expand Down
5 changes: 2 additions & 3 deletions src/daft-local-execution/src/sinks/streaming_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl StreamingSinkNode {
let fut = async move {
rt_context.in_span(&span, || op.execute(idx, &morsel, state_wrapper.as_ref()))
};
let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??;
let result = compute_runtime.spawn(fut).await??;
match result {
StreamingSinkOutput::NeedMoreInput(mp) => {
if let Some(mp) = mp {
Expand Down Expand Up @@ -286,8 +286,7 @@ impl PipelineNode for StreamingSinkNode {
op.finalize(finished_states)
})
})
.await
.context(JoinSnafu)??;
.await??;
if let Some(res) = finalized_result {
let _ = destination_sender.send(res.into()).await;
}
Expand Down
3 changes: 1 addition & 2 deletions src/daft-local-execution/src/sources/scan_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,7 @@ async fn get_delete_map(
}
Ok(Some(delete_map))
})
.await
.context(JoinSnafu)?
.await?
}

async fn stream_scan_task(
Expand Down
7 changes: 3 additions & 4 deletions src/daft-scan/src/glob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,13 @@ fn run_glob_parallel(
let stream = io_client
.glob(glob_input, None, None, None, io_stats, Some(file_format))
.await?;
let results = stream.collect::<Vec<_>>().await;
Result::<_, daft_io::Error>::Ok(futures::stream::iter(results))
let results = stream.map_err(|e| e.into()).collect::<Vec<_>>().await;
DaftResult::Ok(futures::stream::iter(results))
})
}))
.buffered(num_parallel_tasks)
.map(|v| v.map_err(|e| daft_io::Error::JoinError { source: e })?)
.map(|stream| stream?)
.try_flatten()
.map(|v| Ok(v?))
.boxed();

// Construct a static-lifetime BoxStreamIterator
Expand Down

0 comments on commit ab26801

Please sign in to comment.