Skip to content

Commit

Permalink
compute pool
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Oct 2, 2024
1 parent b2dabf6 commit c9a0af6
Show file tree
Hide file tree
Showing 13 changed files with 212 additions and 99 deletions.
51 changes: 37 additions & 14 deletions src/daft-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,25 @@ impl Runtime {
Arc::new(Self { runtime })
}

async fn execute_io_task<F>(future: F) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
AssertUnwindSafe(future).catch_unwind().await.map_err(|e| {
let s = if let Some(s) = e.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = e.downcast_ref::<&str>() {
s.to_string()
} else {
"unknown internal error".to_string()
};
DaftError::ComputeError(format!(
"Caught panic when spawning blocking task in io pool {s})"
))
})
}

/// Similar to tokio's Runtime::block_on but requires static lifetime + Send
/// You should use this when you are spawning IO tasks from an Expression Evaluator or in the Executor
pub fn block_on_io_pool<F>(&self, future: F) -> DaftResult<F::Output>
Expand All @@ -474,26 +493,30 @@ impl Runtime {
{
let (tx, rx) = oneshot::channel();
let _join_handle = self.spawn(async move {
let task_output = AssertUnwindSafe(future).catch_unwind().await.map_err(|e| {
let s = if let Some(s) = e.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = e.downcast_ref::<&str>() {
s.to_string()
} else {
"unknown internal error".to_string()
};
DaftError::ComputeError(format!(
"Caught panic when spawning blocking task in io pool {s})"
))
});

let task_output = Self::execute_io_task(future).await;
if tx.send(task_output).is_err() {
log::warn!("Spawned task output ignored: receiver dropped")
log::warn!("Spawned task output ignored: receiver dropped");
}
});
rx.recv().expect("Spawned task transmitter dropped")
}

/// Similar to block_on_io_pool, but is async and can be awaited
pub async fn block_on_io_pool_async<F>(&self, future: F) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = oneshot::channel();
let _join_handle = self.spawn(async move {
let task_output = Self::execute_io_task(future).await;
if tx.send(task_output).is_err() {
log::warn!("Spawned task output ignored: receiver dropped");
}
});
rx.await.expect("Spawned task transmitter dropped")
}

/// Blocks current thread to compute future. Can not be called in tokio runtime context
///
pub fn block_on_current_thread<F: Future>(&self, future: F) -> F::Output {
Expand Down
2 changes: 1 addition & 1 deletion src/daft-local-execution/src/intermediate_ops/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl IntermediateOperator for AggregateOperator {
&self,
_idx: usize,
input: &PipelineResultType,
_state: Option<&mut Box<dyn IntermediateOperatorState>>,
_state: &mut dyn IntermediateOperatorState,
) -> DaftResult<IntermediateOperatorResult> {
let out = input.as_data().agg(&self.agg_exprs, &self.group_by)?;
Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,11 @@ impl IntermediateOperator for AntiSemiProbeOperator {
&self,
idx: usize,
input: &PipelineResultType,
state: Option<&mut Box<dyn IntermediateOperatorState>>,
state: &mut dyn IntermediateOperatorState,
) -> DaftResult<IntermediateOperatorResult> {
match idx {
0 => {
let state = state
.expect("AntiSemiProbeOperator should have state")
.as_any_mut()
.downcast_mut::<AntiSemiProbeState>()
.expect("AntiSemiProbeOperator state should be AntiSemiProbeState");
Expand All @@ -115,7 +114,6 @@ impl IntermediateOperator for AntiSemiProbeOperator {
}
_ => {
let state = state
.expect("AntiSemiProbeOperator should have state")
.as_any_mut()
.downcast_mut::<AntiSemiProbeState>()
.expect("AntiSemiProbeOperator state should be AntiSemiProbeState");
Expand All @@ -133,7 +131,7 @@ impl IntermediateOperator for AntiSemiProbeOperator {
"AntiSemiProbeOperator"
}

fn make_state(&self) -> Option<Box<dyn IntermediateOperatorState>> {
Some(Box::new(AntiSemiProbeState::Building))
fn make_state(&self) -> Box<dyn IntermediateOperatorState> {
Box::new(AntiSemiProbeState::Building)
}
}
2 changes: 1 addition & 1 deletion src/daft-local-execution/src/intermediate_ops/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl IntermediateOperator for FilterOperator {
&self,
_idx: usize,
input: &PipelineResultType,
_state: Option<&mut Box<dyn IntermediateOperatorState>>,
_state: &mut dyn IntermediateOperatorState,
) -> DaftResult<IntermediateOperatorResult> {
let out = input.as_data().filter(&[self.predicate.clone()])?;
Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,11 @@ impl IntermediateOperator for HashJoinProbeOperator {
&self,
idx: usize,
input: &PipelineResultType,
state: Option<&mut Box<dyn IntermediateOperatorState>>,
state: &mut dyn IntermediateOperatorState,
) -> DaftResult<IntermediateOperatorResult> {
match idx {
0 => {
let state = state
.expect("HashJoinProbeOperator should have state")
.as_any_mut()
.downcast_mut::<HashJoinProbeState>()
.expect("HashJoinProbeOperator state should be HashJoinProbeState");
Expand All @@ -246,7 +245,6 @@ impl IntermediateOperator for HashJoinProbeOperator {
}
_ => {
let state = state
.expect("HashJoinProbeOperator should have state")
.as_any_mut()
.downcast_mut::<HashJoinProbeState>()
.expect("HashJoinProbeOperator state should be HashJoinProbeState");
Expand All @@ -267,7 +265,7 @@ impl IntermediateOperator for HashJoinProbeOperator {
"HashJoinProbeOperator"
}

fn make_state(&self) -> Option<Box<dyn IntermediateOperatorState>> {
Some(Box::new(HashJoinProbeState::Building))
fn make_state(&self) -> Box<dyn IntermediateOperatorState> {
Box::new(HashJoinProbeState::Building)
}
}
28 changes: 22 additions & 6 deletions src/daft-local-execution/src/intermediate_ops/intermediate_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::Arc;
use common_display::tree::TreeDisplay;
use common_error::DaftResult;
use daft_micropartition::MicroPartition;
use tokio::sync::Mutex;
use tracing::{info_span, instrument};

use super::buffer::OperatorBuffer;
Expand All @@ -16,6 +17,12 @@ use crate::{
pub trait IntermediateOperatorState: Send + Sync {
fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
}
pub struct DefaultIntermediateOperatorState {}
impl IntermediateOperatorState for DefaultIntermediateOperatorState {
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}

pub enum IntermediateOperatorResult {
NeedMoreInput(Option<Arc<MicroPartition>>),
Expand All @@ -28,11 +35,11 @@ pub trait IntermediateOperator: Send + Sync {
&self,
idx: usize,
input: &PipelineResultType,
state: Option<&mut Box<dyn IntermediateOperatorState>>,
state: &mut dyn IntermediateOperatorState,
) -> DaftResult<IntermediateOperatorResult>;
fn name(&self) -> &'static str;
fn make_state(&self) -> Option<Box<dyn IntermediateOperatorState>> {
None
fn make_state(&self) -> Box<dyn IntermediateOperatorState> {
Box::new(DefaultIntermediateOperatorState {})
}
}

Expand Down Expand Up @@ -75,11 +82,20 @@ impl IntermediateNode {
rt_context: Arc<RuntimeStatsContext>,
) -> DaftResult<()> {
let span = info_span!("IntermediateOp::execute");
let mut state = op.make_state();
let compute_runtime = crate::get_compute_runtime()?;
let state = Arc::new(Mutex::new(op.make_state()));
while let Some((idx, morsel)) = receiver.recv().await {
loop {
let result =
rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?;
let state = state.clone();
let op = op.clone();
let morsel = morsel.clone();
let span = span.clone();
let rt_context = rt_context.clone();
let fut = async move {
let mut state_guard = state.lock().await;
rt_context.in_span(&span, || op.execute(idx, &morsel, state_guard.as_mut()))
};
let result = compute_runtime.block_on_compute_pool(fut).await??;
match result {
IntermediateOperatorResult::NeedMoreInput(Some(mp)) => {
let _ = sender.send(mp.into()).await;
Expand Down
2 changes: 1 addition & 1 deletion src/daft-local-execution/src/intermediate_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl IntermediateOperator for ProjectOperator {
&self,
_idx: usize,
input: &PipelineResultType,
_state: Option<&mut Box<dyn IntermediateOperatorState>>,
_state: &mut dyn IntermediateOperatorState,
) -> DaftResult<IntermediateOperatorResult> {
let out = input.as_data().eval_expression_list(&self.projection)?;
Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(
Expand Down
86 changes: 86 additions & 0 deletions src/daft-local-execution/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,99 @@ mod run;
mod runtime_stats;
mod sinks;
mod sources;
use std::{
future::Future,
panic::AssertUnwindSafe,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, OnceLock,
},
};

use common_error::{DaftError, DaftResult};
use futures::FutureExt;
use lazy_static::lazy_static;
pub use run::NativeExecutor;
use snafu::{futures::TryFutureExt, Snafu};
lazy_static! {
pub static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get();
}

static COMPUTE_RUNTIME: OnceLock<RuntimeRef> = OnceLock::new();

pub type RuntimeRef = Arc<Runtime>;

pub struct Runtime {
runtime: tokio::runtime::Runtime,
}

impl Runtime {
fn new(runtime: tokio::runtime::Runtime) -> RuntimeRef {
Arc::new(Self { runtime })
}

pub async fn block_on_compute_pool<F>(&self, future: F) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = tokio::sync::oneshot::channel();
let _join_handle = self.spawn(async move {
let task_output = AssertUnwindSafe(future).catch_unwind().await.map_err(|e| {
let s = if let Some(s) = e.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = e.downcast_ref::<&str>() {
s.to_string()
} else {
"unknown internal error".to_string()
};
DaftError::ComputeError(format!(
"Caught panic when spawning blocking task in compute pool {s})"
))
});

if tx.send(task_output).is_err() {
log::warn!("Spawned task output ignored: receiver dropped")
}
});
rx.await.expect("Compute pool receiver dropped")
}

pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.runtime.spawn(future)
}
}

fn init_runtime(num_threads: usize) -> Arc<Runtime> {
std::thread::spawn(move || {
Runtime::new(
tokio::runtime::Builder::new_multi_thread()
.worker_threads(num_threads)
.enable_all()
.thread_name_fn(|| {
static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
format!("Executor-Worker-{}", id)
})
.build()
.unwrap(),
)
})
.join()
.unwrap()
}

pub fn get_compute_runtime() -> DaftResult<RuntimeRef> {
let runtime = COMPUTE_RUNTIME
.get_or_init(|| init_runtime(*NUM_CPUS))
.clone();
Ok(runtime)
}

pub struct ExecutionRuntimeHandle {
worker_set: tokio::task::JoinSet<crate::Result<()>>,
default_morsel_size: usize,
Expand Down Expand Up @@ -51,6 +136,7 @@ impl ExecutionRuntimeHandle {

#[cfg(feature = "python")]
use pyo3::prelude::*;
use tokio::task::JoinHandle;

#[derive(Debug, Snafu)]
pub enum Error {
Expand Down
13 changes: 2 additions & 11 deletions src/daft-local-execution/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ use std::{
collections::HashMap,
fs::File,
io::Write,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};

Expand Down Expand Up @@ -124,14 +121,8 @@ pub fn run_local(
let mut pipeline = physical_plan_to_pipeline(physical_plan, &psets)?;
let (tx, rx) = create_channel(results_buffer_size.unwrap_or(1));
let handle = std::thread::spawn(move || {
let runtime = tokio::runtime::Builder::new_multi_thread()
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.max_blocking_threads(10)
.thread_name_fn(|| {
static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
format!("Executor-Worker-{}", id)
})
.build()
.expect("Failed to create tokio runtime");
runtime.block_on(async {
Expand Down
Loading

0 comments on commit c9a0af6

Please sign in to comment.