Skip to content

Commit

Permalink
cancel on drop
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Oct 25, 2024
1 parent 5b450fb commit 3aef8eb
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 33 deletions.
63 changes: 41 additions & 22 deletions src/common/runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{
use common_error::{DaftError, DaftResult};
use futures::FutureExt;
use lazy_static::lazy_static;
use tokio::{runtime::RuntimeFlavor, task::JoinHandle};
use tokio::{runtime::RuntimeFlavor, task::JoinError};

lazy_static! {
static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get();
Expand Down Expand Up @@ -41,7 +41,6 @@ impl Runtime {
Arc::new(Self { runtime, pool_type })
}

// TODO: figure out a way to cancel the Future if this output is dropped.
async fn execute_task<F>(future: F, pool_type: PoolType) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
Expand Down Expand Up @@ -81,36 +80,25 @@ impl Runtime {
rx.recv().expect("Spawned task transmitter dropped")
}

/// Spawn a task on the runtime and await on it.
/// You should use this when you are spawning compute or IO tasks from the Executor.
pub async fn await_on<F>(&self, future: F) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = oneshot::channel();
let pool_type = self.pool_type;
let _join_handle = self.spawn(async move {
let task_output = Self::execute_task(future, pool_type).await;
if tx.send(task_output).is_err() {
log::warn!("Spawned task output ignored: receiver dropped");
}
});
rx.await.expect("Spawned task transmitter dropped")
}

/// Blocks current thread to compute future. Can not be called in tokio runtime context
///
pub fn block_on_current_thread<F: Future>(&self, future: F) -> F::Output {
self.runtime.block_on(future)
}

pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
// Spawn a task on the runtime
pub fn spawn<F>(
&self,
future: F,
) -> impl Future<Output = Result<F::Output, JoinError>> + Send + 'static
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.runtime.spawn(future)
// Spawn it on a joinset on the runtime, such that if the future gets dropped, the task is cancelled
let mut joinset = tokio::task::JoinSet::new();
joinset.spawn_on(future, self.runtime.handle());
async move { joinset.join_next().await.expect("just spawned task") }.boxed()
}
}

Expand Down Expand Up @@ -183,3 +171,34 @@ pub fn get_io_pool_num_threads() -> Option<usize> {
Err(_) => None,
}
}

mod tests {

#[tokio::test]
async fn test_spawned_task_cancelled_when_dropped() {
use super::*;

let runtime = get_compute_runtime();
let ptr = Arc::new(AtomicUsize::new(0));
let ptr_clone = ptr.clone();

// Spawn a task that just does work in a loop
// The task should own a reference to the Arc, so the strong count should be 2
let task = async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
ptr_clone.fetch_add(1, Ordering::SeqCst);
}
};
let fut = runtime.spawn(task);
assert!(Arc::strong_count(&ptr) == 2);

// Drop the future, which should cancel the task
drop(fut);

// Wait for a while so that the task can be aborted
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
// The strong count should be 1 now
assert!(Arc::strong_count(&ptr) == 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ use common_display::tree::TreeDisplay;
use common_error::DaftResult;
use common_runtime::get_compute_runtime;
use daft_micropartition::MicroPartition;
use snafu::ResultExt;
use tracing::{info_span, instrument};

use super::buffer::OperatorBuffer;
use crate::{
channel::{create_channel, PipelineChannel, Receiver, Sender},
pipeline::{PipelineNode, PipelineResultType},
runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext},
ExecutionRuntimeHandle, NUM_CPUS,
ExecutionRuntimeHandle, JoinSnafu, NUM_CPUS,
};

pub(crate) trait DynIntermediateOpState: Send + Sync {
Expand Down Expand Up @@ -119,7 +120,7 @@ impl IntermediateNode {
let fut = async move {
rt_context.in_span(&span, || op.execute(idx, &morsel, &state_wrapper))
};
let result = compute_runtime.await_on(fut).await??;
let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??;
match result {
IntermediateOperatorResult::NeedMoreInput(Some(mp)) => {
let _ = sender.send(mp.into()).await;
Expand Down
10 changes: 6 additions & 4 deletions src/daft-local-execution/src/sinks/blocking_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ use common_display::tree::TreeDisplay;
use common_error::DaftResult;
use common_runtime::get_compute_runtime;
use daft_micropartition::MicroPartition;
use snafu::ResultExt;
use tracing::info_span;

use crate::{
channel::PipelineChannel,
pipeline::{PipelineNode, PipelineResultType},
runtime_stats::RuntimeStatsContext,
ExecutionRuntimeHandle,
ExecutionRuntimeHandle, JoinSnafu,
};
pub enum BlockingSinkStatus {
NeedMoreInput,
Expand Down Expand Up @@ -102,19 +103,20 @@ impl PipelineNode for BlockingSinkNode {
let mut guard = op.lock().await;
rt_context.in_span(&span, || guard.sink(val.as_data()))
};
let result = compute_runtime.await_on(fut).await??;
let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??;
if matches!(result, BlockingSinkStatus::Finished) {
break;
}
}
let finalized_result = compute_runtime
.await_on(async move {
.spawn(async move {
let mut guard = op.lock().await;
rt_context.in_span(&info_span!("BlockingSinkNode::finalize"), || {
guard.finalize()
})
})
.await??;
.await
.context(JoinSnafu)??;
if let Some(part) = finalized_result {
let _ = destination_sender.send(part).await;
}
Expand Down
7 changes: 4 additions & 3 deletions src/daft-local-execution/src/sinks/streaming_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl StreamingSinkNode {
let fut = async move {
rt_context.in_span(&span, || op.execute(idx, &morsel, state_wrapper.as_ref()))
};
let result = compute_runtime.await_on(fut).await??;
let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??;
match result {
StreamingSinkOutput::NeedMoreInput(mp) => {
if let Some(mp) = mp {
Expand Down Expand Up @@ -281,12 +281,13 @@ impl PipelineNode for StreamingSinkNode {

let compute_runtime = get_compute_runtime();
let finalized_result = compute_runtime
.await_on(async move {
.spawn(async move {
runtime_stats.in_span(&info_span!("StreamingSinkNode::finalize"), || {
op.finalize(finished_states)
})
})
.await??;
.await
.context(JoinSnafu)??;
if let Some(res) = finalized_result {
let _ = destination_sender.send(res.into()).await;
}
Expand Down
5 changes: 3 additions & 2 deletions src/daft-local-execution/src/sources/scan_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ async fn get_delete_map(
.get_io_client_and_runtime()?;
let scan_tasks = scan_tasks.to_vec();
runtime
.await_on(async move {
.spawn(async move {

Check warning on line 142 in src/daft-local-execution/src/sources/scan_task.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sources/scan_task.rs#L142

Added line #L142 was not covered by tests
let mut delete_map = scan_tasks
.iter()
.flat_map(|st| st.sources.iter().map(|s| s.get_path().to_string()))
Expand Down Expand Up @@ -183,7 +183,8 @@ async fn get_delete_map(
}
Ok(Some(delete_map))
})
.await?
.await
.context(JoinSnafu)?

Check warning on line 187 in src/daft-local-execution/src/sources/scan_task.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sources/scan_task.rs#L186-L187

Added lines #L186 - L187 were not covered by tests
}

async fn stream_scan_task(
Expand Down

0 comments on commit 3aef8eb

Please sign in to comment.