diff --git a/Cargo.lock b/Cargo.lock index a4aafc4a76..bf9dfaeb47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2213,10 +2213,13 @@ dependencies = [ "indexmap 2.5.0", "lazy_static", "log", + "loole", "num-format", + "pin-project", "pyo3", "snafu", "tokio", + "tokio-util", "tracing", ] @@ -4003,6 +4006,16 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "loole" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2998397c725c822c6b2ba605fd9eb4c6a7a0810f1629ba3cc232ef4f0308d96" +dependencies = [ + "futures-core", + "futures-sink", +] + [[package]] name = "loom" version = "0.7.2" diff --git a/src/common/runtime/src/lib.rs b/src/common/runtime/src/lib.rs index 41803f06d7..54e2555e61 100644 --- a/src/common/runtime/src/lib.rs +++ b/src/common/runtime/src/lib.rs @@ -58,7 +58,7 @@ impl Future for RuntimeTask { type Output = DaftResult; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match Pin::new(&mut self.joinset).poll_join_next(cx) { + match self.joinset.poll_join_next(cx) { Poll::Ready(Some(result)) => { Poll::Ready(result.map_err(|e| DaftError::External(e.into()))) } diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index ff4f15e8b5..06ddd3efee 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -25,10 +25,13 @@ futures = {workspace = true} indexmap = {workspace = true} lazy_static = {workspace = true} log = {workspace = true} +loole = "0.4.0" num-format = "0.4.4" +pin-project = "1" pyo3 = {workspace = true, optional = true} snafu = {workspace = true} tokio = {workspace = true} +tokio-util = {workspace = true} tracing = {workspace = true} [features] diff --git a/src/daft-local-execution/src/channel.rs b/src/daft-local-execution/src/channel.rs index f16e3bd061..7a58e79ade 100644 --- a/src/daft-local-execution/src/channel.rs +++ b/src/daft-local-execution/src/channel.rs @@ -1,92 +1,60 @@ -use std::sync::Arc; - -use crate::{ - pipeline::PipelineResultType, - runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, -}; - -pub type Sender = tokio::sync::mpsc::Sender; -pub type Receiver = tokio::sync::mpsc::Receiver; - -pub fn create_channel(buffer_size: usize) -> (Sender, Receiver) { - tokio::sync::mpsc::channel(buffer_size) -} - -pub struct PipelineChannel { - sender: PipelineSender, - receiver: PipelineReceiver, -} - -impl PipelineChannel { - pub fn new(buffer_size: usize, in_order: bool) -> Self { - if in_order { - let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip(); - let sender = PipelineSender::InOrder(RoundRobinSender::new(senders)); - let receiver = PipelineReceiver::InOrder(RoundRobinReceiver::new(receivers)); - Self { sender, receiver } - } else { - let (sender, receiver) = create_channel(buffer_size); - let sender = PipelineSender::OutOfOrder(sender); - let receiver = PipelineReceiver::OutOfOrder(receiver); - Self { sender, receiver } - } - } - - fn get_next_sender(&mut self) -> Sender { - match &mut self.sender { - PipelineSender::InOrder(rr) => rr.get_next_sender(), - PipelineSender::OutOfOrder(sender) => sender.clone(), - } - } - - pub(crate) fn get_next_sender_with_stats( - &mut self, - rt: &Arc, - ) -> CountingSender { - CountingSender::new(self.get_next_sender(), rt.clone()) +#[derive(Clone)] +pub(crate) struct Sender(loole::Sender); +impl Sender { + pub(crate) async fn send(&self, val: T) -> Result<(), loole::SendError> { + self.0.send_async(val).await } +} - pub fn get_receiver(self) -> PipelineReceiver { - self.receiver +#[derive(Clone)] +pub(crate) struct Receiver(loole::Receiver); +impl Receiver { + pub(crate) async fn recv(&self) -> Option { + self.0.recv_async().await.ok() } - pub(crate) fn get_receiver_with_stats(self, rt: &Arc) -> CountingReceiver { - CountingReceiver::new(self.get_receiver(), rt.clone()) + pub(crate) fn blocking_recv(&self) -> Option { + self.0.recv().ok() } } -pub enum PipelineSender { - InOrder(RoundRobinSender), - OutOfOrder(Sender), -} - -pub struct RoundRobinSender { - senders: Vec>, - curr_sender_idx: usize, +pub(crate) fn create_channel(buffer_size: usize) -> (Sender, Receiver) { + let (tx, rx) = loole::bounded(buffer_size); + (Sender(tx), Receiver(rx)) } -impl RoundRobinSender { - pub fn new(senders: Vec>) -> Self { - Self { - senders, - curr_sender_idx: 0, +/// A multi-producer, single-consumer channel that is aware of the ordering of the senders. +/// If `ordered` is true, the receiver will try to receive from each sender in a round-robin fashion. +/// This is useful when collecting results from multiple workers in a specific order. +pub(crate) fn create_ordering_aware_receiver_channel( + ordered: bool, + buffer_size: usize, +) -> (Vec>, OrderingAwareReceiver) { + match ordered { + true => { + let (senders, receiver) = (0..buffer_size).map(|_| create_channel::(1)).unzip(); + ( + senders, + OrderingAwareReceiver::InOrder(RoundRobinReceiver::new(receiver)), + ) + } + false => { + let (sender, receiver) = create_channel::(buffer_size); + ( + (0..buffer_size).map(|_| sender.clone()).collect(), + OrderingAwareReceiver::OutOfOrder(receiver), + ) } - } - - pub fn get_next_sender(&mut self) -> Sender { - let next_idx = self.curr_sender_idx; - self.curr_sender_idx = (next_idx + 1) % self.senders.len(); - self.senders[next_idx].clone() } } -pub enum PipelineReceiver { - InOrder(RoundRobinReceiver), - OutOfOrder(Receiver), +pub(crate) enum OrderingAwareReceiver { + InOrder(RoundRobinReceiver), + OutOfOrder(Receiver), } -impl PipelineReceiver { - pub async fn recv(&mut self) -> Option { +impl OrderingAwareReceiver { + pub(crate) async fn recv(&mut self) -> Option { match self { Self::InOrder(rr) => rr.recv().await, Self::OutOfOrder(r) => r.recv().await, @@ -94,14 +62,15 @@ impl PipelineReceiver { } } -pub struct RoundRobinReceiver { +/// A round-robin receiver that tries to receive from each receiver in a round-robin fashion. +pub(crate) struct RoundRobinReceiver { receivers: Vec>, curr_receiver_idx: usize, is_done: bool, } impl RoundRobinReceiver { - pub fn new(receivers: Vec>) -> Self { + fn new(receivers: Vec>) -> Self { Self { receivers, curr_receiver_idx: 0, @@ -109,7 +78,7 @@ impl RoundRobinReceiver { } } - pub async fn recv(&mut self) -> Option { + async fn recv(&mut self) -> Option { if self.is_done { return None; } diff --git a/src/daft-local-execution/src/dispatcher.rs b/src/daft-local-execution/src/dispatcher.rs index 08aeceb2eb..9e487028f1 100644 --- a/src/daft-local-execution/src/dispatcher.rs +++ b/src/daft-local-execution/src/dispatcher.rs @@ -1,74 +1,181 @@ use std::sync::Arc; -use async_trait::async_trait; use common_error::DaftResult; use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; use crate::{ - buffer::RowBasedBuffer, channel::Sender, pipeline::PipelineResultType, + buffer::RowBasedBuffer, + channel::{create_channel, Receiver, Sender}, runtime_stats::CountingReceiver, + RuntimeHandle, SpawnedTask, }; -#[async_trait] -pub(crate) trait Dispatcher { - async fn dispatch( +/// The `DispatchSpawner` trait is implemented by types that can spawn a task that reads from +/// input receivers and distributes morsels to worker receivers. +/// +/// The `spawn_dispatch` method is called with the input receivers, the number of workers, the +/// runtime handle, and the pipeline node that the dispatcher is associated with. +/// +/// It returns a vector of receivers (one per worker) that will receive the morsels. +/// +/// Implementations must spawn a task on the runtime handle that reads from the +/// input receivers and distributes morsels to the worker receivers. +pub(crate) trait DispatchSpawner { + fn spawn_dispatch( &self, - receiver: CountingReceiver, - worker_senders: Vec>, - ) -> DaftResult<()>; + input_receivers: Vec, + num_workers: usize, + runtime_handle: &mut RuntimeHandle, + ) -> SpawnedDispatchResult; } -pub(crate) struct RoundRobinBufferedDispatcher { - morsel_size: usize, +pub(crate) struct SpawnedDispatchResult { + pub(crate) worker_receivers: Vec>>, + pub(crate) spawned_dispatch_task: SpawnedTask>, } -impl RoundRobinBufferedDispatcher { - pub(crate) fn new(morsel_size: usize) -> Self { +/// A dispatcher that distributes morsels to workers in a round-robin fashion. +/// Used if the operator requires maintaining the order of the input. +pub(crate) struct RoundRobinDispatcher { + morsel_size: Option, +} + +impl RoundRobinDispatcher { + pub(crate) fn new(morsel_size: Option) -> Self { Self { morsel_size } } -} -#[async_trait] -impl Dispatcher for RoundRobinBufferedDispatcher { - async fn dispatch( - &self, - mut receiver: CountingReceiver, - worker_senders: Vec>, + async fn dispatch_inner( + worker_senders: Vec>>, + input_receivers: Vec, + morsel_size: Option, ) -> DaftResult<()> { let mut next_worker_idx = 0; - let mut send_to_next_worker = |data: PipelineResultType| { + let mut send_to_next_worker = |data: Arc| { let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); next_worker_sender.send(data) }; - let mut buffer = RowBasedBuffer::new(self.morsel_size); - while let Some(morsel) = receiver.recv().await { - if morsel.should_broadcast() { - for worker_sender in &worker_senders { - if worker_sender.send(morsel.clone()).await.is_err() { + for receiver in input_receivers { + let mut buffer = morsel_size.map(RowBasedBuffer::new); + while let Some(morsel) = receiver.recv().await { + if let Some(buffer) = &mut buffer { + buffer.push(&morsel); + if let Some(ready) = buffer.pop_enough()? { + for r in ready { + if send_to_next_worker(r).await.is_err() { + return Ok(()); + } + } + } + } else if send_to_next_worker(morsel).await.is_err() { + return Ok(()); + } + } + // Clear all remaining morsels + if let Some(buffer) = &mut buffer { + if let Some(last_morsel) = buffer.pop_all()? { + if send_to_next_worker(last_morsel).await.is_err() { return Ok(()); } } - } else { - buffer.push(morsel.as_data()); - if let Some(ready) = buffer.pop_enough()? { - for r in ready { - if send_to_next_worker(r.into()).await.is_err() { - return Ok(()); + } + } + Ok(()) + } +} + +impl DispatchSpawner for RoundRobinDispatcher { + fn spawn_dispatch( + &self, + input_receivers: Vec, + num_workers: usize, + runtime_handle: &mut RuntimeHandle, + ) -> SpawnedDispatchResult { + let (worker_senders, worker_receivers): (Vec<_>, Vec<_>) = + (0..num_workers).map(|_| create_channel(1)).unzip(); + let morsel_size = self.morsel_size; + let task = runtime_handle.spawn(async move { + Self::dispatch_inner(worker_senders, input_receivers, morsel_size).await + }); + + SpawnedDispatchResult { + worker_receivers, + spawned_dispatch_task: task, + } + } +} + +/// A dispatcher that distributes morsels to workers in an unordered fashion. +/// Used if the operator does not require maintaining the order of the input. +pub(crate) struct UnorderedDispatcher { + morsel_size: Option, +} + +impl UnorderedDispatcher { + pub(crate) fn new(morsel_size: Option) -> Self { + Self { morsel_size } + } + + async fn dispatch_inner( + worker_sender: Sender>, + input_receivers: Vec, + morsel_size: Option, + ) -> DaftResult<()> { + for receiver in input_receivers { + let mut buffer = morsel_size.map(RowBasedBuffer::new); + while let Some(morsel) = receiver.recv().await { + if let Some(buffer) = &mut buffer { + buffer.push(&morsel); + if let Some(ready) = buffer.pop_enough()? { + for r in ready { + if worker_sender.send(r).await.is_err() { + return Ok(()); + } } } + } else if worker_sender.send(morsel).await.is_err() { + return Ok(()); + } + } + // Clear all remaining morsels + if let Some(buffer) = &mut buffer { + if let Some(last_morsel) = buffer.pop_all()? { + if worker_sender.send(last_morsel).await.is_err() { + return Ok(()); + } } } - } - // Clear all remaining morsels - if let Some(last_morsel) = buffer.pop_all()? { - let _ = send_to_next_worker(last_morsel.into()).await; } Ok(()) } } +impl DispatchSpawner for UnorderedDispatcher { + fn spawn_dispatch( + &self, + receiver: Vec, + num_workers: usize, + runtime_handle: &mut RuntimeHandle, + ) -> SpawnedDispatchResult { + let (worker_sender, worker_receiver) = create_channel(num_workers); + let worker_receivers = vec![worker_receiver; num_workers]; + let morsel_size = self.morsel_size; + + let dispatch_task = runtime_handle + .spawn(async move { Self::dispatch_inner(worker_sender, receiver, morsel_size).await }); + + SpawnedDispatchResult { + worker_receivers, + spawned_dispatch_task: dispatch_task, + } + } +} + +/// A dispatcher that distributes morsels to workers based on a partitioning expression. +/// Used if the operator requires partitioning the input, i.e. partitioned writes. pub(crate) struct PartitionedDispatcher { partition_by: Vec, } @@ -77,33 +184,18 @@ impl PartitionedDispatcher { pub(crate) fn new(partition_by: Vec) -> Self { Self { partition_by } } -} -#[async_trait] -impl Dispatcher for PartitionedDispatcher { - async fn dispatch( - &self, - mut receiver: CountingReceiver, - worker_senders: Vec>, + async fn dispatch_inner( + worker_senders: Vec>>, + input_receivers: Vec, + partition_by: Vec, ) -> DaftResult<()> { - while let Some(morsel) = receiver.recv().await { - if morsel.should_broadcast() { - for worker_sender in &worker_senders { - if worker_sender.send(morsel.clone()).await.is_err() { - return Ok(()); - } - } - } else { - let partitions = morsel - .as_data() - .partition_by_hash(&self.partition_by, worker_senders.len())?; + for receiver in input_receivers { + while let Some(morsel) = receiver.recv().await { + let partitions = morsel.partition_by_hash(&partition_by, worker_senders.len())?; for (partition, worker_sender) in partitions.into_iter().zip(worker_senders.iter()) { - if worker_sender - .send(Arc::new(partition).into()) - .await - .is_err() - { + if worker_sender.send(Arc::new(partition)).await.is_err() { return Ok(()); } } @@ -112,3 +204,24 @@ impl Dispatcher for PartitionedDispatcher { Ok(()) } } + +impl DispatchSpawner for PartitionedDispatcher { + fn spawn_dispatch( + &self, + input_receivers: Vec, + num_workers: usize, + runtime_handle: &mut RuntimeHandle, + ) -> SpawnedDispatchResult { + let (worker_senders, worker_receivers): (Vec<_>, Vec<_>) = + (0..num_workers).map(|_| create_channel(1)).unzip(); + let partition_by = self.partition_by.clone(); + let dispatch_task = runtime_handle.spawn(async move { + Self::dispatch_inner(worker_senders, input_receivers, partition_by).await + }); + + SpawnedDispatchResult { + worker_receivers, + spawned_dispatch_task: dispatch_task, + } + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/actor_pool_project.rs b/src/daft-local-execution/src/intermediate_ops/actor_pool_project.rs index 58a785e62b..810c6c7560 100644 --- a/src/daft-local-execution/src/intermediate_ops/actor_pool_project.rs +++ b/src/daft-local-execution/src/intermediate_ops/actor_pool_project.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use common_error::DaftResult; +use common_runtime::RuntimeRef; #[cfg(feature = "python")] use daft_dsl::python::PyExpr; use daft_dsl::{functions::python::extract_stateful_udf_exprs, ExprRef}; @@ -12,10 +13,13 @@ use pyo3::prelude::*; use tracing::instrument; use super::intermediate_op::{ - DynIntermediateOpState, IntermediateOperator, IntermediateOperatorResult, - IntermediateOperatorState, + IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator, + IntermediateOperatorResult, +}; +use crate::{ + dispatcher::{DispatchSpawner, RoundRobinDispatcher, UnorderedDispatcher}, + ExecutionRuntimeContext, }; -use crate::pipeline::PipelineResultType; struct ActorHandle { #[cfg(feature = "python")] @@ -108,7 +112,7 @@ struct ActorPoolProjectState { pub actor_handle: ActorHandle, } -impl DynIntermediateOpState for ActorPoolProjectState { +impl IntermediateOpState for ActorPoolProjectState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } @@ -145,23 +149,29 @@ impl IntermediateOperator for ActorPoolProjectOperator { #[instrument(skip_all, name = "ActorPoolProjectOperator::execute")] fn execute( &self, - _idx: usize, - input: &PipelineResultType, - state: &IntermediateOperatorState, - ) -> DaftResult { - state.with_state_mut::(|state| { - state + input: Arc, + mut state: Box, + runtime: &RuntimeRef, + ) -> IntermediateOpExecuteResult { + let fut = runtime.spawn(async move { + let actor_pool_project_state = state + .as_any_mut() + .downcast_mut::() + .expect("ActorPoolProjectState"); + let res = actor_pool_project_state .actor_handle - .eval_input(input.as_data().clone()) - .map(|result| IntermediateOperatorResult::NeedMoreInput(Some(result))) - }) + .eval_input(input) + .map(|result| IntermediateOperatorResult::NeedMoreInput(Some(result)))?; + Ok((state, res)) + }); + fut.into() } fn name(&self) -> &'static str { "ActorPoolProject" } - fn make_state(&self) -> DaftResult> { + fn make_state(&self) -> DaftResult> { // TODO: Pass relevant CUDA_VISIBLE_DEVICES to the actor Ok(Box::new(ActorPoolProjectState { actor_handle: ActorHandle::try_new(&self.projection)?, @@ -172,7 +182,21 @@ impl IntermediateOperator for ActorPoolProjectOperator { self.concurrency } - fn morsel_size(&self) -> Option { - self.batch_size + fn dispatch_spawner( + &self, + runtime_handle: &ExecutionRuntimeContext, + maintain_order: bool, + ) -> Arc { + if maintain_order { + Arc::new(RoundRobinDispatcher::new(Some( + self.batch_size + .unwrap_or_else(|| runtime_handle.default_morsel_size()), + ))) + } else { + Arc::new(UnorderedDispatcher::new(Some( + self.batch_size + .unwrap_or_else(|| runtime_handle.default_morsel_size()), + ))) + } } } diff --git a/src/daft-local-execution/src/intermediate_ops/aggregate.rs b/src/daft-local-execution/src/intermediate_ops/aggregate.rs index 4b8fa7bbb6..cb9344b160 100644 --- a/src/daft-local-execution/src/intermediate_ops/aggregate.rs +++ b/src/daft-local-execution/src/intermediate_ops/aggregate.rs @@ -1,24 +1,31 @@ use std::sync::Arc; -use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, + IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator, + IntermediateOperatorResult, }; -use crate::pipeline::PipelineResultType; -pub struct AggregateOperator { +struct AggParams { agg_exprs: Vec, group_by: Vec, } +pub struct AggregateOperator { + params: Arc, +} + impl AggregateOperator { pub fn new(agg_exprs: Vec, group_by: Vec) -> Self { Self { - agg_exprs, - group_by, + params: Arc::new(AggParams { + agg_exprs, + group_by, + }), } } } @@ -27,14 +34,20 @@ impl IntermediateOperator for AggregateOperator { #[instrument(skip_all, name = "AggregateOperator::execute")] fn execute( &self, - _idx: usize, - input: &PipelineResultType, - _state: &IntermediateOperatorState, - ) -> DaftResult { - let out = input.as_data().agg(&self.agg_exprs, &self.group_by)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, - )))) + input: Arc, + state: Box, + runtime: &RuntimeRef, + ) -> IntermediateOpExecuteResult { + let params = self.params.clone(); + runtime + .spawn(async move { + let out = input.agg(¶ms.agg_exprs, ¶ms.group_by)?; + Ok(( + state, + IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(out))), + )) + }) + .into() } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index 544bba1906..765039651e 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use daft_logical_plan::JoinType; @@ -9,64 +10,72 @@ use daft_table::{GrowableTable, Probeable}; use tracing::{info_span, instrument}; use super::intermediate_op::{ - DynIntermediateOpState, IntermediateOperator, IntermediateOperatorResult, - IntermediateOperatorState, + IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator, + IntermediateOperatorResult, }; -use crate::pipeline::PipelineResultType; +use crate::sinks::hash_join_build::ProbeStateBridgeRef; enum AntiSemiProbeState { - Building, - ReadyToProbe(Arc), + Building(ProbeStateBridgeRef), + Probing(Arc), } impl AntiSemiProbeState { - fn set_table(&mut self, table: &Arc) { - if matches!(self, Self::Building) { - *self = Self::ReadyToProbe(table.clone()); - } else { - panic!("AntiSemiProbeState should only be in Building state when setting table") - } - } - - fn get_probeable(&self) -> &Arc { - if let Self::ReadyToProbe(probeable) = self { - probeable - } else { - panic!("AntiSemiProbeState should only be in ReadyToProbe state when getting probeable") + async fn get_or_await_probeable(&mut self) -> Arc { + match self { + Self::Building(bridge) => { + let probe_state = bridge.get_probe_state().await; + let probeable = probe_state.get_probeable(); + *self = Self::Probing(probeable.clone()); + probeable.clone() + } + Self::Probing(probeable) => probeable.clone(), } } } -impl DynIntermediateOpState for AntiSemiProbeState { +impl IntermediateOpState for AntiSemiProbeState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } } -pub struct AntiSemiProbeOperator { +struct AntiSemiJoinParams { probe_on: Vec, is_semi: bool, +} + +pub(crate) struct AntiSemiProbeOperator { + params: Arc, output_schema: SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, } impl AntiSemiProbeOperator { const DEFAULT_GROWABLE_SIZE: usize = 20; - pub fn new(probe_on: Vec, join_type: &JoinType, output_schema: &SchemaRef) -> Self { + pub fn new( + probe_on: Vec, + join_type: &JoinType, + output_schema: &SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, + ) -> Self { Self { - probe_on, - is_semi: *join_type == JoinType::Semi, + params: Arc::new(AntiSemiJoinParams { + probe_on, + is_semi: *join_type == JoinType::Semi, + }), output_schema: output_schema.clone(), + probe_state_bridge, } } fn probe_anti_semi( - &self, + probe_on: &[ExprRef], + probe_set: &Arc, input: &Arc, - state: &AntiSemiProbeState, + is_semi: bool, ) -> DaftResult> { - let probe_set = state.get_probeable(); - let _growables = info_span!("AntiSemiOperator::build_growables").entered(); let input_tables = input.get_tables()?; @@ -81,11 +90,11 @@ impl AntiSemiProbeOperator { { let _loop = info_span!("AntiSemiOperator::eval_and_probe").entered(); for (probe_side_table_idx, table) in input_tables.iter().enumerate() { - let join_keys = table.eval_expression_list(&self.probe_on)?; + let join_keys = table.eval_expression_list(probe_on)?; let iter = probe_set.probe_exists(&join_keys)?; for (probe_row_idx, matched) in iter.enumerate() { - match (self.is_semi, matched) { + match (is_semi, matched) { (true, true) | (false, false) => { probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); } @@ -107,32 +116,41 @@ impl IntermediateOperator for AntiSemiProbeOperator { #[instrument(skip_all, name = "AntiSemiOperator::execute")] fn execute( &self, - idx: usize, - input: &PipelineResultType, - state: &IntermediateOperatorState, - ) -> DaftResult { - state.with_state_mut::(|state| { - if idx == 0 { - let probe_state = input.as_probe_state(); - state.set_table(probe_state.get_probeable()); - Ok(IntermediateOperatorResult::NeedMoreInput(None)) - } else { - let input = input.as_data(); - if input.is_empty() { - let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); - return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); - } - let out = self.probe_anti_semi(input, state)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) - } - }) + input: Arc, + mut state: Box, + runtime: &RuntimeRef, + ) -> IntermediateOpExecuteResult { + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(( + state, + IntermediateOperatorResult::NeedMoreInput(Some(empty)), + )) + .into(); + } + + let params = self.params.clone(); + runtime + .spawn(async move { + let probe_state = state + .as_any_mut() + .downcast_mut::() + .expect("AntiSemiProbeState should be used with AntiSemiProbeOperator"); + let probeable = probe_state.get_or_await_probeable().await; + let res = + Self::probe_anti_semi(¶ms.probe_on, &probeable, &input, params.is_semi); + Ok((state, IntermediateOperatorResult::NeedMoreInput(Some(res?)))) + }) + .into() } fn name(&self) -> &'static str { "AntiSemiProbeOperator" } - fn make_state(&self) -> DaftResult> { - Ok(Box::new(AntiSemiProbeState::Building)) + fn make_state(&self) -> DaftResult> { + Ok(Box::new(AntiSemiProbeState::Building( + self.probe_state_bridge.clone(), + ))) } } diff --git a/src/daft-local-execution/src/intermediate_ops/explode.rs b/src/daft-local-execution/src/intermediate_ops/explode.rs index 30ec8b02b5..ce50cfb436 100644 --- a/src/daft-local-execution/src/intermediate_ops/explode.rs +++ b/src/daft-local-execution/src/intermediate_ops/explode.rs @@ -1,23 +1,24 @@ use std::sync::Arc; -use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_dsl::ExprRef; use daft_functions::list::explode; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, + IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator, + IntermediateOperatorResult, }; -use crate::pipeline::PipelineResultType; pub struct ExplodeOperator { - to_explode: Vec, + to_explode: Arc>, } impl ExplodeOperator { pub fn new(to_explode: Vec) -> Self { Self { - to_explode: to_explode.into_iter().map(explode).collect(), + to_explode: Arc::new(to_explode.into_iter().map(explode).collect()), } } } @@ -26,14 +27,20 @@ impl IntermediateOperator for ExplodeOperator { #[instrument(skip_all, name = "ExplodeOperator::execute")] fn execute( &self, - _idx: usize, - input: &PipelineResultType, - _state: &IntermediateOperatorState, - ) -> DaftResult { - let out = input.as_data().explode(&self.to_explode)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, - )))) + input: Arc, + state: Box, + runtime: &RuntimeRef, + ) -> IntermediateOpExecuteResult { + let to_explode = self.to_explode.clone(); + runtime + .spawn(async move { + let out = input.explode(&to_explode)?; + Ok(( + state, + IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(out))), + )) + }) + .into() } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/intermediate_ops/filter.rs b/src/daft-local-execution/src/intermediate_ops/filter.rs index aad3bd7e7d..33126940b7 100644 --- a/src/daft-local-execution/src/intermediate_ops/filter.rs +++ b/src/daft-local-execution/src/intermediate_ops/filter.rs @@ -1,13 +1,14 @@ use std::sync::Arc; -use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, + IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator, + IntermediateOperatorResult, }; -use crate::pipeline::PipelineResultType; pub struct FilterOperator { predicate: ExprRef, @@ -23,14 +24,20 @@ impl IntermediateOperator for FilterOperator { #[instrument(skip_all, name = "FilterOperator::execute")] fn execute( &self, - _idx: usize, - input: &PipelineResultType, - _state: &IntermediateOperatorState, - ) -> DaftResult { - let out = input.as_data().filter(&[self.predicate.clone()])?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, - )))) + input: Arc, + state: Box, + runtime: &RuntimeRef, + ) -> IntermediateOpExecuteResult { + let predicate = self.predicate.clone(); + runtime + .spawn(async move { + let out = input.filter(&[predicate])?; + Ok(( + state, + IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(out))), + )) + }) + .into() } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs index ce0da48089..ada98dac23 100644 --- a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; @@ -9,47 +10,47 @@ use indexmap::IndexSet; use tracing::{info_span, instrument}; use super::intermediate_op::{ - DynIntermediateOpState, IntermediateOperator, IntermediateOperatorResult, - IntermediateOperatorState, + IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator, + IntermediateOperatorResult, }; -use crate::pipeline::PipelineResultType; +use crate::sinks::hash_join_build::ProbeStateBridgeRef; enum InnerHashJoinProbeState { - Building, - ReadyToProbe(Arc), + Building(ProbeStateBridgeRef), + Probing(Arc), } impl InnerHashJoinProbeState { - fn set_probe_state(&mut self, probe_state: Arc) { - if matches!(self, Self::Building) { - *self = Self::ReadyToProbe(probe_state); - } else { - panic!("InnerHashJoinProbeState should only be in Building state when setting table") - } - } - - fn get_probe_state(&self) -> &Arc { - if let Self::ReadyToProbe(probe_state) = self { - probe_state - } else { - panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") + async fn get_or_await_probe_state(&mut self) -> Arc { + match self { + Self::Building(bridge) => { + let probe_state = bridge.get_probe_state().await; + *self = Self::Probing(probe_state.clone()); + probe_state + } + Self::Probing(probeable) => probeable.clone(), } } } -impl DynIntermediateOpState for InnerHashJoinProbeState { +impl IntermediateOpState for InnerHashJoinProbeState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } } -pub struct InnerHashJoinProbeOperator { +struct InnerHashJoinParams { probe_on: Vec, common_join_keys: Vec, left_non_join_columns: Vec, right_non_join_columns: Vec, build_on_left: bool, +} + +pub struct InnerHashJoinProbeOperator { + params: Arc, output_schema: SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, } impl InnerHashJoinProbeOperator { @@ -62,6 +63,7 @@ impl InnerHashJoinProbeOperator { build_on_left: bool, common_join_keys: IndexSet, output_schema: &SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, ) -> Self { let left_non_join_columns = left_schema .fields @@ -77,24 +79,29 @@ impl InnerHashJoinProbeOperator { .collect(); let common_join_keys = common_join_keys.into_iter().collect(); Self { - probe_on, - common_join_keys, - left_non_join_columns, - right_non_join_columns, - build_on_left, + params: Arc::new(InnerHashJoinParams { + probe_on, + common_join_keys, + left_non_join_columns, + right_non_join_columns, + build_on_left, + }), output_schema: output_schema.clone(), + probe_state_bridge, } } fn probe_inner( - &self, input: &Arc, - state: &InnerHashJoinProbeState, + probe_state: &Arc, + probe_on: &[ExprRef], + common_join_keys: &[String], + left_non_join_columns: &[String], + right_non_join_columns: &[String], + build_on_left: bool, ) -> DaftResult> { - let (probe_table, tables) = { - let probe_state = state.get_probe_state(); - (probe_state.get_probeable(), probe_state.get_tables()) - }; + let probe_table = probe_state.get_probeable(); + let tables = probe_state.get_tables(); let _growables = info_span!("InnerHashJoinOperator::build_growables").entered(); @@ -117,7 +124,7 @@ impl InnerHashJoinProbeOperator { let _loop = info_span!("InnerHashJoinOperator::eval_and_probe").entered(); for (probe_side_table_idx, table) in input_tables.iter().enumerate() { // we should emit one table at a time when this is streaming - let join_keys = table.eval_expression_list(&self.probe_on)?; + let join_keys = table.eval_expression_list(probe_on)?; let idx_mapper = probe_table.probe_indices(&join_keys)?; for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { @@ -138,15 +145,15 @@ impl InnerHashJoinProbeOperator { let build_side_table = build_side_growable.build()?; let probe_side_table = probe_side_growable.build()?; - let (left_table, right_table) = if self.build_on_left { + let (left_table, right_table) = if build_on_left { (build_side_table, probe_side_table) } else { (probe_side_table, build_side_table) }; - let join_keys_table = left_table.get_columns(&self.common_join_keys)?; - let left_non_join_columns = left_table.get_columns(&self.left_non_join_columns)?; - let right_non_join_columns = right_table.get_columns(&self.right_non_join_columns)?; + let join_keys_table = left_table.get_columns(common_join_keys)?; + let left_non_join_columns = left_table.get_columns(left_non_join_columns)?; + let right_non_join_columns = right_table.get_columns(right_non_join_columns)?; let final_table = join_keys_table .union(&left_non_join_columns)? .union(&right_non_join_columns)?; @@ -163,33 +170,50 @@ impl IntermediateOperator for InnerHashJoinProbeOperator { #[instrument(skip_all, name = "InnerHashJoinOperator::execute")] fn execute( &self, - idx: usize, - input: &PipelineResultType, - state: &IntermediateOperatorState, - ) -> DaftResult { - state.with_state_mut::(|state| match idx { - 0 => { - let probe_state = input.as_probe_state(); - state.set_probe_state(probe_state.clone()); - Ok(IntermediateOperatorResult::NeedMoreInput(None)) - } - _ => { - let input = input.as_data(); - if input.is_empty() { - let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); - return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); - } - let out = self.probe_inner(input, state)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) - } - }) + input: Arc, + mut state: Box, + runtime_ref: &RuntimeRef, + ) -> IntermediateOpExecuteResult { + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(( + state, + IntermediateOperatorResult::NeedMoreInput(Some(empty)), + )) + .into(); + } + + let params = self.params.clone(); + runtime_ref + .spawn(async move { + let inner_join_state = state + .as_any_mut() + .downcast_mut::() + .expect( + "InnerHashJoinProbeState should be used with InnerHashJoinProbeOperator", + ); + let probe_state = inner_join_state.get_or_await_probe_state().await; + let res = Self::probe_inner( + &input, + &probe_state, + ¶ms.probe_on, + ¶ms.common_join_keys, + ¶ms.left_non_join_columns, + ¶ms.right_non_join_columns, + params.build_on_left, + ); + Ok((state, IntermediateOperatorResult::NeedMoreInput(Some(res?)))) + }) + .into() } fn name(&self) -> &'static str { "InnerHashJoinProbeOperator" } - fn make_state(&self) -> DaftResult> { - Ok(Box::new(InnerHashJoinProbeState::Building)) + fn make_state(&self) -> DaftResult> { + Ok(Box::new(InnerHashJoinProbeState::Building( + self.probe_state_bridge.clone(), + ))) } } diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 65b3455f57..91dba3de0a 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -1,69 +1,50 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use common_display::tree::TreeDisplay; use common_error::DaftResult; -use common_runtime::get_compute_runtime; +use common_runtime::{get_compute_runtime, RuntimeRef}; use daft_micropartition::MicroPartition; use tracing::{info_span, instrument}; use crate::{ - buffer::RowBasedBuffer, - channel::{create_channel, PipelineChannel, Receiver, Sender}, - pipeline::{PipelineNode, PipelineResultType}, + channel::{ + create_channel, create_ordering_aware_receiver_channel, OrderingAwareReceiver, Receiver, + Sender, + }, + dispatcher::{DispatchSpawner, RoundRobinDispatcher, UnorderedDispatcher}, + pipeline::PipelineNode, runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, - ExecutionRuntimeHandle, NUM_CPUS, + ExecutionRuntimeContext, OperatorOutput, NUM_CPUS, }; -pub(crate) trait DynIntermediateOpState: Send + Sync { +pub(crate) trait IntermediateOpState: Send + Sync { fn as_any_mut(&mut self) -> &mut dyn std::any::Any; } struct DefaultIntermediateOperatorState {} -impl DynIntermediateOpState for DefaultIntermediateOperatorState { +impl IntermediateOpState for DefaultIntermediateOperatorState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } } -pub(crate) struct IntermediateOperatorState { - inner: Mutex>, -} - -impl IntermediateOperatorState { - fn new(inner: Box) -> Arc { - Arc::new(Self { - inner: Mutex::new(inner), - }) - } - - pub(crate) fn with_state_mut(&self, f: F) -> R - where - F: FnOnce(&mut T) -> R, - { - let mut guard = self.inner.lock().unwrap(); - let state = guard - .as_any_mut() - .downcast_mut::() - .expect("State type mismatch"); - f(state) - } -} - pub enum IntermediateOperatorResult { NeedMoreInput(Option>), #[allow(dead_code)] HasMoreOutput(Arc), } +pub(crate) type IntermediateOpExecuteResult = + OperatorOutput, IntermediateOperatorResult)>>; pub trait IntermediateOperator: Send + Sync { fn execute( &self, - idx: usize, - input: &PipelineResultType, - state: &IntermediateOperatorState, - ) -> DaftResult; + input: Arc, + state: Box, + runtime: &RuntimeRef, + ) -> IntermediateOpExecuteResult; fn name(&self) -> &'static str; - fn make_state(&self) -> DaftResult> { + fn make_state(&self) -> DaftResult> { Ok(Box::new(DefaultIntermediateOperatorState {})) } /// The maximum number of concurrent workers that can be spawned for this operator. @@ -72,9 +53,20 @@ pub trait IntermediateOperator: Send + Sync { *NUM_CPUS } - /// The input morsel size expected by this operator. If None, use the default size. - fn morsel_size(&self) -> Option { - None + fn dispatch_spawner( + &self, + runtime_handle: &ExecutionRuntimeContext, + maintain_order: bool, + ) -> Arc { + if maintain_order { + Arc::new(RoundRobinDispatcher::new(Some( + runtime_handle.default_morsel_size(), + ))) + } else { + Arc::new(UnorderedDispatcher::new(Some( + runtime_handle.default_morsel_size(), + ))) + } } } @@ -112,27 +104,24 @@ impl IntermediateNode { #[instrument(level = "info", skip_all, name = "IntermediateOperator::run_worker")] pub async fn run_worker( op: Arc, - mut receiver: Receiver<(usize, PipelineResultType)>, - sender: CountingSender, + receiver: Receiver>, + sender: Sender>, rt_context: Arc, ) -> DaftResult<()> { let span = info_span!("IntermediateOp::execute"); let compute_runtime = get_compute_runtime(); - let state_wrapper = IntermediateOperatorState::new(op.make_state()?); - while let Some((idx, morsel)) = receiver.recv().await { + let mut state = op.make_state()?; + while let Some(morsel) = receiver.recv().await { loop { - let op = op.clone(); - let morsel = morsel.clone(); - let span = span.clone(); - let rt_context = rt_context.clone(); - let state_wrapper = state_wrapper.clone(); - let fut = async move { - rt_context.in_span(&span, || op.execute(idx, &morsel, &state_wrapper)) - }; - let result = compute_runtime.spawn(fut).await??; - match result { + let result = rt_context + .in_span(&span, || { + op.execute(morsel.clone(), state, &compute_runtime) + }) + .await??; + state = result.0; + match result.1 { IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { - if sender.send(mp.into()).await.is_err() { + if sender.send(mp).await.is_err() { return Ok(()); } break; @@ -141,7 +130,7 @@ impl IntermediateNode { break; } IntermediateOperatorResult::HasMoreOutput(mp) => { - if sender.send(mp.into()).await.is_err() { + if sender.send(mp).await.is_err() { return Ok(()); } } @@ -153,66 +142,24 @@ impl IntermediateNode { pub fn spawn_workers( &self, - num_workers: usize, - destination_channel: &mut PipelineChannel, - runtime_handle: &mut ExecutionRuntimeHandle, - ) -> Vec> { - let mut worker_senders = Vec::with_capacity(num_workers); - for _ in 0..num_workers { - let (worker_sender, worker_receiver) = create_channel(1); - let destination_sender = - destination_channel.get_next_sender_with_stats(&self.runtime_stats); + input_receivers: Vec>>, + runtime_handle: &mut ExecutionRuntimeContext, + maintain_order: bool, + ) -> OrderingAwareReceiver> { + let (output_sender, output_receiver) = + create_ordering_aware_receiver_channel(maintain_order, input_receivers.len()); + for (input_receiver, output_sender) in input_receivers.into_iter().zip(output_sender) { runtime_handle.spawn( Self::run_worker( self.intermediate_op.clone(), - worker_receiver, - destination_sender, + input_receiver, + output_sender, self.runtime_stats.clone(), ), self.intermediate_op.name(), ); - worker_senders.push(worker_sender); } - worker_senders - } - - pub async fn send_to_workers( - receivers: Vec, - worker_senders: Vec>, - morsel_size: usize, - ) -> DaftResult<()> { - let mut next_worker_idx = 0; - let mut send_to_next_worker = |idx, data: PipelineResultType| { - let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); - next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); - next_worker_sender.send((idx, data)) - }; - - for (idx, mut receiver) in receivers.into_iter().enumerate() { - let mut buffer = RowBasedBuffer::new(morsel_size); - while let Some(morsel) = receiver.recv().await { - if morsel.should_broadcast() { - for worker_sender in &worker_senders { - if worker_sender.send((idx, morsel.clone())).await.is_err() { - return Ok(()); - } - } - } else { - buffer.push(morsel.as_data()); - if let Some(ready) = buffer.pop_enough()? { - for part in ready { - if send_to_next_worker(idx, part.into()).await.is_err() { - return Ok(()); - } - } - } - } - } - if let Some(ready) = buffer.pop_all()? { - let _ = send_to_next_worker(idx, ready.into()).await; - } - } - Ok(()) + output_receiver } } @@ -248,32 +195,53 @@ impl PipelineNode for IntermediateNode { } fn start( - &mut self, + &self, maintain_order: bool, - runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + runtime_handle: &mut ExecutionRuntimeContext, + ) -> crate::Result>> { let mut child_result_receivers = Vec::with_capacity(self.children.len()); - for child in &mut self.children { - let child_result_channel = child.start(maintain_order, runtime_handle)?; - child_result_receivers - .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats)); + for child in &self.children { + let child_result_receiver = child.start(maintain_order, runtime_handle)?; + child_result_receivers.push(CountingReceiver::new( + child_result_receiver, + self.runtime_stats.clone(), + )); } let op = self.intermediate_op.clone(); let num_workers = op.max_concurrency(); - let mut destination_channel = PipelineChannel::new(num_workers, maintain_order); + let (destination_sender, destination_receiver) = create_channel(1); + let counting_sender = CountingSender::new(destination_sender, self.runtime_stats.clone()); + + let dispatch_spawner = self + .intermediate_op + .dispatch_spawner(runtime_handle, maintain_order); + let spawned_dispatch_result = dispatch_spawner.spawn_dispatch( + child_result_receivers, + num_workers, + &mut runtime_handle.handle(), + ); + runtime_handle.spawn( + async move { spawned_dispatch_result.spawned_dispatch_task.await? }, + self.name(), + ); - let worker_senders = - self.spawn_workers(num_workers, &mut destination_channel, runtime_handle); + let mut output_receiver = self.spawn_workers( + spawned_dispatch_result.worker_receivers, + runtime_handle, + maintain_order, + ); runtime_handle.spawn( - Self::send_to_workers( - child_result_receivers, - worker_senders, - op.morsel_size() - .unwrap_or_else(|| runtime_handle.default_morsel_size()), - ), + async move { + while let Some(morsel) = output_receiver.recv().await { + if counting_sender.send(morsel).await.is_err() { + return Ok(()); + } + } + Ok(()) + }, op.name(), ); - Ok(destination_channel) + Ok(destination_receiver) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/intermediate_ops/project.rs b/src/daft-local-execution/src/intermediate_ops/project.rs index 370de989aa..d0854cf238 100644 --- a/src/daft-local-execution/src/intermediate_ops/project.rs +++ b/src/daft-local-execution/src/intermediate_ops/project.rs @@ -1,21 +1,24 @@ use std::sync::Arc; -use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, + IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator, + IntermediateOperatorResult, }; -use crate::pipeline::PipelineResultType; pub struct ProjectOperator { - projection: Vec, + projection: Arc>, } impl ProjectOperator { pub fn new(projection: Vec) -> Self { - Self { projection } + Self { + projection: Arc::new(projection), + } } } @@ -23,14 +26,20 @@ impl IntermediateOperator for ProjectOperator { #[instrument(skip_all, name = "ProjectOperator::execute")] fn execute( &self, - _idx: usize, - input: &PipelineResultType, - _state: &IntermediateOperatorState, - ) -> DaftResult { - let out = input.as_data().eval_expression_list(&self.projection)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, - )))) + input: Arc, + state: Box, + runtime: &RuntimeRef, + ) -> IntermediateOpExecuteResult { + let projection = self.projection.clone(); + runtime + .spawn(async move { + let out = input.eval_expression_list(&projection)?; + Ok(( + state, + IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(out))), + )) + }) + .into() } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/intermediate_ops/sample.rs b/src/daft-local-execution/src/intermediate_ops/sample.rs index b0e4610292..311b15ae91 100644 --- a/src/daft-local-execution/src/intermediate_ops/sample.rs +++ b/src/daft-local-execution/src/intermediate_ops/sample.rs @@ -1,25 +1,32 @@ use std::sync::Arc; -use common_error::DaftResult; +use common_runtime::RuntimeRef; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, + IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator, + IntermediateOperatorResult, }; -use crate::pipeline::PipelineResultType; -pub struct SampleOperator { +struct SampleParams { fraction: f64, with_replacement: bool, seed: Option, } +pub struct SampleOperator { + params: Arc, +} + impl SampleOperator { pub fn new(fraction: f64, with_replacement: bool, seed: Option) -> Self { Self { - fraction, - with_replacement, - seed, + params: Arc::new(SampleParams { + fraction, + with_replacement, + seed, + }), } } } @@ -28,17 +35,24 @@ impl IntermediateOperator for SampleOperator { #[instrument(skip_all, name = "SampleOperator::execute")] fn execute( &self, - _idx: usize, - input: &PipelineResultType, - _state: &IntermediateOperatorState, - ) -> DaftResult { - let out = - input - .as_data() - .sample_by_fraction(self.fraction, self.with_replacement, self.seed)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, - )))) + input: Arc, + state: Box, + runtime: &RuntimeRef, + ) -> IntermediateOpExecuteResult { + let params = self.params.clone(); + runtime + .spawn(async move { + let out = input.sample_by_fraction( + params.fraction, + params.with_replacement, + params.seed, + )?; + Ok(( + state, + IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(out))), + )) + }) + .into() } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/intermediate_ops/unpivot.rs b/src/daft-local-execution/src/intermediate_ops/unpivot.rs index 5171f9ad42..ca2d4631d5 100644 --- a/src/daft-local-execution/src/intermediate_ops/unpivot.rs +++ b/src/daft-local-execution/src/intermediate_ops/unpivot.rs @@ -1,20 +1,24 @@ use std::sync::Arc; -use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, + IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator, + IntermediateOperatorResult, }; -use crate::pipeline::PipelineResultType; -pub struct UnpivotOperator { +struct UnpivotParams { ids: Vec, values: Vec, variable_name: String, value_name: String, } +pub struct UnpivotOperator { + params: Arc, +} impl UnpivotOperator { pub fn new( @@ -24,10 +28,12 @@ impl UnpivotOperator { value_name: String, ) -> Self { Self { - ids, - values, - variable_name, - value_name, + params: Arc::new(UnpivotParams { + ids, + values, + variable_name, + value_name, + }), } } } @@ -36,19 +42,25 @@ impl IntermediateOperator for UnpivotOperator { #[instrument(skip_all, name = "UnpivotOperator::execute")] fn execute( &self, - _idx: usize, - input: &PipelineResultType, - _state: &IntermediateOperatorState, - ) -> DaftResult { - let out = input.as_data().unpivot( - &self.ids, - &self.values, - &self.variable_name, - &self.value_name, - )?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, - )))) + input: Arc, + state: Box, + runtime: &RuntimeRef, + ) -> IntermediateOpExecuteResult { + let params = self.params.clone(); + runtime + .spawn(async move { + let out = input.unpivot( + ¶ms.ids, + ¶ms.values, + ¶ms.variable_name, + ¶ms.value_name, + )?; + Ok(( + state, + IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(out))), + )) + }) + .into() } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index 719da409c4..bda9bfcd09 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -9,15 +9,58 @@ mod runtime_stats; mod sinks; mod sources; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + use common_error::{DaftError, DaftResult}; +use common_runtime::RuntimeTask; use lazy_static::lazy_static; pub use run::{run_local, NativeExecutor}; -use snafu::{futures::TryFutureExt, Snafu}; +use snafu::{futures::TryFutureExt, ResultExt, Snafu}; lazy_static! { pub static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get(); } +/// The `OperatorOutput` enum represents the output of an operator. +/// It can be either `Ready` or `Pending`. +/// If the output is `Ready`, the value is immediately available. +/// If the output is `Pending`, the value is not yet available and a `RuntimeTask` is returned. +#[pin_project::pin_project(project = OperatorOutputProj)] +pub(crate) enum OperatorOutput { + Ready(Option), + Pending(#[pin] RuntimeTask), +} + +impl Future for OperatorOutput { + type Output = DaftResult; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project() { + OperatorOutputProj::Ready(value) => { + let value = value.take().unwrap(); + Poll::Ready(Ok(value)) + } + OperatorOutputProj::Pending(task) => task.poll(cx), + } + } +} + +impl From for OperatorOutput { + fn from(value: T) -> Self { + Self::Ready(Some(value)) + } +} + +impl From> for OperatorOutput { + fn from(task: RuntimeTask) -> Self { + Self::Pending(task) + } +} + pub(crate) struct TaskSet { inner: tokio::task::JoinSet, } @@ -45,12 +88,34 @@ impl TaskSet { } } -pub struct ExecutionRuntimeHandle { +#[pin_project::pin_project] +struct SpawnedTask(#[pin] tokio::task::JoinHandle); +impl Future for SpawnedTask { + type Output = crate::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().0.poll(cx).map(|r| r.context(JoinSnafu)) + } +} + +struct RuntimeHandle(tokio::runtime::Handle); +impl RuntimeHandle { + fn spawn(&self, future: F) -> SpawnedTask + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let join_handle = self.0.spawn(future); + SpawnedTask(join_handle) + } +} + +pub struct ExecutionRuntimeContext { worker_set: TaskSet>, default_morsel_size: usize, } -impl ExecutionRuntimeHandle { +impl ExecutionRuntimeContext { #[must_use] pub fn new(default_morsel_size: usize) -> Self { Self { @@ -80,6 +145,10 @@ impl ExecutionRuntimeHandle { pub fn default_morsel_size(&self) -> usize { self.default_morsel_size } + + pub(crate) fn handle(&self) -> RuntimeHandle { + RuntimeHandle(tokio::runtime::Handle::current()) + } } #[cfg(feature = "python")] diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 98767b239e..a54e6d8b85 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -19,13 +19,12 @@ use daft_logical_plan::JoinType; use daft_micropartition::MicroPartition; use daft_physical_plan::{extract_agg_expr, populate_aggregation_stages}; use daft_scan::ScanTaskRef; -use daft_table::ProbeState; use daft_writers::make_physical_writer_factory; use indexmap::IndexSet; use snafu::ResultExt; use crate::{ - channel::PipelineChannel, + channel::Receiver, intermediate_ops::{ actor_pool_project::ActorPoolProjectOperator, aggregate::AggregateOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator, explode::ExplodeOperator, @@ -37,7 +36,7 @@ use crate::{ aggregate::AggregateSink, blocking_sink::BlockingSinkNode, concat::ConcatSink, - hash_join_build::HashJoinBuildSink, + hash_join_build::{HashJoinBuildSink, ProbeStateBridge}, limit::LimitSink, outer_hash_join_probe::OuterHashJoinProbeSink, pivot::PivotSink, @@ -46,55 +45,17 @@ use crate::{ write::{WriteFormat, WriteSink}, }, sources::{empty_scan::EmptyScanSource, in_memory::InMemorySource}, - ExecutionRuntimeHandle, PipelineCreationSnafu, + ExecutionRuntimeContext, PipelineCreationSnafu, }; -#[derive(Clone)] -pub enum PipelineResultType { - Data(Arc), - ProbeState(Arc), -} - -impl From> for PipelineResultType { - fn from(data: Arc) -> Self { - Self::Data(data) - } -} - -impl From> for PipelineResultType { - fn from(probe_state: Arc) -> Self { - Self::ProbeState(probe_state) - } -} - -impl PipelineResultType { - pub fn as_data(&self) -> &Arc { - match self { - Self::Data(data) => data, - _ => panic!("Expected data"), - } - } - - pub fn as_probe_state(&self) -> &Arc { - match self { - Self::ProbeState(probe_state) => probe_state, - _ => panic!("Expected probe table"), - } - } - - pub fn should_broadcast(&self) -> bool { - matches!(self, Self::ProbeState(_)) - } -} - -pub trait PipelineNode: Sync + Send + TreeDisplay { +pub(crate) trait PipelineNode: Sync + Send + TreeDisplay { fn children(&self) -> Vec<&dyn PipelineNode>; fn name(&self) -> &'static str; fn start( - &mut self, + &self, maintain_order: bool, - runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result; + runtime_handle: &mut ExecutionRuntimeContext, + ) -> crate::Result>>; fn as_tree_display(&self) -> &dyn TreeDisplay; } @@ -402,11 +363,13 @@ pub fn physical_plan_to_pipeline( .map(|(e, f)| e.clone().cast(&f.dtype)) .collect::>(); // we should move to a builder pattern + let probe_state_bridge = ProbeStateBridge::new(); let build_sink = HashJoinBuildSink::new( key_schema, casted_build_on, null_equals_null.clone(), join_type, + probe_state_bridge.clone(), )?; let build_child_node = physical_plan_to_pipeline(build_child, psets, cfg)?; let build_node = @@ -420,6 +383,7 @@ pub fn physical_plan_to_pipeline( casted_probe_on, join_type, schema, + probe_state_bridge, )), vec![build_node, probe_child_node], ) @@ -432,6 +396,7 @@ pub fn physical_plan_to_pipeline( build_on_left, common_join_keys, schema, + probe_state_bridge, )), vec![build_node, probe_child_node], ) @@ -445,6 +410,7 @@ pub fn physical_plan_to_pipeline( *join_type, common_join_keys, schema, + probe_state_bridge, )), vec![build_node, probe_child_node], ) diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 0f01ec61e6..039df990cd 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -11,6 +11,7 @@ use common_error::DaftResult; use common_tracing::refresh_chrome_trace; use daft_local_plan::{translate, LocalPhysicalPlan}; use daft_micropartition::MicroPartition; +use tokio_util::sync::CancellationToken; #[cfg(feature = "python")] use { common_daft_config::PyDaftExecutionConfig, @@ -22,7 +23,7 @@ use { use crate::{ channel::{create_channel, Receiver}, pipeline::{physical_plan_to_pipeline, viz_pipeline}, - Error, ExecutionRuntimeHandle, + Error, ExecutionRuntimeContext, }; #[cfg(feature = "python")] @@ -46,6 +47,13 @@ impl LocalPartitionIterator { #[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] pub struct NativeExecutor { local_physical_plan: Arc, + cancel: CancellationToken, +} + +impl Drop for NativeExecutor { + fn drop(&mut self) { + self.cancel.cancel(); + } } #[cfg(feature = "python")] @@ -61,6 +69,7 @@ impl NativeExecutor { let local_physical_plan = translate(&logical_plan)?; Ok(Self { local_physical_plan, + cancel: CancellationToken::new(), }) }) } @@ -90,6 +99,7 @@ impl NativeExecutor { native_psets, cfg.config, results_buffer_size, + self.cancel.clone(), ) })?; let iter = Box::new(out.map(|part| { @@ -116,9 +126,10 @@ pub fn run_local( psets: HashMap>>, cfg: Arc, results_buffer_size: Option, + cancel: CancellationToken, ) -> DaftResult>> + Send>> { refresh_chrome_trace(); - let mut pipeline = physical_plan_to_pipeline(physical_plan, &psets, &cfg)?; + let pipeline = physical_plan_to_pipeline(physical_plan, &psets, &cfg)?; let (tx, rx) = create_channel(results_buffer_size.unwrap_or(1)); let handle = std::thread::spawn(move || { let runtime = tokio::runtime::Builder::new_current_thread() @@ -126,11 +137,11 @@ pub fn run_local( .build() .expect("Failed to create tokio runtime"); let execution_task = async { - let mut runtime_handle = ExecutionRuntimeHandle::new(cfg.default_morsel_size); - let mut receiver = pipeline.start(true, &mut runtime_handle)?.get_receiver(); + let mut runtime_handle = ExecutionRuntimeContext::new(cfg.default_morsel_size); + let receiver = pipeline.start(true, &mut runtime_handle)?; while let Some(val) = receiver.recv().await { - if tx.send(val.as_data().clone()).await.is_err() { + if tx.send(val).await.is_err() { break; } } @@ -164,6 +175,10 @@ pub fn run_local( local_set.block_on(&runtime, async { tokio::select! { biased; + () = cancel.cancelled() => { + log::info!("Execution engine cancelled"); + Ok(()) + } _ = tokio::signal::ctrl_c() => { log::info!("Received Ctrl-C, shutting down execution engine"); Ok(()) diff --git a/src/daft-local-execution/src/runtime_stats.rs b/src/daft-local-execution/src/runtime_stats.rs index 566d253e9c..3ef944ba9b 100644 --- a/src/daft-local-execution/src/runtime_stats.rs +++ b/src/daft-local-execution/src/runtime_stats.rs @@ -5,12 +5,10 @@ use std::{ time::Instant, }; -use tokio::sync::mpsc::error::SendError; +use daft_micropartition::MicroPartition; +use loole::SendError; -use crate::{ - channel::{PipelineReceiver, Sender}, - pipeline::PipelineResultType, -}; +use crate::channel::{Receiver, Sender}; #[derive(Default)] pub struct RuntimeStatsContext { @@ -109,51 +107,42 @@ impl RuntimeStatsContext { } pub struct CountingSender { - sender: Sender, + sender: Sender>, rt: Arc, } impl CountingSender { - pub(crate) fn new(sender: Sender, rt: Arc) -> Self { + pub(crate) fn new(sender: Sender>, rt: Arc) -> Self { Self { sender, rt } } #[inline] pub(crate) async fn send( &self, - v: PipelineResultType, - ) -> Result<(), SendError> { - let len = match v { - PipelineResultType::Data(ref mp) => mp.len(), - PipelineResultType::ProbeState(ref state) => { - state.get_tables().iter().map(|t| t.len()).sum() - } - }; + v: Arc, + ) -> Result<(), SendError>> { + self.rt.mark_rows_emitted(v.len() as u64); self.sender.send(v).await?; - self.rt.mark_rows_emitted(len as u64); Ok(()) } } pub struct CountingReceiver { - receiver: PipelineReceiver, + receiver: Receiver>, rt: Arc, } impl CountingReceiver { - pub(crate) fn new(receiver: PipelineReceiver, rt: Arc) -> Self { + pub(crate) fn new( + receiver: Receiver>, + rt: Arc, + ) -> Self { Self { receiver, rt } } #[inline] - pub(crate) async fn recv(&mut self) -> Option { + pub(crate) async fn recv(&self) -> Option> { let v = self.receiver.recv().await; if let Some(ref v) = v { - let len = match v { - PipelineResultType::Data(ref mp) => mp.len(), - PipelineResultType::ProbeState(state) => { - state.get_tables().iter().map(|t| t.len()).sum() - } - }; - self.rt.mark_rows_received(len as u64); + self.rt.mark_rows_received(v.len() as u64); } v } diff --git a/src/daft-local-execution/src/sinks/aggregate.rs b/src/daft-local-execution/src/sinks/aggregate.rs index abc8acce4c..4ac74c8fb5 100644 --- a/src/daft-local-execution/src/sinks/aggregate.rs +++ b/src/daft-local-execution/src/sinks/aggregate.rs @@ -1,12 +1,16 @@ use std::sync::Arc; use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use tracing::instrument; -use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; -use crate::{pipeline::PipelineResultType, NUM_CPUS}; +use super::blocking_sink::{ + BlockingSink, BlockingSinkFinalizeResult, BlockingSinkSinkResult, BlockingSinkState, + BlockingSinkStatus, +}; +use crate::NUM_CPUS; enum AggregateState { Accumulating(Vec>), @@ -39,16 +43,22 @@ impl BlockingSinkState for AggregateState { } } -pub struct AggregateSink { +struct AggParams { agg_exprs: Vec, group_by: Vec, } +pub struct AggregateSink { + agg_sink_params: Arc, +} + impl AggregateSink { pub fn new(agg_exprs: Vec, group_by: Vec) -> Self { Self { - agg_exprs, - group_by, + agg_sink_params: Arc::new(AggParams { + agg_exprs, + group_by, + }), } } } @@ -57,32 +67,39 @@ impl BlockingSink for AggregateSink { #[instrument(skip_all, name = "AggregateSink::sink")] fn sink( &self, - input: &Arc, + input: Arc, mut state: Box, - ) -> DaftResult { + _runtime: &RuntimeRef, + ) -> BlockingSinkSinkResult { state .as_any_mut() .downcast_mut::() .expect("AggregateSink should have AggregateState") - .push(input.clone()); - Ok(BlockingSinkStatus::NeedMoreInput(state)) + .push(input); + Ok(BlockingSinkStatus::NeedMoreInput(state)).into() } #[instrument(skip_all, name = "AggregateSink::finalize")] fn finalize( &self, states: Vec>, - ) -> DaftResult> { - let all_parts = states.into_iter().flat_map(|mut state| { - state - .as_any_mut() - .downcast_mut::() - .expect("AggregateSink should have AggregateState") - .finalize() - }); - let concated = MicroPartition::concat(all_parts)?; - let agged = Arc::new(concated.agg(&self.agg_exprs, &self.group_by)?); - Ok(Some(agged.into())) + runtime: &RuntimeRef, + ) -> BlockingSinkFinalizeResult { + let params = self.agg_sink_params.clone(); + runtime + .spawn(async move { + let all_parts = states.into_iter().flat_map(|mut state| { + state + .as_any_mut() + .downcast_mut::() + .expect("AggregateSink should have AggregateState") + .finalize() + }); + let concated = MicroPartition::concat(all_parts)?; + let agged = Arc::new(concated.agg(¶ms.agg_exprs, ¶ms.group_by)?); + Ok(Some(agged)) + }) + .into() } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 51ef590d09..03660d72f3 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -2,17 +2,17 @@ use std::sync::Arc; use common_display::tree::TreeDisplay; use common_error::DaftResult; -use common_runtime::get_compute_runtime; +use common_runtime::{get_compute_runtime, RuntimeRef}; use daft_micropartition::MicroPartition; use snafu::ResultExt; use tracing::{info_span, instrument}; use crate::{ - channel::{create_channel, PipelineChannel, Receiver}, - dispatcher::{Dispatcher, RoundRobinBufferedDispatcher}, - pipeline::{PipelineNode, PipelineResultType}, - runtime_stats::RuntimeStatsContext, - ExecutionRuntimeHandle, JoinSnafu, TaskSet, + channel::{create_channel, Receiver}, + dispatcher::{DispatchSpawner, UnorderedDispatcher}, + pipeline::PipelineNode, + runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, + ExecutionRuntimeContext, JoinSnafu, OperatorOutput, TaskSet, }; pub trait BlockingSinkState: Send + Sync { fn as_any_mut(&mut self) -> &mut dyn std::any::Any; @@ -24,22 +24,30 @@ pub enum BlockingSinkStatus { Finished(Box), } +pub(crate) type BlockingSinkSinkResult = OperatorOutput>; +pub(crate) type BlockingSinkFinalizeResult = + OperatorOutput>>>; pub trait BlockingSink: Send + Sync { fn sink( &self, - input: &Arc, + input: Arc, state: Box, - ) -> DaftResult; + runtime: &RuntimeRef, + ) -> BlockingSinkSinkResult; fn finalize( &self, states: Vec>, - ) -> DaftResult>; + runtime: &RuntimeRef, + ) -> BlockingSinkFinalizeResult; fn name(&self) -> &'static str; fn make_state(&self) -> DaftResult>; - fn make_dispatcher(&self, runtime_handle: &ExecutionRuntimeHandle) -> Arc { - Arc::new(RoundRobinBufferedDispatcher::new( + fn dispatch_spawner( + &self, + runtime_handle: &ExecutionRuntimeContext, + ) -> Arc { + Arc::new(UnorderedDispatcher::new(Some( runtime_handle.default_morsel_size(), - )) + ))) } fn max_concurrency(&self) -> usize; } @@ -68,19 +76,16 @@ impl BlockingSinkNode { #[instrument(level = "info", skip_all, name = "BlockingSink::run_worker")] async fn run_worker( op: Arc, - mut input_receiver: Receiver, + input_receiver: Receiver>, rt_context: Arc, ) -> DaftResult> { let span = info_span!("BlockingSink::Sink"); let compute_runtime = get_compute_runtime(); let mut state = op.make_state()?; while let Some(morsel) = input_receiver.recv().await { - let op = op.clone(); - let morsel = morsel.clone(); - let span = span.clone(); - let rt_context = rt_context.clone(); - let fut = async move { rt_context.in_span(&span, || op.sink(morsel.as_data(), state)) }; - let result = compute_runtime.spawn(fut).await??; + let result = rt_context + .in_span(&span, || op.sink(morsel, state, &compute_runtime)) + .await??; match result { BlockingSinkStatus::NeedMoreInput(new_state) => { state = new_state; @@ -96,7 +101,7 @@ impl BlockingSinkNode { fn spawn_workers( op: Arc, - input_receivers: Vec>, + input_receivers: Vec>>, task_set: &mut TaskSet>>, stats: Arc, ) { @@ -135,29 +140,29 @@ impl PipelineNode for BlockingSinkNode { } fn start( - &mut self, - maintain_order: bool, - runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { - let child = self.child.as_mut(); - let child_results_receiver = child - .start(false, runtime_handle)? - .get_receiver_with_stats(&self.runtime_stats); - - let mut destination_channel = PipelineChannel::new(1, maintain_order); - let destination_sender = - destination_channel.get_next_sender_with_stats(&self.runtime_stats); + &self, + _maintain_order: bool, + runtime_handle: &mut ExecutionRuntimeContext, + ) -> crate::Result>> { + let child_results_receiver = self.child.start(false, runtime_handle)?; + let counting_receiver = + CountingReceiver::new(child_results_receiver, self.runtime_stats.clone()); + + let (destination_sender, destination_receiver) = create_channel(1); + let counting_sender = CountingSender::new(destination_sender, self.runtime_stats.clone()); + let op = self.op.clone(); let runtime_stats = self.runtime_stats.clone(); let num_workers = op.max_concurrency(); - let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); - let dispatcher = op.make_dispatcher(runtime_handle); + + let dispatch_spawner = op.dispatch_spawner(runtime_handle); + let spawned_dispatch_result = dispatch_spawner.spawn_dispatch( + vec![counting_receiver], + num_workers, + &mut runtime_handle.handle(), + ); runtime_handle.spawn( - async move { - dispatcher - .dispatch(child_results_receiver, input_senders) - .await - }, + async move { spawned_dispatch_result.spawned_dispatch_task.await? }, self.name(), ); @@ -166,7 +171,7 @@ impl PipelineNode for BlockingSinkNode { let mut task_set = TaskSet::new(); Self::spawn_workers( op.clone(), - input_receivers, + spawned_dispatch_result.worker_receivers, &mut task_set, runtime_stats.clone(), ); @@ -178,21 +183,19 @@ impl PipelineNode for BlockingSinkNode { } let compute_runtime = get_compute_runtime(); - let finalized_result = compute_runtime - .spawn(async move { - runtime_stats.in_span(&info_span!("BlockingSinkNode::finalize"), || { - op.finalize(finished_states) - }) + let finalized_result = runtime_stats + .in_span(&info_span!("BlockingSinkNode::finalize"), || { + op.finalize(finished_states, &compute_runtime) }) .await??; if let Some(res) = finalized_result { - let _ = destination_sender.send(res).await; + let _ = counting_sender.send(res).await; } Ok(()) }, self.name(), ); - Ok(destination_channel) + Ok(destination_receiver) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/sinks/concat.rs b/src/daft-local-execution/src/sinks/concat.rs index 3fc710c691..fb765ccd99 100644 --- a/src/daft-local-execution/src/sinks/concat.rs +++ b/src/daft-local-execution/src/sinks/concat.rs @@ -1,19 +1,20 @@ use std::sync::Arc; -use common_error::{DaftError, DaftResult}; +use common_runtime::RuntimeRef; use daft_micropartition::MicroPartition; use tracing::instrument; use super::streaming_sink::{ - DynStreamingSinkState, StreamingSink, StreamingSinkOutput, StreamingSinkState, + StreamingSink, StreamingSinkExecuteResult, StreamingSinkFinalizeResult, StreamingSinkOutput, + StreamingSinkState, +}; +use crate::{ + dispatcher::{DispatchSpawner, RoundRobinDispatcher, UnorderedDispatcher}, + ExecutionRuntimeContext, NUM_CPUS, }; -use crate::pipeline::PipelineResultType; -struct ConcatSinkState { - // The index of the last morsel of data that was received, which should be strictly non-decreasing. - pub curr_idx: usize, -} -impl DynStreamingSinkState for ConcatSinkState { +struct ConcatSinkState {} +impl StreamingSinkState for ConcatSinkState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } @@ -22,27 +23,17 @@ impl DynStreamingSinkState for ConcatSinkState { pub struct ConcatSink {} impl StreamingSink for ConcatSink { - /// Execute for the ConcatSink operator does not do any computation and simply returns the input data. - /// It only expects that the indices of the input data are strictly non-decreasing. - /// TODO(Colin): If maintain_order is false, technically we could accept any index. Make this optimization later. + /// By default, if the streaming_sink is called with maintain_order = true, input is distributed round-robin to the workers, + /// and the output is received in the same order. Therefore, the 'execute' method does not need to do anything. + /// If maintain_order = false, the input is distributed randomly to the workers, and the output is received in random order. #[instrument(skip_all, name = "ConcatSink::sink")] fn execute( &self, - index: usize, - input: &PipelineResultType, - state_handle: &StreamingSinkState, - ) -> DaftResult { - state_handle.with_state_mut::(|state| { - // If the index is the same as the current index or one more than the current index, then we can accept the morsel. - if state.curr_idx == index || state.curr_idx + 1 == index { - state.curr_idx = index; - Ok(StreamingSinkOutput::NeedMoreInput(Some( - input.as_data().clone(), - ))) - } else { - Err(DaftError::ComputeError(format!("Concat sink received out-of-order data. Expected index to be {} or {}, but got {}.", state.curr_idx, state.curr_idx + 1, index))) - } - }) + input: Arc, + state: Box, + _runtime_ref: &RuntimeRef, + ) -> StreamingSinkExecuteResult { + Ok((state, StreamingSinkOutput::NeedMoreInput(Some(input)))).into() } fn name(&self) -> &'static str { @@ -51,17 +42,33 @@ impl StreamingSink for ConcatSink { fn finalize( &self, - _states: Vec>, - ) -> DaftResult>> { - Ok(None) + _states: Vec>, + _runtime_ref: &RuntimeRef, + ) -> StreamingSinkFinalizeResult { + Ok(None).into() } - fn make_state(&self) -> Box { - Box::new(ConcatSinkState { curr_idx: 0 }) + fn make_state(&self) -> Box { + Box::new(ConcatSinkState {}) } - /// Since the ConcatSink does not do any computation, it does not need to spawn multiple workers. fn max_concurrency(&self) -> usize { - 1 + *NUM_CPUS + } + + fn dispatch_spawner( + &self, + runtime_handle: &ExecutionRuntimeContext, + maintain_order: bool, + ) -> Arc { + if maintain_order { + Arc::new(RoundRobinDispatcher::new(Some( + runtime_handle.default_morsel_size(), + ))) + } else { + Arc::new(UnorderedDispatcher::new(Some( + runtime_handle.default_morsel_size(), + ))) + } } } diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index ba8e85d0d0..594420f0f8 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -1,14 +1,53 @@ -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use daft_logical_plan::JoinType; use daft_micropartition::MicroPartition; use daft_table::{make_probeable_builder, ProbeState, ProbeableBuilder, Table}; -use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; -use crate::pipeline::PipelineResultType; +use super::blocking_sink::{ + BlockingSink, BlockingSinkFinalizeResult, BlockingSinkSinkResult, BlockingSinkState, + BlockingSinkStatus, +}; + +/// ProbeStateBridge is a bridge between the build and probe phase of a hash join. +/// It is used to pass the probe state from the build phase to the probe phase. +/// The build phase sets the probe state once building is complete, and the probe phase +/// waits for the probe state to be set via the `get_probe_state` method. +pub(crate) type ProbeStateBridgeRef = Arc; +pub(crate) struct ProbeStateBridge { + inner: OnceLock>, + notify: tokio::sync::Notify, +} + +impl ProbeStateBridge { + pub(crate) fn new() -> Arc { + Arc::new(Self { + inner: OnceLock::new(), + notify: tokio::sync::Notify::new(), + }) + } + + pub(crate) fn set_probe_state(&self, state: Arc) { + assert!( + !self.inner.set(state).is_err(), + "ProbeStateBridge should be set only once" + ); + self.notify.notify_waiters(); + } + + pub(crate) async fn get_probe_state(&self) -> Arc { + loop { + if let Some(state) = self.inner.get() { + return state.clone(); + } + self.notify.notified().await; + } + } +} enum ProbeTableState { Building { @@ -16,9 +55,7 @@ enum ProbeTableState { projection: Vec, tables: Vec, }, - Done { - probe_state: Arc, - }, + Done, } impl ProbeTableState { @@ -59,7 +96,7 @@ impl ProbeTableState { panic!("add_tables can only be used during the Building Phase") } } - fn finalize(&mut self) -> DaftResult<()> { + fn finalize(&mut self) -> ProbeState { if let Self::Building { probe_table_builder, tables, @@ -69,10 +106,9 @@ impl ProbeTableState { let ptb = std::mem::take(probe_table_builder).expect("should be set in building mode"); let pt = ptb.build(); - *self = Self::Done { - probe_state: Arc::new(ProbeState::new(pt, Arc::new(tables.clone()))), - }; - Ok(()) + let ps = ProbeState::new(pt, tables.clone().into()); + *self = Self::Done; + ps } else { panic!("finalize can only be used during the Building Phase") } @@ -90,6 +126,7 @@ pub struct HashJoinBuildSink { projection: Vec, nulls_equal_aware: Option>, join_type: JoinType, + probe_state_bridge: ProbeStateBridgeRef, } impl HashJoinBuildSink { @@ -98,12 +135,14 @@ impl HashJoinBuildSink { projection: Vec, nulls_equal_aware: Option>, join_type: &JoinType, + probe_state_bridge: ProbeStateBridgeRef, ) -> DaftResult { Ok(Self { key_schema, projection, nulls_equal_aware, join_type: *join_type, + probe_state_bridge, }) } } @@ -115,33 +154,37 @@ impl BlockingSink for HashJoinBuildSink { fn sink( &self, - input: &Arc, + input: Arc, mut state: Box, - ) -> DaftResult { - state - .as_any_mut() - .downcast_mut::() - .expect("HashJoinBuildSink should have ProbeTableState") - .add_tables(input)?; - Ok(BlockingSinkStatus::NeedMoreInput(state)) + runtime: &RuntimeRef, + ) -> BlockingSinkSinkResult { + runtime + .spawn(async move { + let probe_table_state: &mut ProbeTableState = state + .as_any_mut() + .downcast_mut::() + .expect("HashJoinBuildSink should have ProbeTableState"); + probe_table_state.add_tables(&input)?; + Ok(BlockingSinkStatus::NeedMoreInput(state)) + }) + .into() } fn finalize( &self, states: Vec>, - ) -> DaftResult> { + _runtime: &RuntimeRef, + ) -> BlockingSinkFinalizeResult { assert_eq!(states.len(), 1); let mut state = states.into_iter().next().unwrap(); let probe_table_state = state .as_any_mut() .downcast_mut::() .expect("State type mismatch"); - probe_table_state.finalize()?; - if let ProbeTableState::Done { probe_state } = probe_table_state { - Ok(Some(probe_state.clone().into())) - } else { - panic!("finalize should only be called after the probe table is built") - } + let finalized_probe_state = probe_table_state.finalize(); + self.probe_state_bridge + .set_probe_state(finalized_probe_state.into()); + Ok(None).into() } fn max_concurrency(&self) -> usize { diff --git a/src/daft-local-execution/src/sinks/limit.rs b/src/daft-local-execution/src/sinks/limit.rs index ff24a703e2..0f237c52a9 100644 --- a/src/daft-local-execution/src/sinks/limit.rs +++ b/src/daft-local-execution/src/sinks/limit.rs @@ -1,13 +1,17 @@ use std::sync::Arc; -use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_micropartition::MicroPartition; use tracing::instrument; use super::streaming_sink::{ - DynStreamingSinkState, StreamingSink, StreamingSinkOutput, StreamingSinkState, + StreamingSink, StreamingSinkExecuteResult, StreamingSinkFinalizeResult, StreamingSinkOutput, + StreamingSinkState, +}; +use crate::{ + dispatcher::{DispatchSpawner, UnorderedDispatcher}, + ExecutionRuntimeContext, }; -use crate::pipeline::PipelineResultType; struct LimitSinkState { remaining: usize, @@ -23,7 +27,7 @@ impl LimitSinkState { } } -impl DynStreamingSinkState for LimitSinkState { +impl StreamingSinkState for LimitSinkState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } @@ -43,33 +47,38 @@ impl StreamingSink for LimitSink { #[instrument(skip_all, name = "LimitSink::sink")] fn execute( &self, - index: usize, - input: &PipelineResultType, - state_handle: &StreamingSinkState, - ) -> DaftResult { - assert_eq!(index, 0); - let input = input.as_data(); + input: Arc, + mut state: Box, + runtime_ref: &RuntimeRef, + ) -> StreamingSinkExecuteResult { let input_num_rows = input.len(); - state_handle.with_state_mut::(|state| { - let remaining = state.get_remaining_mut(); - use std::cmp::Ordering::{Equal, Greater, Less}; - match input_num_rows.cmp(remaining) { - Less => { - *remaining -= input_num_rows; - Ok(StreamingSinkOutput::NeedMoreInput(Some(input.clone()))) - } - Equal => { - *remaining = 0; - Ok(StreamingSinkOutput::Finished(Some(input.clone()))) - } - Greater => { - let taken = input.head(*remaining)?; - *remaining = 0; - Ok(StreamingSinkOutput::Finished(Some(Arc::new(taken)))) - } + let remaining = state + .as_any_mut() + .downcast_mut::() + .expect("Limit sink should have LimitSinkState") + .get_remaining_mut(); + use std::cmp::Ordering::{Equal, Greater, Less}; + match input_num_rows.cmp(remaining) { + Less => { + *remaining -= input_num_rows; + Ok((state, StreamingSinkOutput::NeedMoreInput(Some(input)))).into() + } + Equal => { + *remaining = 0; + Ok((state, StreamingSinkOutput::Finished(Some(input)))).into() + } + Greater => { + let to_head = *remaining; + *remaining = 0; + runtime_ref + .spawn(async move { + let taken = input.head(to_head)?; + Ok((state, StreamingSinkOutput::Finished(Some(taken.into())))) + }) + .into() } - }) + } } fn name(&self) -> &'static str { @@ -78,16 +87,27 @@ impl StreamingSink for LimitSink { fn finalize( &self, - _states: Vec>, - ) -> DaftResult>> { - Ok(None) + _states: Vec>, + _runtime_ref: &RuntimeRef, + ) -> StreamingSinkFinalizeResult { + Ok(None).into() } - fn make_state(&self) -> Box { + fn make_state(&self) -> Box { Box::new(LimitSinkState::new(self.limit)) } fn max_concurrency(&self) -> usize { 1 } + + fn dispatch_spawner( + &self, + _runtime_handle: &ExecutionRuntimeContext, + _maintain_order: bool, + ) -> Arc { + // Limits are greedy, so we don't need to buffer any input. + // They are also not concurrent, so we don't need to worry about ordering. + Arc::new(UnorderedDispatcher::new(None)) + } } diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index 88f68aa202..a8ca50130f 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_core::{ prelude::{ bitmap::{and, Bitmap, MutableBitmap}, @@ -12,13 +13,21 @@ use daft_dsl::ExprRef; use daft_logical_plan::JoinType; use daft_micropartition::MicroPartition; use daft_table::{GrowableTable, ProbeState, Table}; +use futures::{stream, StreamExt}; use indexmap::IndexSet; use tracing::{info_span, instrument}; -use super::streaming_sink::{ - DynStreamingSinkState, StreamingSink, StreamingSinkOutput, StreamingSinkState, +use super::{ + hash_join_build::ProbeStateBridgeRef, + streaming_sink::{ + StreamingSink, StreamingSinkExecuteResult, StreamingSinkFinalizeResult, + StreamingSinkOutput, StreamingSinkState, + }, +}; +use crate::{ + dispatcher::{DispatchSpawner, RoundRobinDispatcher, UnorderedDispatcher}, + ExecutionRuntimeContext, }; -use crate::pipeline::PipelineResultType; struct IndexBitmapBuilder { mutable_bitmaps: Vec, @@ -69,59 +78,60 @@ impl IndexBitmap { } } -enum OuterHashJoinProbeState { - Building, - ReadyToProbe(Arc, Option), +enum OuterHashJoinState { + Building(ProbeStateBridgeRef, bool), + Probing(Arc, Option), } -impl OuterHashJoinProbeState { - fn initialize_probe_state(&mut self, probe_state: Arc, needs_bitmap: bool) { - let tables = probe_state.get_tables(); - if matches!(self, Self::Building) { - *self = Self::ReadyToProbe( - probe_state.clone(), - if needs_bitmap { - Some(IndexBitmapBuilder::new(tables)) - } else { - None - }, - ); - } else { - panic!("OuterHashJoinProbeState should only be in Building state when setting table") - } - } - - fn get_probe_state(&self) -> &ProbeState { - if let Self::ReadyToProbe(probe_state, _) = self { - probe_state - } else { - panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") +impl OuterHashJoinState { + async fn get_or_build_probe_state(&mut self) -> Arc { + match self { + Self::Building(bridge, needs_bitmap) => { + let probe_state = bridge.get_probe_state().await; + let builder = + needs_bitmap.then(|| IndexBitmapBuilder::new(probe_state.get_tables())); + *self = Self::Probing(probe_state.clone(), builder); + probe_state + } + Self::Probing(probe_state, _) => probe_state.clone(), } } - fn get_bitmap_builder(&mut self) -> &mut Option { - if let Self::ReadyToProbe(_, bitmap_builder) = self { - bitmap_builder - } else { - panic!("get_bitmap can only be used during the ReadyToProbe Phase") + async fn get_or_build_bitmap(&mut self) -> &mut Option { + match self { + Self::Building(bridge, _) => { + let probe_state = bridge.get_probe_state().await; + let builder = IndexBitmapBuilder::new(probe_state.get_tables()); + *self = Self::Probing(probe_state, Some(builder)); + match self { + Self::Probing(_, builder) => builder, + _ => unreachable!(), + } + } + Self::Probing(_, builder) => builder, } } } -impl DynStreamingSinkState for OuterHashJoinProbeState { +impl StreamingSinkState for OuterHashJoinState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } } -pub(crate) struct OuterHashJoinProbeSink { +struct OuterHashJoinParams { probe_on: Vec, common_join_keys: Vec, left_non_join_columns: Vec, right_non_join_columns: Vec, right_non_join_schema: SchemaRef, join_type: JoinType, +} + +pub(crate) struct OuterHashJoinProbeSink { + params: Arc, output_schema: SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, } impl OuterHashJoinProbeSink { @@ -132,6 +142,7 @@ impl OuterHashJoinProbeSink { join_type: JoinType, common_join_keys: IndexSet, output_schema: &SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, ) -> Self { let left_non_join_columns = left_schema .fields @@ -150,25 +161,30 @@ impl OuterHashJoinProbeSink { let right_non_join_columns = right_non_join_schema.fields.keys().cloned().collect(); let common_join_keys = common_join_keys.into_iter().collect(); Self { - probe_on, - common_join_keys, - left_non_join_columns, - right_non_join_columns, - right_non_join_schema, - join_type, + params: Arc::new(OuterHashJoinParams { + probe_on, + common_join_keys, + left_non_join_columns, + right_non_join_columns, + right_non_join_schema, + join_type, + }), output_schema: output_schema.clone(), + probe_state_bridge, } } fn probe_left_right( - &self, input: &Arc, - state: &OuterHashJoinProbeState, + probe_state: &ProbeState, + join_type: JoinType, + probe_on: &[ExprRef], + common_join_keys: &[String], + left_non_join_columns: &[String], + right_non_join_columns: &[String], ) -> DaftResult> { - let (probe_table, tables) = { - let probe_state = state.get_probe_state(); - (probe_state.get_probeable(), probe_state.get_tables()) - }; + let probe_table = probe_state.get_probeable().clone(); + let tables = probe_state.get_tables().clone(); let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); let mut build_side_growable = GrowableTable::new( @@ -185,7 +201,7 @@ impl OuterHashJoinProbeSink { { let _loop = info_span!("OuterHashJoinProbeSink::eval_and_probe").entered(); for (probe_side_table_idx, table) in input_tables.iter().enumerate() { - let join_keys = table.eval_expression_list(&self.probe_on)?; + let join_keys = table.eval_expression_list(probe_on)?; let idx_mapper = probe_table.probe_indices(&join_keys)?; for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { @@ -209,15 +225,15 @@ impl OuterHashJoinProbeSink { let build_side_table = build_side_growable.build()?; let probe_side_table = probe_side_growable.build()?; - let final_table = if self.join_type == JoinType::Left { - let join_table = probe_side_table.get_columns(&self.common_join_keys)?; - let left = probe_side_table.get_columns(&self.left_non_join_columns)?; - let right = build_side_table.get_columns(&self.right_non_join_columns)?; + let final_table = if join_type == JoinType::Left { + let join_table = probe_side_table.get_columns(common_join_keys)?; + let left = probe_side_table.get_columns(left_non_join_columns)?; + let right = build_side_table.get_columns(right_non_join_columns)?; join_table.union(&left)?.union(&right)? } else { - let join_table = probe_side_table.get_columns(&self.common_join_keys)?; - let left = build_side_table.get_columns(&self.left_non_join_columns)?; - let right = probe_side_table.get_columns(&self.right_non_join_columns)?; + let join_table = probe_side_table.get_columns(common_join_keys)?; + let left = build_side_table.get_columns(left_non_join_columns)?; + let right = probe_side_table.get_columns(right_non_join_columns)?; join_table.union(&left)?.union(&right)? }; Ok(Arc::new(MicroPartition::new_loaded( @@ -228,18 +244,17 @@ impl OuterHashJoinProbeSink { } fn probe_outer( - &self, input: &Arc, - state: &mut OuterHashJoinProbeState, + probe_state: &ProbeState, + bitmap_builder: &mut IndexBitmapBuilder, + probe_on: &[ExprRef], + common_join_keys: &[String], + left_non_join_columns: &[String], + right_non_join_columns: &[String], ) -> DaftResult> { - let (probe_table, tables) = { - let probe_state = state.get_probe_state(); - ( - probe_state.get_probeable().clone(), - probe_state.get_tables().clone(), - ) - }; - let bitmap_builder = state.get_bitmap_builder(); + let probe_table = probe_state.get_probeable().clone(); + let tables = probe_state.get_tables().clone(); + let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); // Need to set use_validity to true here because we add nulls to the build side let mut build_side_growable = GrowableTable::new( @@ -252,15 +267,11 @@ impl OuterHashJoinProbeSink { let mut probe_side_growable = GrowableTable::new(&input_tables.iter().collect::>(), false, input.len())?; - let left_idx_used = bitmap_builder - .as_mut() - .expect("bitmap should be set in outer join"); - drop(_growables); { let _loop = info_span!("OuterHashJoinProbeSink::eval_and_probe").entered(); for (probe_side_table_idx, table) in input_tables.iter().enumerate() { - let join_keys = table.eval_expression_list(&self.probe_on)?; + let join_keys = table.eval_expression_list(probe_on)?; let idx_mapper = probe_table.probe_indices(&join_keys)?; for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { @@ -268,7 +279,7 @@ impl OuterHashJoinProbeSink { for (build_side_table_idx, build_row_idx) in inner_iter { let build_side_table_idx = build_side_table_idx as usize; let build_row_idx = build_row_idx as usize; - left_idx_used.mark_used(build_side_table_idx, build_row_idx); + bitmap_builder.mark_used(build_side_table_idx, build_row_idx); build_side_growable.extend(build_side_table_idx, build_row_idx, 1); probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); } @@ -283,9 +294,9 @@ impl OuterHashJoinProbeSink { let build_side_table = build_side_growable.build()?; let probe_side_table = probe_side_growable.build()?; - let join_table = probe_side_table.get_columns(&self.common_join_keys)?; - let left = build_side_table.get_columns(&self.left_non_join_columns)?; - let right = probe_side_table.get_columns(&self.right_non_join_columns)?; + let join_table = probe_side_table.get_columns(common_join_keys)?; + let left = build_side_table.get_columns(left_non_join_columns)?; + let right = probe_side_table.get_columns(right_non_join_columns)?; let final_table = join_table.union(&left)?.union(&right)?; Ok(Arc::new(MicroPartition::new_loaded( final_table.schema.clone(), @@ -294,37 +305,49 @@ impl OuterHashJoinProbeSink { ))) } - fn finalize_outer( - &self, - mut states: Vec>, + async fn finalize_outer( + mut states: Vec>, + common_join_keys: &[String], + left_non_join_columns: &[String], + right_non_join_schema: &SchemaRef, ) -> DaftResult>> { - let states = states - .iter_mut() - .map(|s| { - s.as_any_mut() - .downcast_mut::() - .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState") - }) - .collect::>(); - let tables = states - .first() + let mut states_iter = states.iter_mut(); + let first_state = states_iter + .next() .expect("at least one state should be present") - .get_probe_state() + .as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); + let tables = first_state + .get_or_build_probe_state() + .await .get_tables() .clone(); + let first_bitmap = first_state + .get_or_build_bitmap() + .await + .take() + .expect("bitmap should be set") + .build(); let merged_bitmap = { - let bitmaps = states.into_iter().map(|s| { - if let OuterHashJoinProbeState::ReadyToProbe(_, bitmap) = s { - bitmap + let bitmaps = stream::once(async move { first_bitmap }) + .chain(stream::iter(states_iter).then(|s| async move { + let state = s + .as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); + state + .get_or_build_bitmap() + .await .take() - .expect("bitmap should be present in outer join") + .expect("bitmap should be set") .build() - } else { - panic!("OuterHashJoinProbeState should be in ReadyToProbe state") - } - }); - bitmaps.fold(None, |acc, x| match acc { + })) + .collect::>() + .await; + + bitmaps.into_iter().fold(None, |acc, x| match acc { None => Some(x), Some(acc) => Some(acc.merge(&x)), }) @@ -339,16 +362,15 @@ impl OuterHashJoinProbeSink { let build_side_table = Table::concat(&leftovers)?; - let join_table = build_side_table.get_columns(&self.common_join_keys)?; - let left = build_side_table.get_columns(&self.left_non_join_columns)?; + let join_table = build_side_table.get_columns(common_join_keys)?; + let left = build_side_table.get_columns(left_non_join_columns)?; let right = { - let columns = self - .right_non_join_schema + let columns = right_non_join_schema .fields .values() .map(|field| Series::full_null(&field.name, &field.dtype, left.len())) .collect::>(); - Table::new_unchecked(self.right_non_join_schema.clone(), columns, left.len()) + Table::new_unchecked(right_non_join_schema.clone(), columns, left.len()) }; let final_table = join_table.union(&left)?.union(&right)?; Ok(Some(Arc::new(MicroPartition::new_loaded( @@ -363,54 +385,105 @@ impl StreamingSink for OuterHashJoinProbeSink { #[instrument(skip_all, name = "OuterHashJoinProbeSink::execute")] fn execute( &self, - idx: usize, - input: &PipelineResultType, - state_handle: &StreamingSinkState, - ) -> DaftResult { - match idx { - 0 => { - state_handle.with_state_mut::(|state| { - state.initialize_probe_state( - input.as_probe_state().clone(), - self.join_type == JoinType::Outer, - ); - }); - Ok(StreamingSinkOutput::NeedMoreInput(None)) - } - _ => state_handle.with_state_mut::(|state| { - let input = input.as_data(); - if input.is_empty() { - let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); - return Ok(StreamingSinkOutput::NeedMoreInput(Some(empty))); - } - let out = match self.join_type { - JoinType::Left | JoinType::Right => self.probe_left_right(input, state), - JoinType::Outer => self.probe_outer(input, state), + input: Arc, + mut state: Box, + runtime_ref: &RuntimeRef, + ) -> StreamingSinkExecuteResult { + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok((state, StreamingSinkOutput::NeedMoreInput(Some(empty)))).into(); + } + + let params = self.params.clone(); + runtime_ref + .spawn(async move { + let outer_join_state = state + .as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink should have OuterHashJoinProbeState"); + let probe_state = outer_join_state.get_or_build_probe_state().await; + let out = match params.join_type { + JoinType::Left | JoinType::Right => Self::probe_left_right( + &input, + &probe_state, + params.join_type, + ¶ms.probe_on, + ¶ms.common_join_keys, + ¶ms.left_non_join_columns, + ¶ms.right_non_join_columns, + ), + JoinType::Outer => { + let bitmap_builder = outer_join_state + .get_or_build_bitmap() + .await + .as_mut() + .expect("bitmap should be set"); + Self::probe_outer( + &input, + &probe_state, + bitmap_builder, + ¶ms.probe_on, + ¶ms.common_join_keys, + ¶ms.left_non_join_columns, + ¶ms.right_non_join_columns, + ) + } _ => unreachable!( "Only Left, Right, and Outer joins are supported in OuterHashJoinProbeSink" ), }?; - Ok(StreamingSinkOutput::NeedMoreInput(Some(out))) - }), - } + Ok((state, StreamingSinkOutput::NeedMoreInput(Some(out)))) + }) + .into() } fn name(&self) -> &'static str { "OuterHashJoinProbeSink" } - fn make_state(&self) -> Box { - Box::new(OuterHashJoinProbeState::Building) + fn make_state(&self) -> Box { + Box::new(OuterHashJoinState::Building( + self.probe_state_bridge.clone(), + self.params.join_type == JoinType::Outer, + )) } fn finalize( &self, - states: Vec>, - ) -> DaftResult>> { - if self.join_type == JoinType::Outer { - self.finalize_outer(states) + states: Vec>, + runtime_ref: &RuntimeRef, + ) -> StreamingSinkFinalizeResult { + if self.params.join_type == JoinType::Outer { + let params = self.params.clone(); + runtime_ref + .spawn(async move { + Self::finalize_outer( + states, + ¶ms.common_join_keys, + ¶ms.left_non_join_columns, + ¶ms.right_non_join_schema, + ) + .await + }) + .into() + } else { + Ok(None).into() + } + } + + fn dispatch_spawner( + &self, + runtime_handle: &ExecutionRuntimeContext, + maintain_order: bool, + ) -> Arc { + if maintain_order { + Arc::new(RoundRobinDispatcher::new(Some( + runtime_handle.default_morsel_size(), + ))) } else { - Ok(None) + Arc::new(UnorderedDispatcher::new(Some( + runtime_handle.default_morsel_size(), + ))) } } } diff --git a/src/daft-local-execution/src/sinks/pivot.rs b/src/daft-local-execution/src/sinks/pivot.rs index 93cb7b3843..0a4ceb54a1 100644 --- a/src/daft-local-execution/src/sinks/pivot.rs +++ b/src/daft-local-execution/src/sinks/pivot.rs @@ -1,12 +1,16 @@ use std::sync::Arc; use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_dsl::{AggExpr, Expr, ExprRef}; use daft_micropartition::MicroPartition; use tracing::instrument; -use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; -use crate::{pipeline::PipelineResultType, NUM_CPUS}; +use super::blocking_sink::{ + BlockingSink, BlockingSinkFinalizeResult, BlockingSinkSinkResult, BlockingSinkState, + BlockingSinkStatus, +}; +use crate::NUM_CPUS; enum PivotState { Accumulating(Vec>), @@ -39,12 +43,16 @@ impl BlockingSinkState for PivotState { } } +struct PivotParams { + group_by: Vec, + pivot_column: ExprRef, + value_column: ExprRef, + aggregation: AggExpr, + names: Vec, +} + pub struct PivotSink { - pub group_by: Vec, - pub pivot_column: ExprRef, - pub value_column: ExprRef, - pub aggregation: AggExpr, - pub names: Vec, + pivot_params: Arc, } impl PivotSink { @@ -56,11 +64,13 @@ impl PivotSink { names: Vec, ) -> Self { Self { - group_by, - pivot_column, - value_column, - aggregation, - names, + pivot_params: Arc::new(PivotParams { + group_by, + pivot_column, + value_column, + aggregation, + names, + }), } } } @@ -69,47 +79,54 @@ impl BlockingSink for PivotSink { #[instrument(skip_all, name = "PivotSink::sink")] fn sink( &self, - input: &Arc, + input: Arc, mut state: Box, - ) -> DaftResult { + _runtime: &RuntimeRef, + ) -> BlockingSinkSinkResult { state .as_any_mut() .downcast_mut::() .expect("PivotSink should have PivotState") - .push(input.clone()); - Ok(BlockingSinkStatus::NeedMoreInput(state)) + .push(input); + Ok(BlockingSinkStatus::NeedMoreInput(state)).into() } #[instrument(skip_all, name = "PivotSink::finalize")] fn finalize( &self, states: Vec>, - ) -> DaftResult> { - let all_parts = states.into_iter().flat_map(|mut state| { - state - .as_any_mut() - .downcast_mut::() - .expect("PivotSink should have PivotState") - .finalize() - }); - let concated = MicroPartition::concat(all_parts)?; - let group_by_with_pivot = self - .group_by - .iter() - .chain(std::iter::once(&self.pivot_column)) - .cloned() - .collect::>(); - let agged = concated.agg( - &[Expr::Agg(self.aggregation.clone()).into()], - &group_by_with_pivot, - )?; - let pivoted = Arc::new(agged.pivot( - &self.group_by, - self.pivot_column.clone(), - self.value_column.clone(), - self.names.clone(), - )?); - Ok(Some(pivoted.into())) + runtime: &RuntimeRef, + ) -> BlockingSinkFinalizeResult { + let pivot_params = self.pivot_params.clone(); + runtime + .spawn(async move { + let all_parts = states.into_iter().flat_map(|mut state| { + state + .as_any_mut() + .downcast_mut::() + .expect("PivotSink should have PivotState") + .finalize() + }); + let concated = MicroPartition::concat(all_parts)?; + let group_by_with_pivot = pivot_params + .group_by + .iter() + .chain(std::iter::once(&pivot_params.pivot_column)) + .cloned() + .collect::>(); + let agged = concated.agg( + &[Expr::Agg(pivot_params.aggregation.clone()).into()], + &group_by_with_pivot, + )?; + let pivoted = Arc::new(agged.pivot( + &pivot_params.group_by, + pivot_params.pivot_column.clone(), + pivot_params.value_column.clone(), + pivot_params.names.clone(), + )?); + Ok(Some(pivoted)) + }) + .into() } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/sinks/sort.rs b/src/daft-local-execution/src/sinks/sort.rs index 83c933d1ec..9c2b3f3944 100644 --- a/src/daft-local-execution/src/sinks/sort.rs +++ b/src/daft-local-execution/src/sinks/sort.rs @@ -1,12 +1,16 @@ use std::sync::Arc; use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use tracing::instrument; -use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; -use crate::{pipeline::PipelineResultType, NUM_CPUS}; +use super::blocking_sink::{ + BlockingSink, BlockingSinkFinalizeResult, BlockingSinkSinkResult, BlockingSinkState, + BlockingSinkStatus, +}; +use crate::NUM_CPUS; enum SortState { Building(Vec>), @@ -38,16 +42,22 @@ impl BlockingSinkState for SortState { self } } -pub struct SortSink { + +struct SortParams { sort_by: Vec, descending: Vec, } +pub struct SortSink { + params: Arc, +} impl SortSink { pub fn new(sort_by: Vec, descending: Vec) -> Self { Self { - sort_by, - descending, + params: Arc::new(SortParams { + sort_by, + descending, + }), } } } @@ -56,32 +66,39 @@ impl BlockingSink for SortSink { #[instrument(skip_all, name = "SortSink::sink")] fn sink( &self, - input: &Arc, + input: Arc, mut state: Box, - ) -> DaftResult { + _runtime_ref: &RuntimeRef, + ) -> BlockingSinkSinkResult { state .as_any_mut() .downcast_mut::() .expect("SortSink should have sort state") - .push(input.clone()); - Ok(BlockingSinkStatus::NeedMoreInput(state)) + .push(input); + Ok(BlockingSinkStatus::NeedMoreInput(state)).into() } #[instrument(skip_all, name = "SortSink::finalize")] fn finalize( &self, states: Vec>, - ) -> DaftResult> { - let parts = states.into_iter().flat_map(|mut state| { - let state = state - .as_any_mut() - .downcast_mut::() - .expect("State type mismatch"); - state.finalize() - }); - let concated = MicroPartition::concat(parts)?; - let sorted = Arc::new(concated.sort(&self.sort_by, &self.descending)?); - Ok(Some(sorted.into())) + runtime: &RuntimeRef, + ) -> BlockingSinkFinalizeResult { + let params = self.params.clone(); + runtime + .spawn(async move { + let parts = states.into_iter().flat_map(|mut state| { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("State type mismatch"); + state.finalize() + }); + let concated = MicroPartition::concat(parts)?; + let sorted = Arc::new(concated.sort(¶ms.sort_by, ¶ms.descending)?); + Ok(Some(sorted)) + }) + .into() } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 911b52d652..9d5452cf4b 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -1,47 +1,27 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use common_display::tree::TreeDisplay; use common_error::DaftResult; -use common_runtime::get_compute_runtime; +use common_runtime::{get_compute_runtime, RuntimeRef}; use daft_micropartition::MicroPartition; use snafu::ResultExt; use tracing::{info_span, instrument}; use crate::{ - channel::{create_channel, PipelineChannel, Receiver, Sender}, - pipeline::{PipelineNode, PipelineResultType}, - runtime_stats::{CountingReceiver, RuntimeStatsContext}, - ExecutionRuntimeHandle, JoinSnafu, TaskSet, NUM_CPUS, + channel::{ + create_channel, create_ordering_aware_receiver_channel, OrderingAwareReceiver, Receiver, + Sender, + }, + dispatcher::DispatchSpawner, + pipeline::PipelineNode, + runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, + ExecutionRuntimeContext, JoinSnafu, OperatorOutput, TaskSet, NUM_CPUS, }; -pub trait DynStreamingSinkState: Send + Sync { +pub trait StreamingSinkState: Send + Sync { fn as_any_mut(&mut self) -> &mut dyn std::any::Any; } -pub(crate) struct StreamingSinkState { - inner: Mutex>, -} - -impl StreamingSinkState { - fn new(inner: Box) -> Arc { - Arc::new(Self { - inner: Mutex::new(inner), - }) - } - - pub(crate) fn with_state_mut(&self, f: F) -> R - where - F: FnOnce(&mut T) -> R, - { - let mut guard = self.inner.lock().unwrap(); - let state = guard - .as_any_mut() - .downcast_mut::() - .expect("State type mismatch"); - f(state) - } -} - pub enum StreamingSinkOutput { NeedMoreInput(Option>), #[allow(dead_code)] @@ -49,34 +29,45 @@ pub enum StreamingSinkOutput { Finished(Option>), } +pub(crate) type StreamingSinkExecuteResult = + OperatorOutput, StreamingSinkOutput)>>; +pub(crate) type StreamingSinkFinalizeResult = + OperatorOutput>>>; pub trait StreamingSink: Send + Sync { /// Execute the StreamingSink operator on the morsel of input data, /// received from the child with the given index, /// with the given state. fn execute( &self, - index: usize, - input: &PipelineResultType, - state_handle: &StreamingSinkState, - ) -> DaftResult; + input: Arc, + state: Box, + runtime: &RuntimeRef, + ) -> StreamingSinkExecuteResult; /// Finalize the StreamingSink operator, with the given states from each worker. fn finalize( &self, - states: Vec>, - ) -> DaftResult>>; + states: Vec>, + runtime: &RuntimeRef, + ) -> StreamingSinkFinalizeResult; /// The name of the StreamingSink operator. fn name(&self) -> &'static str; /// Create a new worker-local state for this StreamingSink. - fn make_state(&self) -> Box; + fn make_state(&self) -> Box; /// The maximum number of concurrent workers that can be spawned for this sink. /// Each worker will has its own StreamingSinkState. fn max_concurrency(&self) -> usize { *NUM_CPUS } + + fn dispatch_spawner( + &self, + runtime_handle: &ExecutionRuntimeContext, + maintain_order: bool, + ) -> Arc; } pub struct StreamingSinkNode { @@ -104,110 +95,66 @@ impl StreamingSinkNode { #[instrument(level = "info", skip_all, name = "StreamingSink::run_worker")] async fn run_worker( op: Arc, - mut input_receiver: Receiver<(usize, PipelineResultType)>, + input_receiver: Receiver>, output_sender: Sender>, rt_context: Arc, - ) -> DaftResult> { + ) -> DaftResult> { let span = info_span!("StreamingSink::Execute"); let compute_runtime = get_compute_runtime(); - let state_wrapper = StreamingSinkState::new(op.make_state()); - let mut finished = false; - while let Some((idx, morsel)) = input_receiver.recv().await { - if finished { - break; - } + let mut state = op.make_state(); + while let Some(morsel) = input_receiver.recv().await { loop { - let op = op.clone(); - let morsel = morsel.clone(); - let span = span.clone(); - let rt_context = rt_context.clone(); - let state_wrapper = state_wrapper.clone(); - let fut = async move { - rt_context.in_span(&span, || op.execute(idx, &morsel, state_wrapper.as_ref())) - }; - let result = compute_runtime.spawn(fut).await??; - match result { + let output = rt_context.in_span(&span, || { + op.execute(morsel.clone(), state, &compute_runtime) + }); + let result = output.await??; + state = result.0; + match result.1 { StreamingSinkOutput::NeedMoreInput(mp) => { if let Some(mp) = mp { if output_sender.send(mp).await.is_err() { - finished = true; - break; + return Ok(state); } } break; } StreamingSinkOutput::HasMoreOutput(mp) => { if output_sender.send(mp).await.is_err() { - finished = true; - break; + return Ok(state); } } StreamingSinkOutput::Finished(mp) => { if let Some(mp) = mp { let _ = output_sender.send(mp).await; } - finished = true; - break; + return Ok(state); } } } } - // Take the state out of the Arc and Mutex because we need to return it. - // It should be guaranteed that the ONLY holder of state at this point is this function. - Ok(Arc::into_inner(state_wrapper) - .expect("Completed worker should have exclusive access to state wrapper") - .inner - .into_inner() - .expect("Completed worker should have exclusive access to inner state")) + Ok(state) } fn spawn_workers( op: Arc, - input_receivers: Vec>, - task_set: &mut TaskSet>>, + input_receivers: Vec>>, + task_set: &mut TaskSet>>, stats: Arc, - ) -> Receiver> { - let (output_sender, output_receiver) = create_channel(input_receivers.len()); - for input_receiver in input_receivers { + maintain_order: bool, + ) -> OrderingAwareReceiver> { + let (output_sender, output_receiver) = + create_ordering_aware_receiver_channel(maintain_order, input_receivers.len()); + for (input_receiver, output_sender) in input_receivers.into_iter().zip(output_sender) { task_set.spawn(Self::run_worker( op.clone(), input_receiver, - output_sender.clone(), + output_sender, stats.clone(), )); } output_receiver } - - // Forwards input from the children to the workers in a round-robin fashion. - // Always exhausts the input from one child before moving to the next. - async fn forward_input_to_workers( - receivers: Vec, - worker_senders: Vec>, - ) -> DaftResult<()> { - let mut next_worker_idx = 0; - let mut send_to_next_worker = |idx, data: PipelineResultType| { - let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); - next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); - next_worker_sender.send((idx, data)) - }; - - for (idx, mut receiver) in receivers.into_iter().enumerate() { - while let Some(morsel) = receiver.recv().await { - if morsel.should_broadcast() { - for worker_sender in &worker_senders { - if worker_sender.send((idx, morsel.clone())).await.is_err() { - return Ok(()); - } - } - } else if send_to_next_worker(idx, morsel.clone()).await.is_err() { - return Ok(()); - } - } - } - Ok(()) - } } impl TreeDisplay for StreamingSinkNode { @@ -244,41 +191,50 @@ impl PipelineNode for StreamingSinkNode { } fn start( - &mut self, + &self, maintain_order: bool, - runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + runtime_handle: &mut ExecutionRuntimeContext, + ) -> crate::Result>> { let mut child_result_receivers = Vec::with_capacity(self.children.len()); - for child in &mut self.children { - let child_result_channel = child.start(maintain_order, runtime_handle)?; - child_result_receivers - .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats.clone())); + for child in &self.children { + let child_result_receiver = child.start(maintain_order, runtime_handle)?; + child_result_receivers.push(CountingReceiver::new( + child_result_receiver, + self.runtime_stats.clone(), + )); } - let mut destination_channel = PipelineChannel::new(1, maintain_order); - let destination_sender = - destination_channel.get_next_sender_with_stats(&self.runtime_stats); + let (destination_sender, destination_receiver) = create_channel(1); + let counting_sender = CountingSender::new(destination_sender, self.runtime_stats.clone()); let op = self.op.clone(); let runtime_stats = self.runtime_stats.clone(); let num_workers = op.max_concurrency(); - let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); + + let dispatch_spawner = op.dispatch_spawner(runtime_handle, maintain_order); + let spawned_dispatch_result = dispatch_spawner.spawn_dispatch( + child_result_receivers, + num_workers, + &mut runtime_handle.handle(), + ); runtime_handle.spawn( - Self::forward_input_to_workers(child_result_receivers, input_senders), + async move { spawned_dispatch_result.spawned_dispatch_task.await? }, self.name(), ); + runtime_handle.spawn( async move { let mut task_set = TaskSet::new(); let mut output_receiver = Self::spawn_workers( op.clone(), - input_receivers, + spawned_dispatch_result.worker_receivers, &mut task_set, runtime_stats.clone(), + maintain_order, ); while let Some(morsel) = output_receiver.recv().await { - if destination_sender.send(morsel.into()).await.is_err() { + if counting_sender.send(morsel).await.is_err() { break; } } @@ -290,21 +246,19 @@ impl PipelineNode for StreamingSinkNode { } let compute_runtime = get_compute_runtime(); - let finalized_result = compute_runtime - .spawn(async move { - runtime_stats.in_span(&info_span!("StreamingSinkNode::finalize"), || { - op.finalize(finished_states) - }) + let finalized_result = runtime_stats + .in_span(&info_span!("StreamingSinkNode::finalize"), || { + op.finalize(finished_states, &compute_runtime) }) .await??; if let Some(res) = finalized_result { - let _ = destination_sender.send(res.into()).await; + let _ = counting_sender.send(res).await; } Ok(()) }, self.name(), ); - Ok(destination_channel) + Ok(destination_receiver) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/sinks/write.rs b/src/daft-local-execution/src/sinks/write.rs index 96a2ccd843..6bd3147f8d 100644 --- a/src/daft-local-execution/src/sinks/write.rs +++ b/src/daft-local-execution/src/sinks/write.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; @@ -8,11 +9,13 @@ use daft_table::Table; use daft_writers::{FileWriter, WriterFactory}; use tracing::instrument; -use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; +use super::blocking_sink::{ + BlockingSink, BlockingSinkFinalizeResult, BlockingSinkSinkResult, BlockingSinkState, + BlockingSinkStatus, +}; use crate::{ - dispatcher::{Dispatcher, PartitionedDispatcher, RoundRobinBufferedDispatcher}, - pipeline::PipelineResultType, - NUM_CPUS, + dispatcher::{DispatchSpawner, PartitionedDispatcher, UnorderedDispatcher}, + ExecutionRuntimeContext, NUM_CPUS, }; pub enum WriteFormat { @@ -72,37 +75,48 @@ impl BlockingSink for WriteSink { #[instrument(skip_all, name = "WriteSink::sink")] fn sink( &self, - input: &Arc, + input: Arc, mut state: Box, - ) -> DaftResult { - state - .as_any_mut() - .downcast_mut::() - .expect("WriteSink should have WriteState") - .writer - .write(input)?; - Ok(BlockingSinkStatus::NeedMoreInput(state)) + runtime_ref: &RuntimeRef, + ) -> BlockingSinkSinkResult { + runtime_ref + .spawn(async move { + state + .as_any_mut() + .downcast_mut::() + .expect("WriteSink should have WriteState") + .writer + .write(&input)?; + Ok(BlockingSinkStatus::NeedMoreInput(state)) + }) + .into() } #[instrument(skip_all, name = "WriteSink::finalize")] fn finalize( &self, states: Vec>, - ) -> DaftResult> { - let mut results = vec![]; - for mut state in states { - let state = state - .as_any_mut() - .downcast_mut::() - .expect("State type mismatch"); - results.extend(state.writer.close()?); - } - let mp = Arc::new(MicroPartition::new_loaded( - self.file_schema.clone(), - results.into(), - None, - )); - Ok(Some(mp.into())) + runtime: &RuntimeRef, + ) -> BlockingSinkFinalizeResult { + let file_schema = self.file_schema.clone(); + runtime + .spawn(async move { + let mut results = vec![]; + for mut state in states { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("State type mismatch"); + results.extend(state.writer.close()?); + } + let mp = Arc::new(MicroPartition::new_loaded( + file_schema, + results.into(), + None, + )); + Ok(Some(mp)) + }) + .into() } fn name(&self) -> &'static str { @@ -124,16 +138,16 @@ impl BlockingSink for WriteSink { Ok(Box::new(WriteState::new(writer)) as Box) } - fn make_dispatcher( + fn dispatch_spawner( &self, - runtime_handle: &crate::ExecutionRuntimeHandle, - ) -> Arc { + runtime_handle: &ExecutionRuntimeContext, + ) -> Arc { if let Some(partition_by) = &self.partition_by { Arc::new(PartitionedDispatcher::new(partition_by.clone())) } else { - Arc::new(RoundRobinBufferedDispatcher::new( + Arc::new(UnorderedDispatcher::new(Some( runtime_handle.default_morsel_size(), - )) + ))) } } diff --git a/src/daft-local-execution/src/sources/source.rs b/src/daft-local-execution/src/sources/source.rs index 0073d18c82..8a2dd8c45c 100644 --- a/src/daft-local-execution/src/sources/source.rs +++ b/src/daft-local-execution/src/sources/source.rs @@ -9,8 +9,10 @@ use daft_micropartition::MicroPartition; use futures::{stream::BoxStream, StreamExt}; use crate::{ - channel::PipelineChannel, pipeline::PipelineNode, runtime_stats::RuntimeStatsContext, - ExecutionRuntimeHandle, + channel::{create_channel, Receiver}, + pipeline::PipelineNode, + runtime_stats::{CountingSender, RuntimeStatsContext}, + ExecutionRuntimeContext, }; pub type SourceStream<'a> = BoxStream<'a, DaftResult>>; @@ -70,33 +72,33 @@ impl PipelineNode for SourceNode { vec![] } fn start( - &mut self, + &self, maintain_order: bool, - runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + runtime_handle: &mut ExecutionRuntimeContext, + ) -> crate::Result>> { let source = self.source.clone(); let io_stats = self.io_stats.clone(); - let mut channel = PipelineChannel::new(1, maintain_order); - let counting_sender = channel.get_next_sender_with_stats(&self.runtime_stats); + let (destination_sender, destination_receiver) = create_channel(1); + let counting_sender = CountingSender::new(destination_sender, self.runtime_stats.clone()); runtime_handle.spawn( async move { let mut has_data = false; let mut source_stream = source.get_data(maintain_order, io_stats).await?; while let Some(part) = source_stream.next().await { has_data = true; - if counting_sender.send(part?.into()).await.is_err() { + if counting_sender.send(part?).await.is_err() { return Ok(()); } } if !has_data { let empty = Arc::new(MicroPartition::empty(Some(source.schema().clone()))); - let _ = counting_sender.send(empty.into()).await; + let _ = counting_sender.send(empty).await; } Ok(()) }, self.name(), ); - Ok(channel) + Ok(destination_receiver) } fn as_tree_display(&self) -> &dyn TreeDisplay { self