diff --git a/Cargo.lock b/Cargo.lock index f5a1ef9c25..b85d3fcb72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2086,7 +2086,6 @@ dependencies = [ name = "daft-local-execution" version = "0.3.0-dev0" dependencies = [ - "async-trait", "common-daft-config", "common-display", "common-error", @@ -2110,6 +2109,7 @@ dependencies = [ "indexmap 2.5.0", "lazy_static", "log", + "loole", "num-format", "pyo3", "snafu", @@ -3629,6 +3629,16 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "loole" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2998397c725c822c6b2ba605fd9eb4c6a7a0810f1629ba3cc232ef4f0308d96" +dependencies = [ + "futures-core", + "futures-sink", +] + [[package]] name = "lz4" version = "1.26.0" diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index f0be341b6a..06c250ee00 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -1,5 +1,4 @@ [dependencies] -async-trait = {workspace = true} common-daft-config = {path = "../common/daft-config", default-features = false} common-display = {path = "../common/display", default-features = false} common-error = {path = "../common/error", default-features = false} @@ -23,6 +22,7 @@ futures = {workspace = true} indexmap = {workspace = true} lazy_static = {workspace = true} log = {workspace = true} +loole = "0.4.0" num-format = "0.4.4" pyo3 = {workspace = true, optional = true} snafu = {workspace = true} diff --git a/src/daft-local-execution/src/channel.rs b/src/daft-local-execution/src/channel.rs index f16e3bd061..21c8a41cf9 100644 --- a/src/daft-local-execution/src/channel.rs +++ b/src/daft-local-execution/src/channel.rs @@ -1,92 +1,84 @@ use std::sync::Arc; -use crate::{ - pipeline::PipelineResultType, - runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, -}; - -pub type Sender = tokio::sync::mpsc::Sender; -pub type Receiver = tokio::sync::mpsc::Receiver; - -pub fn create_channel(buffer_size: usize) -> (Sender, Receiver) { - tokio::sync::mpsc::channel(buffer_size) -} - -pub struct PipelineChannel { - sender: PipelineSender, - receiver: PipelineReceiver, +use daft_micropartition::MicroPartition; +use loole::SendError; + +use crate::runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}; + +#[derive(Clone)] +pub(crate) struct Sender(loole::Sender) +where + T: Clone; +impl Sender { + pub(crate) async fn send(&self, val: T) -> Result<(), SendError> { + self.0.send_async(val).await + } } -impl PipelineChannel { - pub fn new(buffer_size: usize, in_order: bool) -> Self { - if in_order { - let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip(); - let sender = PipelineSender::InOrder(RoundRobinSender::new(senders)); - let receiver = PipelineReceiver::InOrder(RoundRobinReceiver::new(receivers)); - Self { sender, receiver } - } else { - let (sender, receiver) = create_channel(buffer_size); - let sender = PipelineSender::OutOfOrder(sender); - let receiver = PipelineReceiver::OutOfOrder(receiver); - Self { sender, receiver } - } - } - - fn get_next_sender(&mut self) -> Sender { - match &mut self.sender { - PipelineSender::InOrder(rr) => rr.get_next_sender(), - PipelineSender::OutOfOrder(sender) => sender.clone(), - } +impl Sender> { + pub(crate) fn into_counting_sender(self, rt: Arc) -> CountingSender { + CountingSender::new(self, rt) } +} - pub(crate) fn get_next_sender_with_stats( - &mut self, - rt: &Arc, - ) -> CountingSender { - CountingSender::new(self.get_next_sender(), rt.clone()) +#[derive(Clone)] +pub(crate) struct Receiver(loole::Receiver) +where + T: Clone; +impl Receiver { + pub(crate) async fn recv(&self) -> Option { + self.0.recv_async().await.ok() } - pub fn get_receiver(self) -> PipelineReceiver { - self.receiver + pub(crate) fn blocking_recv(&self) -> Option { + self.0.recv().ok() } - pub(crate) fn get_receiver_with_stats(self, rt: &Arc) -> CountingReceiver { - CountingReceiver::new(self.get_receiver(), rt.clone()) + pub(crate) fn into_stream(self) -> impl futures::Stream { + self.0.into_stream() } } -pub enum PipelineSender { - InOrder(RoundRobinSender), - OutOfOrder(Sender), +impl Receiver> { + pub(crate) fn into_counting_receiver(self, rt: Arc) -> CountingReceiver { + CountingReceiver::new(self, rt) + } } -pub struct RoundRobinSender { - senders: Vec>, - curr_sender_idx: usize, +pub(crate) fn create_channel(buffer_size: usize) -> (Sender, Receiver) { + let (tx, rx) = loole::bounded(buffer_size); + (Sender(tx), Receiver(rx)) } -impl RoundRobinSender { - pub fn new(senders: Vec>) -> Self { - Self { - senders, - curr_sender_idx: 0, +pub(crate) fn create_ordering_aware_receiver_channel( + ordered: bool, + buffer_size: usize, +) -> (Vec>, OrderingAwareReceiver) { + match ordered { + true => { + let (senders, receiver) = (0..buffer_size).map(|_| create_channel::(1)).unzip(); + ( + senders, + OrderingAwareReceiver::InOrder(RoundRobinReceiver::new(receiver)), + ) + } + false => { + let (sender, receiver) = create_channel::(buffer_size); + ( + (0..buffer_size).map(|_| sender.clone()).collect(), + OrderingAwareReceiver::OutOfOrder(receiver), + ) } - } - - pub fn get_next_sender(&mut self) -> Sender { - let next_idx = self.curr_sender_idx; - self.curr_sender_idx = (next_idx + 1) % self.senders.len(); - self.senders[next_idx].clone() } } -pub enum PipelineReceiver { - InOrder(RoundRobinReceiver), - OutOfOrder(Receiver), +pub(crate) enum OrderingAwareReceiver { + InOrder(RoundRobinReceiver), + OutOfOrder(Receiver), } -impl PipelineReceiver { - pub async fn recv(&mut self) -> Option { +impl OrderingAwareReceiver { + pub(crate) async fn recv(&mut self) -> Option { match self { Self::InOrder(rr) => rr.recv().await, Self::OutOfOrder(r) => r.recv().await, @@ -94,14 +86,14 @@ impl PipelineReceiver { } } -pub struct RoundRobinReceiver { +pub(crate) struct RoundRobinReceiver { receivers: Vec>, curr_receiver_idx: usize, is_done: bool, } -impl RoundRobinReceiver { - pub fn new(receivers: Vec>) -> Self { +impl RoundRobinReceiver { + fn new(receivers: Vec>) -> Self { Self { receivers, curr_receiver_idx: 0, @@ -109,7 +101,7 @@ impl RoundRobinReceiver { } } - pub async fn recv(&mut self) -> Option { + async fn recv(&mut self) -> Option { if self.is_done { return None; } diff --git a/src/daft-local-execution/src/dispatcher.rs b/src/daft-local-execution/src/dispatcher.rs index d21fc306b6..79184ce1ba 100644 --- a/src/daft-local-execution/src/dispatcher.rs +++ b/src/daft-local-execution/src/dispatcher.rs @@ -1,70 +1,149 @@ use std::sync::Arc; -use async_trait::async_trait; use common_error::DaftResult; use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; use crate::{ - buffer::RowBasedBuffer, channel::Sender, pipeline::PipelineResultType, + buffer::RowBasedBuffer, + channel::{create_channel, Receiver, Sender}, runtime_stats::CountingReceiver, + ExecutionRuntimeHandle, }; -#[async_trait] pub(crate) trait Dispatcher { - async fn dispatch( + fn dispatch( &self, - receiver: CountingReceiver, - worker_senders: Vec>, - ) -> DaftResult<()>; + input_receivers: Vec, + num_workers: usize, + runtime_handle: &mut ExecutionRuntimeHandle, + name: &'static str, + ) -> Vec>>; } -pub(crate) struct RoundRobinBufferedDispatcher { - morsel_size: usize, +pub(crate) struct RoundRobinDispatcher { + morsel_size: Option, } -impl RoundRobinBufferedDispatcher { - pub(crate) fn new(morsel_size: usize) -> Self { +impl RoundRobinDispatcher { + pub(crate) fn new(morsel_size: Option) -> Self { Self { morsel_size } } -} -#[async_trait] -impl Dispatcher for RoundRobinBufferedDispatcher { - async fn dispatch( - &self, - mut receiver: CountingReceiver, - worker_senders: Vec>, + async fn dispatch_inner( + worker_senders: Vec>>, + input_receivers: Vec, + morsel_size: Option, ) -> DaftResult<()> { let mut next_worker_idx = 0; - let mut send_to_next_worker = |data: PipelineResultType| { + let mut send_to_next_worker = |data: Arc| { let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); next_worker_sender.send(data) }; - let mut buffer = RowBasedBuffer::new(self.morsel_size); - while let Some(morsel) = receiver.recv().await { - if morsel.should_broadcast() { - for worker_sender in &worker_senders { - let _ = worker_sender.send(morsel.clone()).await; - } - } else { - buffer.push(morsel.as_data()); - if let Some(ready) = buffer.pop_enough()? { - for r in ready { - let _ = send_to_next_worker(r.into()).await; + for receiver in input_receivers { + let mut buffer = morsel_size.map(RowBasedBuffer::new); + while let Some(morsel) = receiver.recv().await { + if let Some(buffer) = &mut buffer { + buffer.push(&morsel); + if let Some(ready) = buffer.pop_enough()? { + for r in ready { + let _ = send_to_next_worker(r).await; + } } + } else { + let _ = send_to_next_worker(morsel).await; + } + } + // Clear all remaining morsels + if let Some(buffer) = &mut buffer { + if let Some(last_morsel) = buffer.pop_all()? { + let _ = send_to_next_worker(last_morsel).await; } } } - // Clear all remaining morsels - if let Some(last_morsel) = buffer.pop_all()? { - let _ = send_to_next_worker(last_morsel.into()).await; + Ok(()) + } +} + +impl Dispatcher for RoundRobinDispatcher { + fn dispatch( + &self, + input_receivers: Vec, + num_workers: usize, + runtime_handle: &mut ExecutionRuntimeHandle, + name: &'static str, + ) -> Vec>> { + let (worker_senders, worker_receivers): (Vec<_>, Vec<_>) = + (0..num_workers).map(|_| create_channel(1)).unzip(); + let morsel_size = self.morsel_size; + runtime_handle.spawn( + Self::dispatch_inner(worker_senders, input_receivers, morsel_size), + name, + ); + worker_receivers + } +} + +pub(crate) struct UnorderedDispatcher { + morsel_size: Option, +} + +impl UnorderedDispatcher { + pub(crate) fn new(morsel_size: Option) -> Self { + Self { morsel_size } + } + + async fn dispatch_inner( + worker_sender: Sender>, + input_receivers: Vec, + morsel_size: Option, + ) -> DaftResult<()> { + for receiver in input_receivers { + let mut buffer = morsel_size.map(RowBasedBuffer::new); + while let Some(morsel) = receiver.recv().await { + if let Some(buffer) = &mut buffer { + buffer.push(&morsel); + if let Some(ready) = buffer.pop_enough()? { + for r in ready { + let _ = worker_sender.send(r).await; + } + } + } else { + let _ = worker_sender.send(morsel).await; + } + } + // Clear all remaining morsels + if let Some(buffer) = &mut buffer { + if let Some(last_morsel) = buffer.pop_all()? { + let _ = worker_sender.send(last_morsel).await; + } + } } Ok(()) } } +impl Dispatcher for UnorderedDispatcher { + fn dispatch( + &self, + receiver: Vec, + num_workers: usize, + runtime_handle: &mut ExecutionRuntimeHandle, + name: &'static str, + ) -> Vec>> { + let (worker_sender, worker_receiver) = create_channel(num_workers); + let worker_receivers = vec![worker_receiver; num_workers]; + let morsel_size = self.morsel_size; + runtime_handle.spawn( + Self::dispatch_inner(worker_sender, receiver, morsel_size), + name, + ); + worker_receivers + } +} + pub(crate) struct PartitionedDispatcher { partition_by: Vec, } @@ -73,30 +152,40 @@ impl PartitionedDispatcher { pub(crate) fn new(partition_by: Vec) -> Self { Self { partition_by } } -} -#[async_trait] -impl Dispatcher for PartitionedDispatcher { - async fn dispatch( - &self, - mut receiver: CountingReceiver, - worker_senders: Vec>, + async fn dispatch_inner( + worker_senders: Vec>>, + input_receivers: Vec, + partition_by: Vec, ) -> DaftResult<()> { - while let Some(morsel) = receiver.recv().await { - if morsel.should_broadcast() { - for worker_sender in &worker_senders { - let _ = worker_sender.send(morsel.clone()).await; - } - } else { - let partitions = morsel - .as_data() - .partition_by_hash(&self.partition_by, worker_senders.len())?; + for receiver in input_receivers { + while let Some(morsel) = receiver.recv().await { + let partitions = morsel.partition_by_hash(&partition_by, worker_senders.len())?; for (partition, worker_sender) in partitions.into_iter().zip(worker_senders.iter()) { - let _ = worker_sender.send(Arc::new(partition).into()).await; + let _ = worker_sender.send(Arc::new(partition)).await; } } } Ok(()) } } + +impl Dispatcher for PartitionedDispatcher { + fn dispatch( + &self, + input_receivers: Vec, + num_workers: usize, + runtime_handle: &mut ExecutionRuntimeHandle, + name: &'static str, + ) -> Vec>> { + let (worker_senders, worker_receivers): (Vec<_>, Vec<_>) = + (0..num_workers).map(|_| create_channel(1)).unzip(); + let partition_by = self.partition_by.clone(); + runtime_handle.spawn( + Self::dispatch_inner(worker_senders, input_receivers, partition_by), + name, + ); + worker_receivers + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/actor_pool_project.rs b/src/daft-local-execution/src/intermediate_ops/actor_pool_project.rs index 58a785e62b..e2b5031dee 100644 --- a/src/daft-local-execution/src/intermediate_ops/actor_pool_project.rs +++ b/src/daft-local-execution/src/intermediate_ops/actor_pool_project.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use common_error::DaftResult; +use common_runtime::RuntimeRef; #[cfg(feature = "python")] use daft_dsl::python::PyExpr; use daft_dsl::{functions::python::extract_stateful_udf_exprs, ExprRef}; @@ -12,10 +13,10 @@ use pyo3::prelude::*; use tracing::instrument; use super::intermediate_op::{ - DynIntermediateOpState, IntermediateOperator, IntermediateOperatorResult, - IntermediateOperatorState, + IntermediateOpState, IntermediateOperator, IntermediateOperatorResult, + IntermediateOperatorResultType, }; -use crate::pipeline::PipelineResultType; +use crate::dispatcher::{RoundRobinDispatcher, UnorderedDispatcher}; struct ActorHandle { #[cfg(feature = "python")] @@ -108,7 +109,7 @@ struct ActorPoolProjectState { pub actor_handle: ActorHandle, } -impl DynIntermediateOpState for ActorPoolProjectState { +impl IntermediateOpState for ActorPoolProjectState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } @@ -145,23 +146,31 @@ impl IntermediateOperator for ActorPoolProjectOperator { #[instrument(skip_all, name = "ActorPoolProjectOperator::execute")] fn execute( &self, - _idx: usize, - input: &PipelineResultType, - state: &IntermediateOperatorState, - ) -> DaftResult { - state.with_state_mut::(|state| { - state - .actor_handle - .eval_input(input.as_data().clone()) - .map(|result| IntermediateOperatorResult::NeedMoreInput(Some(result))) - }) + input: &Arc, + mut state: Box, + runtime_ref: &RuntimeRef, + ) -> IntermediateOperatorResult { + let input = input.clone(); + runtime_ref + .spawn(async move { + let app_state = state + .as_any_mut() + .downcast_mut::() + .expect("ActorPoolProjectState should be used with ActorPoolProjectOperator"); + let res = app_state.actor_handle.eval_input(input.clone())?; + Ok(( + state, + IntermediateOperatorResultType::NeedMoreInput(Some(res)), + )) + }) + .into() } fn name(&self) -> &'static str { "ActorPoolProject" } - fn make_state(&self) -> DaftResult> { + fn make_state(&self) -> DaftResult> { // TODO: Pass relevant CUDA_VISIBLE_DEVICES to the actor Ok(Box::new(ActorPoolProjectState { actor_handle: ActorHandle::try_new(&self.projection)?, @@ -172,7 +181,21 @@ impl IntermediateOperator for ActorPoolProjectOperator { self.concurrency } - fn morsel_size(&self) -> Option { - self.batch_size + fn make_dispatcher( + &self, + runtime_handle: &crate::ExecutionRuntimeHandle, + maintain_order: bool, + ) -> Arc { + if maintain_order { + Arc::new(UnorderedDispatcher::new(Some( + self.batch_size + .unwrap_or_else(|| runtime_handle.default_morsel_size()), + ))) + } else { + Arc::new(RoundRobinDispatcher::new(Some( + self.batch_size + .unwrap_or_else(|| runtime_handle.default_morsel_size()), + ))) + } } } diff --git a/src/daft-local-execution/src/intermediate_ops/aggregate.rs b/src/daft-local-execution/src/intermediate_ops/aggregate.rs deleted file mode 100644 index 4b8fa7bbb6..0000000000 --- a/src/daft-local-execution/src/intermediate_ops/aggregate.rs +++ /dev/null @@ -1,43 +0,0 @@ -use std::sync::Arc; - -use common_error::DaftResult; -use daft_dsl::ExprRef; -use tracing::instrument; - -use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, -}; -use crate::pipeline::PipelineResultType; - -pub struct AggregateOperator { - agg_exprs: Vec, - group_by: Vec, -} - -impl AggregateOperator { - pub fn new(agg_exprs: Vec, group_by: Vec) -> Self { - Self { - agg_exprs, - group_by, - } - } -} - -impl IntermediateOperator for AggregateOperator { - #[instrument(skip_all, name = "AggregateOperator::execute")] - fn execute( - &self, - _idx: usize, - input: &PipelineResultType, - _state: &IntermediateOperatorState, - ) -> DaftResult { - let out = input.as_data().agg(&self.agg_exprs, &self.group_by)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, - )))) - } - - fn name(&self) -> &'static str { - "AggregateOperator" - } -} diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index 9338fd2dc8..0360f0dab5 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; @@ -9,64 +10,66 @@ use daft_table::{GrowableTable, Probeable}; use tracing::{info_span, instrument}; use super::intermediate_op::{ - DynIntermediateOpState, IntermediateOperator, IntermediateOperatorResult, - IntermediateOperatorState, + IntermediateOpState, IntermediateOperator, IntermediateOperatorResult, + IntermediateOperatorResultType, }; -use crate::pipeline::PipelineResultType; +use crate::sinks::hash_join_build::ProbeStateBridgeRef; enum AntiSemiProbeState { - Building, - ReadyToProbe(Arc), + Building(ProbeStateBridgeRef), + Probing(Arc), } impl AntiSemiProbeState { - fn set_table(&mut self, table: &Arc) { - if matches!(self, Self::Building) { - *self = Self::ReadyToProbe(table.clone()); - } else { - panic!("AntiSemiProbeState should only be in Building state when setting table") - } - } - - fn get_probeable(&self) -> &Arc { - if let Self::ReadyToProbe(probeable) = self { - probeable - } else { - panic!("AntiSemiProbeState should only be in ReadyToProbe state when getting probeable") + async fn get_or_await_probeable(&mut self) -> Arc { + match self { + Self::Building(bridge) => { + let probe_state = bridge.get_probe_state().await; + let probeable = probe_state.get_probeable(); + *self = Self::Probing(probeable.clone()); + probeable.clone() + } + Self::Probing(probeable) => probeable.clone(), } } } -impl DynIntermediateOpState for AntiSemiProbeState { +impl IntermediateOpState for AntiSemiProbeState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } } -pub struct AntiSemiProbeOperator { +pub(crate) struct AntiSemiProbeOperator { probe_on: Vec, is_semi: bool, output_schema: SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, } impl AntiSemiProbeOperator { const DEFAULT_GROWABLE_SIZE: usize = 20; - pub fn new(probe_on: Vec, join_type: &JoinType, output_schema: &SchemaRef) -> Self { + pub fn new( + probe_on: Vec, + join_type: &JoinType, + output_schema: &SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, + ) -> Self { Self { probe_on, is_semi: *join_type == JoinType::Semi, output_schema: output_schema.clone(), + probe_state_bridge, } } fn probe_anti_semi( - &self, + probe_on: &[ExprRef], + probe_set: &Arc, input: &Arc, - state: &AntiSemiProbeState, + is_semi: bool, ) -> DaftResult> { - let probe_set = state.get_probeable(); - let _growables = info_span!("AntiSemiOperator::build_growables").entered(); let input_tables = input.get_tables()?; @@ -81,11 +84,11 @@ impl AntiSemiProbeOperator { { let _loop = info_span!("AntiSemiOperator::eval_and_probe").entered(); for (probe_side_table_idx, table) in input_tables.iter().enumerate() { - let join_keys = table.eval_expression_list(&self.probe_on)?; + let join_keys = table.eval_expression_list(probe_on)?; let iter = probe_set.probe_exists(&join_keys)?; for (probe_row_idx, matched) in iter.enumerate() { - match (self.is_semi, matched) { + match (is_semi, matched) { (true, true) | (false, false) => { probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); } @@ -107,32 +110,44 @@ impl IntermediateOperator for AntiSemiProbeOperator { #[instrument(skip_all, name = "AntiSemiOperator::execute")] fn execute( &self, - idx: usize, - input: &PipelineResultType, - state: &IntermediateOperatorState, - ) -> DaftResult { - state.with_state_mut::(|state| { - if idx == 0 { - let probe_state = input.as_probe_state(); - state.set_table(probe_state.get_probeable()); - Ok(IntermediateOperatorResult::NeedMoreInput(None)) - } else { - let input = input.as_data(); - if input.is_empty() { - let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); - return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); - } - let out = self.probe_anti_semi(input, state)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) - } - }) + input: &Arc, + mut state: Box, + runtime_ref: &RuntimeRef, + ) -> IntermediateOperatorResult { + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(( + state, + IntermediateOperatorResultType::NeedMoreInput(Some(empty)), + )) + .into(); + } + let probe_on = self.probe_on.clone(); + let is_semi = self.is_semi; + let input = input.clone(); + runtime_ref + .spawn(async move { + let probe_state = state + .as_any_mut() + .downcast_mut::() + .expect("AntiSemiProbeState should be used with AntiSemiProbeOperator"); + let probeable = probe_state.get_or_await_probeable().await; + let res = Self::probe_anti_semi(&probe_on, &probeable, &input, is_semi); + Ok(( + state, + IntermediateOperatorResultType::NeedMoreInput(Some(res?)), + )) + }) + .into() } fn name(&self) -> &'static str { "AntiSemiProbeOperator" } - fn make_state(&self) -> DaftResult> { - Ok(Box::new(AntiSemiProbeState::Building)) + fn make_state(&self) -> DaftResult> { + Ok(Box::new(AntiSemiProbeState::Building( + self.probe_state_bridge.clone(), + ))) } } diff --git a/src/daft-local-execution/src/intermediate_ops/explode.rs b/src/daft-local-execution/src/intermediate_ops/explode.rs index 30ec8b02b5..3cd2d991b8 100644 --- a/src/daft-local-execution/src/intermediate_ops/explode.rs +++ b/src/daft-local-execution/src/intermediate_ops/explode.rs @@ -1,21 +1,22 @@ use std::sync::Arc; -use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_dsl::ExprRef; use daft_functions::list::explode; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, + IntermediateOpState, IntermediateOperator, IntermediateOperatorResult, + IntermediateOperatorResultType, }; -use crate::pipeline::PipelineResultType; -pub struct ExplodeOperator { +pub(crate) struct ExplodeOperator { to_explode: Vec, } impl ExplodeOperator { - pub fn new(to_explode: Vec) -> Self { + pub(crate) fn new(to_explode: Vec) -> Self { Self { to_explode: to_explode.into_iter().map(explode).collect(), } @@ -26,14 +27,21 @@ impl IntermediateOperator for ExplodeOperator { #[instrument(skip_all, name = "ExplodeOperator::execute")] fn execute( &self, - _idx: usize, - input: &PipelineResultType, - _state: &IntermediateOperatorState, - ) -> DaftResult { - let out = input.as_data().explode(&self.to_explode)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, - )))) + input: &Arc, + state: Box, + runtime_ref: &RuntimeRef, + ) -> IntermediateOperatorResult { + let input = input.clone(); + let to_explode = self.to_explode.clone(); + runtime_ref + .spawn(async move { + let out = input.explode(&to_explode)?; + Ok(( + state, + IntermediateOperatorResultType::NeedMoreInput(Some(Arc::new(out))), + )) + }) + .into() } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/intermediate_ops/filter.rs b/src/daft-local-execution/src/intermediate_ops/filter.rs index aad3bd7e7d..9e24678076 100644 --- a/src/daft-local-execution/src/intermediate_ops/filter.rs +++ b/src/daft-local-execution/src/intermediate_ops/filter.rs @@ -1,13 +1,14 @@ use std::sync::Arc; -use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, + IntermediateOpState, IntermediateOperator, IntermediateOperatorResult, + IntermediateOperatorResultType, }; -use crate::pipeline::PipelineResultType; pub struct FilterOperator { predicate: ExprRef, @@ -23,14 +24,21 @@ impl IntermediateOperator for FilterOperator { #[instrument(skip_all, name = "FilterOperator::execute")] fn execute( &self, - _idx: usize, - input: &PipelineResultType, - _state: &IntermediateOperatorState, - ) -> DaftResult { - let out = input.as_data().filter(&[self.predicate.clone()])?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, - )))) + input: &Arc, + state: Box, + runtime_ref: &RuntimeRef, + ) -> IntermediateOperatorResult { + let input = input.clone(); + let predicate = self.predicate.clone(); + runtime_ref + .spawn(async move { + let filtered = input.filter(&[predicate.clone()])?; + Ok(( + state, + IntermediateOperatorResultType::NeedMoreInput(Some(Arc::new(filtered))), + )) + }) + .into() } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs index ce0da48089..b1373adee7 100644 --- a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; @@ -9,35 +10,30 @@ use indexmap::IndexSet; use tracing::{info_span, instrument}; use super::intermediate_op::{ - DynIntermediateOpState, IntermediateOperator, IntermediateOperatorResult, - IntermediateOperatorState, + IntermediateOpState, IntermediateOperator, IntermediateOperatorResult, + IntermediateOperatorResultType, }; -use crate::pipeline::PipelineResultType; +use crate::{sinks::hash_join_build::ProbeStateBridgeRef, MaybeFuture}; enum InnerHashJoinProbeState { - Building, - ReadyToProbe(Arc), + Building(ProbeStateBridgeRef), + Probing(Arc), } impl InnerHashJoinProbeState { - fn set_probe_state(&mut self, probe_state: Arc) { - if matches!(self, Self::Building) { - *self = Self::ReadyToProbe(probe_state); - } else { - panic!("InnerHashJoinProbeState should only be in Building state when setting table") - } - } - - fn get_probe_state(&self) -> &Arc { - if let Self::ReadyToProbe(probe_state) = self { - probe_state - } else { - panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") + async fn get_or_await_probe_state(&mut self) -> Arc { + match self { + Self::Building(bridge) => { + let probe_state = bridge.get_probe_state().await; + *self = Self::Probing(probe_state.clone()); + probe_state + } + Self::Probing(probeable) => probeable.clone(), } } } -impl DynIntermediateOpState for InnerHashJoinProbeState { +impl IntermediateOpState for InnerHashJoinProbeState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } @@ -50,6 +46,7 @@ pub struct InnerHashJoinProbeOperator { right_non_join_columns: Vec, build_on_left: bool, output_schema: SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, } impl InnerHashJoinProbeOperator { @@ -62,6 +59,7 @@ impl InnerHashJoinProbeOperator { build_on_left: bool, common_join_keys: IndexSet, output_schema: &SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, ) -> Self { let left_non_join_columns = left_schema .fields @@ -83,18 +81,21 @@ impl InnerHashJoinProbeOperator { right_non_join_columns, build_on_left, output_schema: output_schema.clone(), + probe_state_bridge, } } fn probe_inner( - &self, input: &Arc, - state: &InnerHashJoinProbeState, + probe_state: &Arc, + probe_on: &[ExprRef], + common_join_keys: &[String], + left_non_join_columns: &[String], + right_non_join_columns: &[String], + build_on_left: bool, ) -> DaftResult> { - let (probe_table, tables) = { - let probe_state = state.get_probe_state(); - (probe_state.get_probeable(), probe_state.get_tables()) - }; + let probe_table = probe_state.get_probeable(); + let tables = probe_state.get_tables(); let _growables = info_span!("InnerHashJoinOperator::build_growables").entered(); @@ -117,7 +118,7 @@ impl InnerHashJoinProbeOperator { let _loop = info_span!("InnerHashJoinOperator::eval_and_probe").entered(); for (probe_side_table_idx, table) in input_tables.iter().enumerate() { // we should emit one table at a time when this is streaming - let join_keys = table.eval_expression_list(&self.probe_on)?; + let join_keys = table.eval_expression_list(probe_on)?; let idx_mapper = probe_table.probe_indices(&join_keys)?; for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { @@ -138,15 +139,15 @@ impl InnerHashJoinProbeOperator { let build_side_table = build_side_growable.build()?; let probe_side_table = probe_side_growable.build()?; - let (left_table, right_table) = if self.build_on_left { + let (left_table, right_table) = if build_on_left { (build_side_table, probe_side_table) } else { (probe_side_table, build_side_table) }; - let join_keys_table = left_table.get_columns(&self.common_join_keys)?; - let left_non_join_columns = left_table.get_columns(&self.left_non_join_columns)?; - let right_non_join_columns = right_table.get_columns(&self.right_non_join_columns)?; + let join_keys_table = left_table.get_columns(common_join_keys)?; + let left_non_join_columns = left_table.get_columns(left_non_join_columns)?; + let right_non_join_columns = right_table.get_columns(right_non_join_columns)?; let final_table = join_keys_table .union(&left_non_join_columns)? .union(&right_non_join_columns)?; @@ -163,33 +164,54 @@ impl IntermediateOperator for InnerHashJoinProbeOperator { #[instrument(skip_all, name = "InnerHashJoinOperator::execute")] fn execute( &self, - idx: usize, - input: &PipelineResultType, - state: &IntermediateOperatorState, - ) -> DaftResult { - state.with_state_mut::(|state| match idx { - 0 => { - let probe_state = input.as_probe_state(); - state.set_probe_state(probe_state.clone()); - Ok(IntermediateOperatorResult::NeedMoreInput(None)) - } - _ => { - let input = input.as_data(); - if input.is_empty() { - let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); - return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); - } - let out = self.probe_inner(input, state)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) - } - }) + input: &Arc, + mut state: Box, + runtime_ref: &RuntimeRef, + ) -> IntermediateOperatorResult { + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return MaybeFuture::Immediate(Ok(( + state, + IntermediateOperatorResultType::NeedMoreInput(Some(empty)), + ))); + } + + let input = input.clone(); + let probe_on = self.probe_on.clone(); + let common_join_keys = self.common_join_keys.clone(); + let left_non_join_columns = self.left_non_join_columns.clone(); + let right_non_join_columns = self.right_non_join_columns.clone(); + let build_on_left = self.build_on_left; + let fut = runtime_ref.spawn(async move { + let inner_join_state = state + .as_any_mut() + .downcast_mut::() + .expect("InnerHashJoinProbeState should be used with InnerHashJoinProbeOperator"); + let probe_state = inner_join_state.get_or_await_probe_state().await; + let res = Self::probe_inner( + &input, + &probe_state, + &probe_on, + &common_join_keys, + &left_non_join_columns, + &right_non_join_columns, + build_on_left, + ); + Ok(( + state, + IntermediateOperatorResultType::NeedMoreInput(Some(res?)), + )) + }); + MaybeFuture::Future(fut) } fn name(&self) -> &'static str { "InnerHashJoinProbeOperator" } - fn make_state(&self) -> DaftResult> { - Ok(Box::new(InnerHashJoinProbeState::Building)) + fn make_state(&self) -> DaftResult> { + Ok(Box::new(InnerHashJoinProbeState::Building( + self.probe_state_bridge.clone(), + ))) } } diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index d268120c06..e89baddea3 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -1,69 +1,48 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use common_display::tree::TreeDisplay; use common_error::DaftResult; -use common_runtime::get_compute_runtime; +use common_runtime::{get_compute_runtime, RuntimeRef}; use daft_micropartition::MicroPartition; use tracing::{info_span, instrument}; use crate::{ - buffer::RowBasedBuffer, - channel::{create_channel, PipelineChannel, Receiver, Sender}, - pipeline::{PipelineNode, PipelineResultType}, - runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, - ExecutionRuntimeHandle, NUM_CPUS, + channel::{create_channel, create_ordering_aware_receiver_channel, Receiver, Sender}, + dispatcher::{Dispatcher, RoundRobinDispatcher, UnorderedDispatcher}, + pipeline::PipelineNode, + runtime_stats::RuntimeStatsContext, + ExecutionRuntimeHandle, MaybeFuture, NUM_CPUS, }; -pub(crate) trait DynIntermediateOpState: Send + Sync { +pub(crate) trait IntermediateOpState: Send + Sync { fn as_any_mut(&mut self) -> &mut dyn std::any::Any; } struct DefaultIntermediateOperatorState {} -impl DynIntermediateOpState for DefaultIntermediateOperatorState { +impl IntermediateOpState for DefaultIntermediateOperatorState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } } -pub(crate) struct IntermediateOperatorState { - inner: Mutex>, -} - -impl IntermediateOperatorState { - fn new(inner: Box) -> Arc { - Arc::new(Self { - inner: Mutex::new(inner), - }) - } - - pub(crate) fn with_state_mut(&self, f: F) -> R - where - F: FnOnce(&mut T) -> R, - { - let mut guard = self.inner.lock().unwrap(); - let state = guard - .as_any_mut() - .downcast_mut::() - .expect("State type mismatch"); - f(state) - } -} - -pub enum IntermediateOperatorResult { +pub(crate) enum IntermediateOperatorResultType { NeedMoreInput(Option>), #[allow(dead_code)] HasMoreOutput(Arc), } +pub(crate) type IntermediateOperatorResult = + MaybeFuture, IntermediateOperatorResultType)>>; + pub trait IntermediateOperator: Send + Sync { fn execute( &self, - idx: usize, - input: &PipelineResultType, - state: &IntermediateOperatorState, - ) -> DaftResult; + input: &Arc, + state: Box, + runtime_ref: &RuntimeRef, + ) -> IntermediateOperatorResult; fn name(&self) -> &'static str; - fn make_state(&self) -> DaftResult> { + fn make_state(&self) -> DaftResult> { Ok(Box::new(DefaultIntermediateOperatorState {})) } /// The maximum number of concurrent workers that can be spawned for this operator. @@ -72,9 +51,20 @@ pub trait IntermediateOperator: Send + Sync { *NUM_CPUS } - /// The input morsel size expected by this operator. If None, use the default size. - fn morsel_size(&self) -> Option { - None + fn make_dispatcher( + &self, + runtime_handle: &ExecutionRuntimeHandle, + maintain_order: bool, + ) -> Arc { + if maintain_order { + Arc::new(UnorderedDispatcher::new(Some( + runtime_handle.default_morsel_size(), + ))) + } else { + Arc::new(RoundRobinDispatcher::new(Some( + runtime_handle.default_morsel_size(), + ))) + } } } @@ -112,34 +102,30 @@ impl IntermediateNode { #[instrument(level = "info", skip_all, name = "IntermediateOperator::run_worker")] pub async fn run_worker( op: Arc, - mut receiver: Receiver<(usize, PipelineResultType)>, - sender: CountingSender, + receiver: Receiver>, + sender: Sender>, rt_context: Arc, ) -> DaftResult<()> { let span = info_span!("IntermediateOp::execute"); let compute_runtime = get_compute_runtime(); - let state_wrapper = IntermediateOperatorState::new(op.make_state()?); - while let Some((idx, morsel)) = receiver.recv().await { + let mut state = op.make_state()?; + while let Some(morsel) = receiver.recv().await { loop { - let op = op.clone(); - let morsel = morsel.clone(); - let span = span.clone(); - let rt_context = rt_context.clone(); - let state_wrapper = state_wrapper.clone(); - let fut = async move { - rt_context.in_span(&span, || op.execute(idx, &morsel, &state_wrapper)) - }; - let result = compute_runtime.spawn(fut).await??; - match result { - IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { - let _ = sender.send(mp.into()).await; + let result = rt_context + .in_span(&span, || op.execute(&morsel, state, &compute_runtime)) + .output() + .await??; + state = result.0; + match result.1 { + IntermediateOperatorResultType::NeedMoreInput(Some(mp)) => { + let _ = sender.send(mp).await; break; } - IntermediateOperatorResult::NeedMoreInput(None) => { + IntermediateOperatorResultType::NeedMoreInput(None) => { break; } - IntermediateOperatorResult::HasMoreOutput(mp) => { - let _ = sender.send(mp.into()).await; + IntermediateOperatorResultType::HasMoreOutput(mp) => { + let _ = sender.send(mp).await; } } } @@ -149,62 +135,21 @@ impl IntermediateNode { pub fn spawn_workers( &self, - num_workers: usize, - destination_channel: &mut PipelineChannel, + output_senders: Vec>>, + input_receivers: Vec>>, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> Vec> { - let mut worker_senders = Vec::with_capacity(num_workers); - for _ in 0..num_workers { - let (worker_sender, worker_receiver) = create_channel(1); - let destination_sender = - destination_channel.get_next_sender_with_stats(&self.runtime_stats); + ) { + for (receiver, destination_channel) in input_receivers.into_iter().zip(output_senders) { runtime_handle.spawn( Self::run_worker( self.intermediate_op.clone(), - worker_receiver, - destination_sender, + receiver, + destination_channel, self.runtime_stats.clone(), ), self.intermediate_op.name(), ); - worker_senders.push(worker_sender); } - worker_senders - } - - pub async fn send_to_workers( - receivers: Vec, - worker_senders: Vec>, - morsel_size: usize, - ) -> DaftResult<()> { - let mut next_worker_idx = 0; - let mut send_to_next_worker = |idx, data: PipelineResultType| { - let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); - next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); - next_worker_sender.send((idx, data)) - }; - - for (idx, mut receiver) in receivers.into_iter().enumerate() { - let mut buffer = RowBasedBuffer::new(morsel_size); - while let Some(morsel) = receiver.recv().await { - if morsel.should_broadcast() { - for worker_sender in &worker_senders { - let _ = worker_sender.send((idx, morsel.clone())).await; - } - } else { - buffer.push(morsel.as_data()); - if let Some(ready) = buffer.pop_enough()? { - for part in ready { - let _ = send_to_next_worker(idx, part.into()).await; - } - } - } - } - if let Some(ready) = buffer.pop_all()? { - let _ = send_to_next_worker(idx, ready.into()).await; - } - } - Ok(()) } } @@ -240,32 +185,46 @@ impl PipelineNode for IntermediateNode { } fn start( - &mut self, + &self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + ) -> crate::Result>> { let mut child_result_receivers = Vec::with_capacity(self.children.len()); - for child in &mut self.children { - let child_result_channel = child.start(maintain_order, runtime_handle)?; - child_result_receivers - .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats)); + for child in &self.children { + let child_result_receiver = child + .start(maintain_order, runtime_handle)? + .into_counting_receiver(self.runtime_stats.clone()); + child_result_receivers.push(child_result_receiver); } - let op = self.intermediate_op.clone(); - let num_workers = op.max_concurrency(); - let mut destination_channel = PipelineChannel::new(num_workers, maintain_order); - let worker_senders = - self.spawn_workers(num_workers, &mut destination_channel, runtime_handle); + let (destination_sender, destination_receiver) = create_channel(1); + let destination_sender = + destination_sender.into_counting_sender(self.runtime_stats.clone()); + + let num_workers = self.intermediate_op.max_concurrency(); + let dispatcher = self + .intermediate_op + .make_dispatcher(runtime_handle, maintain_order); + let input_receivers = dispatcher.dispatch( + child_result_receivers, + num_workers, + runtime_handle, + self.name(), + ); + + let (output_senders, mut output_receiver) = + create_ordering_aware_receiver_channel(maintain_order, num_workers); + self.spawn_workers(output_senders, input_receivers, runtime_handle); runtime_handle.spawn( - Self::send_to_workers( - child_result_receivers, - worker_senders, - op.morsel_size() - .unwrap_or_else(|| runtime_handle.default_morsel_size()), - ), - op.name(), + async move { + while let Some(morsel) = output_receiver.recv().await { + let _ = destination_sender.send(morsel).await; + } + Ok(()) + }, + self.name(), ); - Ok(destination_channel) + Ok(destination_receiver) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/intermediate_ops/mod.rs b/src/daft-local-execution/src/intermediate_ops/mod.rs index f512343e96..b0b0d85051 100644 --- a/src/daft-local-execution/src/intermediate_ops/mod.rs +++ b/src/daft-local-execution/src/intermediate_ops/mod.rs @@ -1,5 +1,4 @@ pub mod actor_pool_project; -pub mod aggregate; pub mod anti_semi_hash_join_probe; pub mod explode; pub mod filter; diff --git a/src/daft-local-execution/src/intermediate_ops/project.rs b/src/daft-local-execution/src/intermediate_ops/project.rs index 370de989aa..a80286fcab 100644 --- a/src/daft-local-execution/src/intermediate_ops/project.rs +++ b/src/daft-local-execution/src/intermediate_ops/project.rs @@ -1,13 +1,14 @@ use std::sync::Arc; -use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, + IntermediateOpState, IntermediateOperator, IntermediateOperatorResult, + IntermediateOperatorResultType, }; -use crate::pipeline::PipelineResultType; pub struct ProjectOperator { projection: Vec, @@ -23,14 +24,21 @@ impl IntermediateOperator for ProjectOperator { #[instrument(skip_all, name = "ProjectOperator::execute")] fn execute( &self, - _idx: usize, - input: &PipelineResultType, - _state: &IntermediateOperatorState, - ) -> DaftResult { - let out = input.as_data().eval_expression_list(&self.projection)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, - )))) + input: &Arc, + state: Box, + runtime_ref: &RuntimeRef, + ) -> IntermediateOperatorResult { + let input = input.clone(); + let projection = self.projection.clone(); + runtime_ref + .spawn(async move { + let out = input.eval_expression_list(&projection)?; + Ok(( + state, + IntermediateOperatorResultType::NeedMoreInput(Some(Arc::new(out))), + )) + }) + .into() } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/intermediate_ops/sample.rs b/src/daft-local-execution/src/intermediate_ops/sample.rs index b0e4610292..9ce2b53aeb 100644 --- a/src/daft-local-execution/src/intermediate_ops/sample.rs +++ b/src/daft-local-execution/src/intermediate_ops/sample.rs @@ -1,12 +1,13 @@ use std::sync::Arc; -use common_error::DaftResult; +use common_runtime::RuntimeRef; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, + IntermediateOpState, IntermediateOperator, IntermediateOperatorResult, + IntermediateOperatorResultType, }; -use crate::pipeline::PipelineResultType; pub struct SampleOperator { fraction: f64, @@ -28,17 +29,23 @@ impl IntermediateOperator for SampleOperator { #[instrument(skip_all, name = "SampleOperator::execute")] fn execute( &self, - _idx: usize, - input: &PipelineResultType, - _state: &IntermediateOperatorState, - ) -> DaftResult { - let out = - input - .as_data() - .sample_by_fraction(self.fraction, self.with_replacement, self.seed)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, - )))) + input: &Arc, + state: Box, + runtime_ref: &RuntimeRef, + ) -> IntermediateOperatorResult { + let input = input.clone(); + let fraction = self.fraction; + let with_replacement = self.with_replacement; + let seed = self.seed; + runtime_ref + .spawn(async move { + let out = input.sample_by_fraction(fraction, with_replacement, seed)?; + Ok(( + state, + IntermediateOperatorResultType::NeedMoreInput(Some(Arc::new(out))), + )) + }) + .into() } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/intermediate_ops/unpivot.rs b/src/daft-local-execution/src/intermediate_ops/unpivot.rs index 5171f9ad42..f184f14df6 100644 --- a/src/daft-local-execution/src/intermediate_ops/unpivot.rs +++ b/src/daft-local-execution/src/intermediate_ops/unpivot.rs @@ -1,13 +1,14 @@ use std::sync::Arc; -use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, + IntermediateOpState, IntermediateOperator, IntermediateOperatorResult, + IntermediateOperatorResultType, }; -use crate::pipeline::PipelineResultType; pub struct UnpivotOperator { ids: Vec, @@ -36,19 +37,25 @@ impl IntermediateOperator for UnpivotOperator { #[instrument(skip_all, name = "UnpivotOperator::execute")] fn execute( &self, - _idx: usize, - input: &PipelineResultType, - _state: &IntermediateOperatorState, - ) -> DaftResult { - let out = input.as_data().unpivot( - &self.ids, - &self.values, - &self.variable_name, - &self.value_name, - )?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, - )))) + input: &Arc, + state: Box, + runtime_ref: &RuntimeRef, + ) -> IntermediateOperatorResult { + let input = input.clone(); + let ids = self.ids.clone(); + let values = self.values.clone(); + let variable_name = self.variable_name.clone(); + let value_name = self.value_name.clone(); + + runtime_ref + .spawn(async move { + let out = input.unpivot(&ids, &values, &variable_name, &value_name)?; + Ok(( + state, + IntermediateOperatorResultType::NeedMoreInput(Some(Arc::new(out))), + )) + }) + .into() } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index 553ad18b40..64913bff5f 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -10,6 +10,7 @@ mod sinks; mod sources; use common_error::{DaftError, DaftResult}; +use common_runtime::RuntimeTask; use lazy_static::lazy_static; pub use run::NativeExecutor; use snafu::{futures::TryFutureExt, Snafu}; @@ -18,6 +19,32 @@ lazy_static! { pub static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get(); } +pub(crate) enum MaybeFuture { + Immediate(T), + Future(RuntimeTask), +} + +impl MaybeFuture { + pub(crate) async fn output(self) -> DaftResult { + match self { + Self::Immediate(v) => Ok(v), + Self::Future(f) => f.await, + } + } +} + +impl From for MaybeFuture { + fn from(v: T) -> Self { + Self::Immediate(v) + } +} + +impl From> for MaybeFuture { + fn from(f: RuntimeTask) -> Self { + Self::Future(f) + } +} + pub(crate) struct TaskSet { inner: tokio::task::JoinSet, } @@ -45,20 +72,20 @@ impl TaskSet { } } -pub struct ExecutionRuntimeHandle { +pub(crate) struct ExecutionRuntimeHandle { worker_set: TaskSet>, default_morsel_size: usize, } impl ExecutionRuntimeHandle { #[must_use] - pub fn new(default_morsel_size: usize) -> Self { + fn new(default_morsel_size: usize) -> Self { Self { worker_set: TaskSet::new(), default_morsel_size, } } - pub fn spawn( + fn spawn( &mut self, task: impl std::future::Future> + 'static, node_name: &str, @@ -68,16 +95,15 @@ impl ExecutionRuntimeHandle { .spawn(task.with_context(|_| PipelineExecutionSnafu { node_name })); } - pub async fn join_next(&mut self) -> Option, tokio::task::JoinError>> { + async fn join_next(&mut self) -> Option, tokio::task::JoinError>> { self.worker_set.join_next().await } - pub async fn shutdown(&mut self) { + async fn shutdown(&mut self) { self.worker_set.shutdown().await; } - #[must_use] - pub fn default_morsel_size(&self) -> usize { + fn default_morsel_size(&self) -> usize { self.default_morsel_size } } diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index f15ab50543..076190b88a 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -9,23 +9,22 @@ use daft_core::{ prelude::{Schema, SchemaRef}, utils::supertype, }; -use daft_dsl::{col, join::get_common_join_keys, Expr}; +use daft_dsl::join::get_common_join_keys; use daft_micropartition::MicroPartition; use daft_physical_plan::{ ActorPoolProject, Concat, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, PhysicalWrite, Pivot, Project, Sample, Sort, UnGroupedAggregate, Unpivot, }; -use daft_plan::{populate_aggregation_stages, JoinType}; -use daft_table::ProbeState; +use daft_plan::JoinType; use daft_writers::make_writer_factory; use indexmap::IndexSet; use snafu::ResultExt; use crate::{ - channel::PipelineChannel, + channel::Receiver, intermediate_ops::{ - actor_pool_project::ActorPoolProjectOperator, aggregate::AggregateOperator, + actor_pool_project::ActorPoolProjectOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator, explode::ExplodeOperator, filter::FilterOperator, inner_hash_join_probe::InnerHashJoinProbeOperator, intermediate_op::IntermediateNode, project::ProjectOperator, sample::SampleOperator, @@ -35,7 +34,7 @@ use crate::{ aggregate::AggregateSink, blocking_sink::BlockingSinkNode, concat::ConcatSink, - hash_join_build::HashJoinBuildSink, + hash_join_build::{HashJoinBuildSink, ProbeStateBridge}, limit::LimitSink, outer_hash_join_probe::OuterHashJoinProbeSink, pivot::PivotSink, @@ -47,57 +46,19 @@ use crate::{ ExecutionRuntimeHandle, PipelineCreationSnafu, }; -#[derive(Clone)] -pub enum PipelineResultType { - Data(Arc), - ProbeState(Arc), -} - -impl From> for PipelineResultType { - fn from(data: Arc) -> Self { - Self::Data(data) - } -} - -impl From> for PipelineResultType { - fn from(probe_state: Arc) -> Self { - Self::ProbeState(probe_state) - } -} - -impl PipelineResultType { - pub fn as_data(&self) -> &Arc { - match self { - Self::Data(data) => data, - _ => panic!("Expected data"), - } - } - - pub fn as_probe_state(&self) -> &Arc { - match self { - Self::ProbeState(probe_state) => probe_state, - _ => panic!("Expected probe table"), - } - } - - pub fn should_broadcast(&self) -> bool { - matches!(self, Self::ProbeState(_)) - } -} - -pub trait PipelineNode: Sync + Send + TreeDisplay { +pub(crate) trait PipelineNode: Sync + Send + TreeDisplay { fn children(&self) -> Vec<&dyn PipelineNode>; fn name(&self) -> &'static str; fn start( - &mut self, + &self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result; + ) -> crate::Result>>; fn as_tree_display(&self) -> &dyn TreeDisplay; } -pub fn viz_pipeline(root: &dyn PipelineNode) -> String { +pub(crate) fn viz_pipeline(root: &dyn PipelineNode) -> String { let mut output = String::new(); let mut visitor = MermaidDisplayVisitor::new( &mut output, @@ -109,7 +70,7 @@ pub fn viz_pipeline(root: &dyn PipelineNode) -> String { output } -pub fn physical_plan_to_pipeline( +pub(crate) fn physical_plan_to_pipeline( physical_plan: &LocalPhysicalPlan, psets: &HashMap>>, cfg: &Arc, @@ -190,34 +151,9 @@ pub fn physical_plan_to_pipeline( schema, .. }) => { - let (first_stage_aggs, second_stage_aggs, final_exprs) = - populate_aggregation_stages(aggregations, schema, &[]); - let first_stage_agg_op = AggregateOperator::new( - first_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - vec![], - ); let child_node = physical_plan_to_pipeline(input, psets, cfg)?; - let post_first_agg_node = - IntermediateNode::new(Arc::new(first_stage_agg_op), vec![child_node]).boxed(); - - let second_stage_agg_sink = AggregateSink::new( - second_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - vec![], - ); - let second_stage_node = - BlockingSinkNode::new(Arc::new(second_stage_agg_sink), post_first_agg_node).boxed(); - - let final_stage_project = ProjectOperator::new(final_exprs); - - IntermediateNode::new(Arc::new(final_stage_project), vec![second_stage_node]).boxed() + let agg_sink = AggregateSink::new(aggregations, &[], schema); + BlockingSinkNode::new(Arc::new(agg_sink), child_node).boxed() } LocalPhysicalPlan::HashAggregate(HashAggregate { input, @@ -226,40 +162,9 @@ pub fn physical_plan_to_pipeline( schema, .. }) => { - let (first_stage_aggs, second_stage_aggs, final_exprs) = - populate_aggregation_stages(aggregations, schema, group_by); let child_node = physical_plan_to_pipeline(input, psets, cfg)?; - let (post_first_agg_node, group_by) = if !first_stage_aggs.is_empty() { - let agg_op = AggregateOperator::new( - first_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - group_by.clone(), - ); - ( - IntermediateNode::new(Arc::new(agg_op), vec![child_node]).boxed(), - &group_by.iter().map(|e| col(e.name())).collect(), - ) - } else { - (child_node, group_by) - }; - - let second_stage_agg_sink = AggregateSink::new( - second_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - group_by.clone(), - ); - let second_stage_node = - BlockingSinkNode::new(Arc::new(second_stage_agg_sink), post_first_agg_node).boxed(); - - let final_stage_project = ProjectOperator::new(final_exprs); - - IntermediateNode::new(Arc::new(final_stage_project), vec![second_stage_node]).boxed() + let agg_sink = AggregateSink::new(aggregations, group_by, schema); + BlockingSinkNode::new(Arc::new(agg_sink), child_node).boxed() } LocalPhysicalPlan::Unpivot(Unpivot { input, @@ -373,7 +278,13 @@ pub fn physical_plan_to_pipeline( .collect::>(); // we should move to a builder pattern - let build_sink = HashJoinBuildSink::new(key_schema, casted_build_on, join_type)?; + let probe_state_bridge = ProbeStateBridge::new(); + let build_sink = HashJoinBuildSink::new( + key_schema, + casted_build_on, + join_type, + probe_state_bridge.clone(), + )?; let build_child_node = physical_plan_to_pipeline(build_child, psets, cfg)?; let build_node = BlockingSinkNode::new(Arc::new(build_sink), build_child_node).boxed(); @@ -386,6 +297,7 @@ pub fn physical_plan_to_pipeline( casted_probe_on, join_type, schema, + probe_state_bridge, )), vec![build_node, probe_child_node], ) @@ -398,6 +310,7 @@ pub fn physical_plan_to_pipeline( build_on_left, common_join_keys, schema, + probe_state_bridge, )), vec![build_node, probe_child_node], ) @@ -411,6 +324,7 @@ pub fn physical_plan_to_pipeline( *join_type, common_join_keys, schema, + probe_state_bridge, )), vec![build_node, probe_child_node], ) diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 3cab41ef36..6732f2aa4a 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -111,14 +111,14 @@ fn should_enable_explain_analyze() -> bool { } } -pub fn run_local( +fn run_local( physical_plan: &LocalPhysicalPlan, psets: HashMap>>, cfg: Arc, results_buffer_size: Option, ) -> DaftResult>> + Send>> { refresh_chrome_trace(); - let mut pipeline = physical_plan_to_pipeline(physical_plan, &psets, &cfg)?; + let pipeline = physical_plan_to_pipeline(physical_plan, &psets, &cfg)?; let (tx, rx) = create_channel(results_buffer_size.unwrap_or(1)); let handle = std::thread::spawn(move || { let runtime = tokio::runtime::Builder::new_current_thread() @@ -127,10 +127,10 @@ pub fn run_local( .expect("Failed to create tokio runtime"); let execution_task = async { let mut runtime_handle = ExecutionRuntimeHandle::new(cfg.default_morsel_size); - let mut receiver = pipeline.start(true, &mut runtime_handle)?.get_receiver(); + let receiver = pipeline.start(true, &mut runtime_handle)?; while let Some(val) = receiver.recv().await { - let _ = tx.send(val.as_data().clone()).await; + let _ = tx.send(val).await; } while let Some(result) = runtime_handle.join_next().await { diff --git a/src/daft-local-execution/src/runtime_stats.rs b/src/daft-local-execution/src/runtime_stats.rs index 566d253e9c..b760c4f32c 100644 --- a/src/daft-local-execution/src/runtime_stats.rs +++ b/src/daft-local-execution/src/runtime_stats.rs @@ -5,25 +5,23 @@ use std::{ time::Instant, }; -use tokio::sync::mpsc::error::SendError; +use daft_micropartition::MicroPartition; +use loole::SendError; -use crate::{ - channel::{PipelineReceiver, Sender}, - pipeline::PipelineResultType, -}; +use crate::channel::{Receiver, Sender}; #[derive(Default)] -pub struct RuntimeStatsContext { +pub(crate) struct RuntimeStatsContext { rows_received: AtomicU64, rows_emitted: AtomicU64, cpu_us: AtomicU64, } #[derive(Debug)] -pub struct RuntimeStats { - pub rows_received: u64, - pub rows_emitted: u64, - pub cpu_us: u64, +pub(crate) struct RuntimeStats { + rows_received: u64, + rows_emitted: u64, + cpu_us: u64, } impl RuntimeStats { @@ -108,51 +106,44 @@ impl RuntimeStatsContext { } } -pub struct CountingSender { - sender: Sender, +pub(crate) struct CountingSender { + sender: Sender>, rt: Arc, } impl CountingSender { - pub(crate) fn new(sender: Sender, rt: Arc) -> Self { + pub(crate) fn new(sender: Sender>, rt: Arc) -> Self { Self { sender, rt } } #[inline] pub(crate) async fn send( &self, - v: PipelineResultType, - ) -> Result<(), SendError> { - let len = match v { - PipelineResultType::Data(ref mp) => mp.len(), - PipelineResultType::ProbeState(ref state) => { - state.get_tables().iter().map(|t| t.len()).sum() - } - }; + v: Arc, + ) -> Result<(), SendError>> { + let len = v.len(); self.sender.send(v).await?; self.rt.mark_rows_emitted(len as u64); Ok(()) } } -pub struct CountingReceiver { - receiver: PipelineReceiver, +pub(crate) struct CountingReceiver { + receiver: Receiver>, rt: Arc, } impl CountingReceiver { - pub(crate) fn new(receiver: PipelineReceiver, rt: Arc) -> Self { + pub(crate) fn new( + receiver: Receiver>, + rt: Arc, + ) -> Self { Self { receiver, rt } } #[inline] - pub(crate) async fn recv(&mut self) -> Option { + pub(crate) async fn recv(&self) -> Option> { let v = self.receiver.recv().await; if let Some(ref v) = v { - let len = match v { - PipelineResultType::Data(ref mp) => mp.len(), - PipelineResultType::ProbeState(state) => { - state.get_tables().iter().map(|t| t.len()).sum() - } - }; + let len = v.len(); self.rt.mark_rows_received(len as u64); } v diff --git a/src/daft-local-execution/src/sinks/aggregate.rs b/src/daft-local-execution/src/sinks/aggregate.rs index abc8acce4c..2fb64c27f8 100644 --- a/src/daft-local-execution/src/sinks/aggregate.rs +++ b/src/daft-local-execution/src/sinks/aggregate.rs @@ -1,12 +1,14 @@ use std::sync::Arc; use common_error::DaftResult; -use daft_dsl::ExprRef; +use daft_core::prelude::SchemaRef; +use daft_dsl::{col, AggExpr, Expr, ExprRef}; use daft_micropartition::MicroPartition; +use daft_plan::populate_aggregation_stages; use tracing::instrument; use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; -use crate::{pipeline::PipelineResultType, NUM_CPUS}; +use crate::NUM_CPUS; enum AggregateState { Accumulating(Vec>), @@ -40,15 +42,38 @@ impl BlockingSinkState for AggregateState { } pub struct AggregateSink { - agg_exprs: Vec, - group_by: Vec, + sink_aggs: Vec, + finalize_aggs: Vec, + final_projections: Vec, + sink_group_by: Vec, + finalize_group_by: Vec, } impl AggregateSink { - pub fn new(agg_exprs: Vec, group_by: Vec) -> Self { + pub fn new(agg_exprs: &[AggExpr], group_by: &[ExprRef], schema: &SchemaRef) -> Self { + let (sink_aggs, finalize_aggs, final_projections) = + populate_aggregation_stages(agg_exprs, schema, group_by); + let sink_aggs = sink_aggs + .values() + .cloned() + .map(|e| Arc::new(Expr::Agg(e))) + .collect::>(); + let finalize_aggs = finalize_aggs + .values() + .cloned() + .map(|e| Arc::new(Expr::Agg(e))) + .collect::>(); + let finalize_group_by = if sink_aggs.is_empty() { + group_by.to_vec() + } else { + group_by.iter().map(|e| col(e.name())).collect() + }; Self { - agg_exprs, - group_by, + sink_aggs, + finalize_aggs, + final_projections, + sink_group_by: group_by.to_vec(), + finalize_group_by, } } } @@ -60,11 +85,16 @@ impl BlockingSink for AggregateSink { input: &Arc, mut state: Box, ) -> DaftResult { - state + let agg_state = state .as_any_mut() .downcast_mut::() - .expect("AggregateSink should have AggregateState") - .push(input.clone()); + .expect("AggregateSink should have AggregateState"); + if self.sink_aggs.is_empty() { + agg_state.push(input.clone()); + } else { + let agged = input.agg(&self.sink_aggs, &self.sink_group_by)?; + agg_state.push(agged.into()); + } Ok(BlockingSinkStatus::NeedMoreInput(state)) } @@ -72,7 +102,7 @@ impl BlockingSink for AggregateSink { fn finalize( &self, states: Vec>, - ) -> DaftResult> { + ) -> DaftResult>> { let all_parts = states.into_iter().flat_map(|mut state| { state .as_any_mut() @@ -81,8 +111,9 @@ impl BlockingSink for AggregateSink { .finalize() }); let concated = MicroPartition::concat(all_parts)?; - let agged = Arc::new(concated.agg(&self.agg_exprs, &self.group_by)?); - Ok(Some(agged.into())) + let agged = concated.agg(&self.finalize_aggs, &self.finalize_group_by)?; + let projected = Arc::new(agged.eval_expression_list(&self.final_projections)?); + Ok(Some(projected)) } fn name(&self) -> &'static str { @@ -96,4 +127,13 @@ impl BlockingSink for AggregateSink { fn make_state(&self) -> DaftResult> { Ok(Box::new(AggregateState::Accumulating(vec![]))) } + + fn make_dispatcher( + &self, + runtime_handle: &crate::ExecutionRuntimeHandle, + ) -> Arc { + Arc::new(crate::dispatcher::UnorderedDispatcher::new(Some( + runtime_handle.default_morsel_size(), + ))) + } } diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 51ef590d09..aa8c646c90 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -8,9 +8,9 @@ use snafu::ResultExt; use tracing::{info_span, instrument}; use crate::{ - channel::{create_channel, PipelineChannel, Receiver}, - dispatcher::{Dispatcher, RoundRobinBufferedDispatcher}, - pipeline::{PipelineNode, PipelineResultType}, + channel::{create_channel, Receiver}, + dispatcher::Dispatcher, + pipeline::PipelineNode, runtime_stats::RuntimeStatsContext, ExecutionRuntimeHandle, JoinSnafu, TaskSet, }; @@ -33,14 +33,10 @@ pub trait BlockingSink: Send + Sync { fn finalize( &self, states: Vec>, - ) -> DaftResult>; + ) -> DaftResult>>; fn name(&self) -> &'static str; fn make_state(&self) -> DaftResult>; - fn make_dispatcher(&self, runtime_handle: &ExecutionRuntimeHandle) -> Arc { - Arc::new(RoundRobinBufferedDispatcher::new( - runtime_handle.default_morsel_size(), - )) - } + fn make_dispatcher(&self, runtime_handle: &ExecutionRuntimeHandle) -> Arc; fn max_concurrency(&self) -> usize; } @@ -68,7 +64,7 @@ impl BlockingSinkNode { #[instrument(level = "info", skip_all, name = "BlockingSink::run_worker")] async fn run_worker( op: Arc, - mut input_receiver: Receiver, + input_receiver: Receiver>, rt_context: Arc, ) -> DaftResult> { let span = info_span!("BlockingSink::Sink"); @@ -76,10 +72,9 @@ impl BlockingSinkNode { let mut state = op.make_state()?; while let Some(morsel) = input_receiver.recv().await { let op = op.clone(); - let morsel = morsel.clone(); let span = span.clone(); let rt_context = rt_context.clone(); - let fut = async move { rt_context.in_span(&span, || op.sink(morsel.as_data(), state)) }; + let fut = async move { rt_context.in_span(&span, || op.sink(&morsel, state)) }; let result = compute_runtime.spawn(fut).await??; match result { BlockingSinkStatus::NeedMoreInput(new_state) => { @@ -96,7 +91,7 @@ impl BlockingSinkNode { fn spawn_workers( op: Arc, - input_receivers: Vec>, + input_receivers: Vec>>, task_set: &mut TaskSet>>, stats: Arc, ) { @@ -135,29 +130,27 @@ impl PipelineNode for BlockingSinkNode { } fn start( - &mut self, - maintain_order: bool, + &self, + _maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { - let child = self.child.as_mut(); - let child_results_receiver = child + ) -> crate::Result>> { + let child_results_receiver = self + .child .start(false, runtime_handle)? - .get_receiver_with_stats(&self.runtime_stats); + .into_counting_receiver(self.runtime_stats.clone()); - let mut destination_channel = PipelineChannel::new(1, maintain_order); + let (destination_sender, destination_receiver) = create_channel(1); let destination_sender = - destination_channel.get_next_sender_with_stats(&self.runtime_stats); + destination_sender.into_counting_sender(self.runtime_stats.clone()); + let op = self.op.clone(); let runtime_stats = self.runtime_stats.clone(); let num_workers = op.max_concurrency(); - let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); let dispatcher = op.make_dispatcher(runtime_handle); - runtime_handle.spawn( - async move { - dispatcher - .dispatch(child_results_receiver, input_senders) - .await - }, + let input_receivers = dispatcher.dispatch( + vec![child_results_receiver], + num_workers, + runtime_handle, self.name(), ); @@ -192,7 +185,7 @@ impl PipelineNode for BlockingSinkNode { }, self.name(), ); - Ok(destination_channel) + Ok(destination_receiver) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/sinks/concat.rs b/src/daft-local-execution/src/sinks/concat.rs index 3fc710c691..9fb86c435c 100644 --- a/src/daft-local-execution/src/sinks/concat.rs +++ b/src/daft-local-execution/src/sinks/concat.rs @@ -1,19 +1,20 @@ use std::sync::Arc; -use common_error::{DaftError, DaftResult}; +use common_runtime::RuntimeRef; use daft_micropartition::MicroPartition; use tracing::instrument; use super::streaming_sink::{ - DynStreamingSinkState, StreamingSink, StreamingSinkOutput, StreamingSinkState, + StreamingSink, StreamingSinkExecuteOutput, StreamingSinkFinalizeOutput, + StreamingSinkOutputType, StreamingSinkState, +}; +use crate::{ + dispatcher::{Dispatcher, RoundRobinDispatcher, UnorderedDispatcher}, + ExecutionRuntimeHandle, MaybeFuture, NUM_CPUS, }; -use crate::pipeline::PipelineResultType; -struct ConcatSinkState { - // The index of the last morsel of data that was received, which should be strictly non-decreasing. - pub curr_idx: usize, -} -impl DynStreamingSinkState for ConcatSinkState { +struct ConcatSinkState {} +impl StreamingSinkState for ConcatSinkState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } @@ -22,27 +23,20 @@ impl DynStreamingSinkState for ConcatSinkState { pub struct ConcatSink {} impl StreamingSink for ConcatSink { - /// Execute for the ConcatSink operator does not do any computation and simply returns the input data. - /// It only expects that the indices of the input data are strictly non-decreasing. - /// TODO(Colin): If maintain_order is false, technically we could accept any index. Make this optimization later. + /// By default, if the streaming_sink is called with maintain_order = true, input is distributed round-robin to the workers, + /// and the output is received in the same order. Therefore, the 'execute' method does not need to do anything. + /// If maintain_order = false, the input is distributed randomly to the workers, and the output is received in random order. #[instrument(skip_all, name = "ConcatSink::sink")] fn execute( &self, - index: usize, - input: &PipelineResultType, - state_handle: &StreamingSinkState, - ) -> DaftResult { - state_handle.with_state_mut::(|state| { - // If the index is the same as the current index or one more than the current index, then we can accept the morsel. - if state.curr_idx == index || state.curr_idx + 1 == index { - state.curr_idx = index; - Ok(StreamingSinkOutput::NeedMoreInput(Some( - input.as_data().clone(), - ))) - } else { - Err(DaftError::ComputeError(format!("Concat sink received out-of-order data. Expected index to be {} or {}, but got {}.", state.curr_idx, state.curr_idx + 1, index))) - } - }) + input: &Arc, + state: Box, + _runtime_ref: &RuntimeRef, + ) -> StreamingSinkExecuteOutput { + MaybeFuture::Immediate(Ok(( + state, + StreamingSinkOutputType::NeedMoreInput(Some(input.clone())), + ))) } fn name(&self) -> &'static str { @@ -51,17 +45,33 @@ impl StreamingSink for ConcatSink { fn finalize( &self, - _states: Vec>, - ) -> DaftResult>> { - Ok(None) + _states: Vec>, + _runtime_ref: &RuntimeRef, + ) -> StreamingSinkFinalizeOutput { + MaybeFuture::Immediate(Ok(None)) } - fn make_state(&self) -> Box { - Box::new(ConcatSinkState { curr_idx: 0 }) + fn make_state(&self) -> Box { + Box::new(ConcatSinkState {}) } - /// Since the ConcatSink does not do any computation, it does not need to spawn multiple workers. fn max_concurrency(&self) -> usize { - 1 + *NUM_CPUS + } + + fn make_dispatcher( + &self, + runtime_handle: &ExecutionRuntimeHandle, + maintain_order: bool, + ) -> Arc { + if maintain_order { + Arc::new(RoundRobinDispatcher::new(Some( + runtime_handle.default_morsel_size(), + ))) + } else { + Arc::new(UnorderedDispatcher::new(Some( + runtime_handle.default_morsel_size(), + ))) + } } } diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index 677f63279d..aa4071d100 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use common_error::DaftResult; use daft_core::prelude::SchemaRef; @@ -8,7 +8,39 @@ use daft_plan::JoinType; use daft_table::{make_probeable_builder, ProbeState, ProbeableBuilder, Table}; use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; -use crate::pipeline::PipelineResultType; +use crate::dispatcher::{Dispatcher, UnorderedDispatcher}; + +pub(crate) type ProbeStateBridgeRef = Arc; +pub(crate) struct ProbeStateBridge { + inner: OnceLock>, + notify: tokio::sync::Notify, +} + +impl ProbeStateBridge { + pub(crate) fn new() -> Arc { + Arc::new(Self { + inner: OnceLock::new(), + notify: tokio::sync::Notify::new(), + }) + } + + pub(crate) fn set_probe_state(&self, state: Arc) { + assert!( + !self.inner.set(state).is_err(), + "ProbeStateBridge should be set only once" + ); + self.notify.notify_waiters(); + } + + pub(crate) async fn get_probe_state(&self) -> Arc { + loop { + if let Some(state) = self.inner.get() { + return state.clone(); + } + self.notify.notified().await; + } + } +} enum ProbeTableState { Building { @@ -16,9 +48,7 @@ enum ProbeTableState { projection: Vec, tables: Vec, }, - Done { - probe_state: Arc, - }, + Done, } impl ProbeTableState { @@ -54,7 +84,7 @@ impl ProbeTableState { panic!("add_tables can only be used during the Building Phase") } } - fn finalize(&mut self) -> DaftResult<()> { + fn finalize(&mut self) -> DaftResult { if let Self::Building { probe_table_builder, tables, @@ -64,10 +94,9 @@ impl ProbeTableState { let ptb = std::mem::take(probe_table_builder).expect("should be set in building mode"); let pt = ptb.build(); - *self = Self::Done { - probe_state: Arc::new(ProbeState::new(pt, Arc::new(tables.clone()))), - }; - Ok(()) + let ps = ProbeState::new(pt, tables.clone().into()); + *self = Self::Done; + Ok(ps) } else { panic!("finalize can only be used during the Building Phase") } @@ -84,6 +113,7 @@ pub struct HashJoinBuildSink { key_schema: SchemaRef, projection: Vec, join_type: JoinType, + probe_state_bridge: ProbeStateBridgeRef, } impl HashJoinBuildSink { @@ -91,11 +121,13 @@ impl HashJoinBuildSink { key_schema: SchemaRef, projection: Vec, join_type: &JoinType, + probe_state_bridge: ProbeStateBridgeRef, ) -> DaftResult { Ok(Self { key_schema, projection, join_type: *join_type, + probe_state_bridge, }) } } @@ -121,19 +153,17 @@ impl BlockingSink for HashJoinBuildSink { fn finalize( &self, states: Vec>, - ) -> DaftResult> { + ) -> DaftResult>> { assert_eq!(states.len(), 1); let mut state = states.into_iter().next().unwrap(); let probe_table_state = state .as_any_mut() .downcast_mut::() .expect("State type mismatch"); - probe_table_state.finalize()?; - if let ProbeTableState::Done { probe_state } = probe_table_state { - Ok(Some(probe_state.clone().into())) - } else { - panic!("finalize should only be called after the probe table is built") - } + let finalized_probe_state = probe_table_state.finalize()?; + self.probe_state_bridge + .set_probe_state(finalized_probe_state.into()); + Ok(None) } fn max_concurrency(&self) -> usize { @@ -147,4 +177,11 @@ impl BlockingSink for HashJoinBuildSink { &self.join_type, )?)) } + + fn make_dispatcher( + &self, + _runtime_handle: &crate::ExecutionRuntimeHandle, + ) -> Arc { + Arc::new(UnorderedDispatcher::new(None)) + } } diff --git a/src/daft-local-execution/src/sinks/limit.rs b/src/daft-local-execution/src/sinks/limit.rs index ff24a703e2..2baa5cd359 100644 --- a/src/daft-local-execution/src/sinks/limit.rs +++ b/src/daft-local-execution/src/sinks/limit.rs @@ -1,13 +1,17 @@ use std::sync::Arc; -use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_micropartition::MicroPartition; use tracing::instrument; use super::streaming_sink::{ - DynStreamingSinkState, StreamingSink, StreamingSinkOutput, StreamingSinkState, + StreamingSink, StreamingSinkExecuteOutput, StreamingSinkFinalizeOutput, + StreamingSinkOutputType, StreamingSinkState, +}; +use crate::{ + dispatcher::{Dispatcher, UnorderedDispatcher}, + ExecutionRuntimeHandle, MaybeFuture, }; -use crate::pipeline::PipelineResultType; struct LimitSinkState { remaining: usize, @@ -23,7 +27,7 @@ impl LimitSinkState { } } -impl DynStreamingSinkState for LimitSinkState { +impl StreamingSinkState for LimitSinkState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } @@ -43,33 +47,45 @@ impl StreamingSink for LimitSink { #[instrument(skip_all, name = "LimitSink::sink")] fn execute( &self, - index: usize, - input: &PipelineResultType, - state_handle: &StreamingSinkState, - ) -> DaftResult { - assert_eq!(index, 0); - let input = input.as_data(); + input: &Arc, + mut state: Box, + runtime_ref: &RuntimeRef, + ) -> StreamingSinkExecuteOutput { let input_num_rows = input.len(); - state_handle.with_state_mut::(|state| { - let remaining = state.get_remaining_mut(); - use std::cmp::Ordering::{Equal, Greater, Less}; - match input_num_rows.cmp(remaining) { - Less => { - *remaining -= input_num_rows; - Ok(StreamingSinkOutput::NeedMoreInput(Some(input.clone()))) - } - Equal => { - *remaining = 0; - Ok(StreamingSinkOutput::Finished(Some(input.clone()))) - } - Greater => { - let taken = input.head(*remaining)?; - *remaining = 0; - Ok(StreamingSinkOutput::Finished(Some(Arc::new(taken)))) - } + let remaining = state + .as_any_mut() + .downcast_mut::() + .expect("Limit sink should have LimitSinkState") + .get_remaining_mut(); + use std::cmp::Ordering::{Equal, Greater, Less}; + match input_num_rows.cmp(remaining) { + Less => { + *remaining -= input_num_rows; + MaybeFuture::Immediate(Ok(( + state, + StreamingSinkOutputType::NeedMoreInput(Some(input.clone())), + ))) + } + Equal => { + *remaining = 0; + MaybeFuture::Immediate(Ok(( + state, + StreamingSinkOutputType::Finished(Some(input.clone())), + ))) + } + Greater => { + let input = input.clone(); + let to_head = *remaining; + *remaining = 0; + let fut = runtime_ref.spawn(async move { + let taken = input.head(to_head)?; + Ok((state, StreamingSinkOutputType::Finished(Some(taken.into())))) + }); + + MaybeFuture::Future(fut) } - }) + } } fn name(&self) -> &'static str { @@ -78,16 +94,27 @@ impl StreamingSink for LimitSink { fn finalize( &self, - _states: Vec>, - ) -> DaftResult>> { - Ok(None) + _states: Vec>, + _runtime_ref: &RuntimeRef, + ) -> StreamingSinkFinalizeOutput { + MaybeFuture::Immediate(Ok(None)) } - fn make_state(&self) -> Box { + fn make_state(&self) -> Box { Box::new(LimitSinkState::new(self.limit)) } fn max_concurrency(&self) -> usize { 1 } + + fn make_dispatcher( + &self, + _runtime_handle: &ExecutionRuntimeHandle, + _maintain_order: bool, + ) -> Arc { + // LimitSink should be greedy, and accept all input as soon as possible. + // It is also not concurrent, so we don't need to worry about ordering. + Arc::new(UnorderedDispatcher::new(None)) + } } diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index 23cefecf92..e8dd7d2f09 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use common_error::DaftResult; +use common_runtime::RuntimeRef; use daft_core::{ prelude::{ bitmap::{and, Bitmap, MutableBitmap}, @@ -12,13 +13,21 @@ use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use daft_plan::JoinType; use daft_table::{GrowableTable, ProbeState, Table}; +use futures::{stream, StreamExt}; use indexmap::IndexSet; use tracing::{info_span, instrument}; -use super::streaming_sink::{ - DynStreamingSinkState, StreamingSink, StreamingSinkOutput, StreamingSinkState, +use super::{ + hash_join_build::ProbeStateBridgeRef, + streaming_sink::{ + StreamingSink, StreamingSinkExecuteOutput, StreamingSinkFinalizeOutput, + StreamingSinkOutputType, StreamingSinkState, + }, +}; +use crate::{ + dispatcher::{Dispatcher, RoundRobinDispatcher, UnorderedDispatcher}, + ExecutionRuntimeHandle, MaybeFuture, }; -use crate::pipeline::PipelineResultType; struct IndexBitmapBuilder { mutable_bitmaps: Vec, @@ -69,59 +78,56 @@ impl IndexBitmap { } } -enum OuterHashJoinProbeState { - Building, - ReadyToProbe(Arc, Option), +enum OuterHashJoinState { + Building(ProbeStateBridgeRef, bool), + Probing(Arc, Option), } -impl OuterHashJoinProbeState { - fn initialize_probe_state(&mut self, probe_state: Arc, needs_bitmap: bool) { - let tables = probe_state.get_tables(); - if matches!(self, Self::Building) { - *self = Self::ReadyToProbe( - probe_state.clone(), - if needs_bitmap { - Some(IndexBitmapBuilder::new(tables)) - } else { - None - }, - ); - } else { - panic!("OuterHashJoinProbeState should only be in Building state when setting table") - } - } - - fn get_probe_state(&self) -> &ProbeState { - if let Self::ReadyToProbe(probe_state, _) = self { - probe_state - } else { - panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") +impl OuterHashJoinState { + async fn get_or_build_probe_state(&mut self) -> Arc { + match self { + Self::Building(bridge, needs_bitmap) => { + let probe_state = bridge.get_probe_state().await; + let builder = + needs_bitmap.then(|| IndexBitmapBuilder::new(probe_state.get_tables())); + *self = Self::Probing(probe_state.clone(), builder); + probe_state + } + Self::Probing(probe_state, _) => probe_state.clone(), } } - fn get_bitmap_builder(&mut self) -> &mut Option { - if let Self::ReadyToProbe(_, bitmap_builder) = self { - bitmap_builder - } else { - panic!("get_bitmap can only be used during the ReadyToProbe Phase") + async fn get_or_build_bitmap(&mut self) -> &mut Option { + match self { + Self::Building(bridge, _) => { + let probe_state = bridge.get_probe_state().await; + let builder = IndexBitmapBuilder::new(probe_state.get_tables()); + *self = Self::Probing(probe_state, Some(builder)); + match self { + Self::Probing(_, builder) => builder, + _ => unreachable!(), + } + } + Self::Probing(_, builder) => builder, } } } -impl DynStreamingSinkState for OuterHashJoinProbeState { +impl StreamingSinkState for OuterHashJoinState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } } pub(crate) struct OuterHashJoinProbeSink { - probe_on: Vec, - common_join_keys: Vec, - left_non_join_columns: Vec, - right_non_join_columns: Vec, + probe_on: Arc>, + common_join_keys: Arc>, + left_non_join_columns: Arc>, + right_non_join_columns: Arc>, right_non_join_schema: SchemaRef, join_type: JoinType, output_schema: SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, } impl OuterHashJoinProbeSink { @@ -132,13 +138,16 @@ impl OuterHashJoinProbeSink { join_type: JoinType, common_join_keys: IndexSet, output_schema: &SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, ) -> Self { - let left_non_join_columns = left_schema - .fields - .keys() - .filter(|c| !common_join_keys.contains(*c)) - .cloned() - .collect(); + let left_non_join_columns = Arc::new( + left_schema + .fields + .keys() + .filter(|c| !common_join_keys.contains(*c)) + .cloned() + .collect(), + ); let right_non_join_fields = right_schema .fields .values() @@ -147,28 +156,32 @@ impl OuterHashJoinProbeSink { .collect(); let right_non_join_schema = Arc::new(Schema::new(right_non_join_fields).expect("right schema should be valid")); - let right_non_join_columns = right_non_join_schema.fields.keys().cloned().collect(); - let common_join_keys = common_join_keys.into_iter().collect(); + let right_non_join_columns = + Arc::new(right_non_join_schema.fields.keys().cloned().collect()); + let common_join_keys = Arc::new(common_join_keys.into_iter().collect()); Self { - probe_on, + probe_on: Arc::new(probe_on), common_join_keys, left_non_join_columns, right_non_join_columns, right_non_join_schema, join_type, output_schema: output_schema.clone(), + probe_state_bridge, } } fn probe_left_right( - &self, input: &Arc, - state: &OuterHashJoinProbeState, + probe_state: &ProbeState, + join_type: JoinType, + probe_on: &[ExprRef], + common_join_keys: &[String], + left_non_join_columns: &[String], + right_non_join_columns: &[String], ) -> DaftResult> { - let (probe_table, tables) = { - let probe_state = state.get_probe_state(); - (probe_state.get_probeable(), probe_state.get_tables()) - }; + let probe_table = probe_state.get_probeable().clone(); + let tables = probe_state.get_tables().clone(); let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); let mut build_side_growable = GrowableTable::new( @@ -185,7 +198,7 @@ impl OuterHashJoinProbeSink { { let _loop = info_span!("OuterHashJoinProbeSink::eval_and_probe").entered(); for (probe_side_table_idx, table) in input_tables.iter().enumerate() { - let join_keys = table.eval_expression_list(&self.probe_on)?; + let join_keys = table.eval_expression_list(probe_on)?; let idx_mapper = probe_table.probe_indices(&join_keys)?; for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { @@ -209,15 +222,15 @@ impl OuterHashJoinProbeSink { let build_side_table = build_side_growable.build()?; let probe_side_table = probe_side_growable.build()?; - let final_table = if self.join_type == JoinType::Left { - let join_table = probe_side_table.get_columns(&self.common_join_keys)?; - let left = probe_side_table.get_columns(&self.left_non_join_columns)?; - let right = build_side_table.get_columns(&self.right_non_join_columns)?; + let final_table = if join_type == JoinType::Left { + let join_table = probe_side_table.get_columns(common_join_keys)?; + let left = probe_side_table.get_columns(left_non_join_columns)?; + let right = build_side_table.get_columns(right_non_join_columns)?; join_table.union(&left)?.union(&right)? } else { - let join_table = probe_side_table.get_columns(&self.common_join_keys)?; - let left = build_side_table.get_columns(&self.left_non_join_columns)?; - let right = probe_side_table.get_columns(&self.right_non_join_columns)?; + let join_table = probe_side_table.get_columns(common_join_keys)?; + let left = build_side_table.get_columns(left_non_join_columns)?; + let right = probe_side_table.get_columns(right_non_join_columns)?; join_table.union(&left)?.union(&right)? }; Ok(Arc::new(MicroPartition::new_loaded( @@ -228,18 +241,17 @@ impl OuterHashJoinProbeSink { } fn probe_outer( - &self, input: &Arc, - state: &mut OuterHashJoinProbeState, + probe_state: &ProbeState, + bitmap_builder: &mut IndexBitmapBuilder, + probe_on: &[ExprRef], + common_join_keys: &[String], + left_non_join_columns: &[String], + right_non_join_columns: &[String], ) -> DaftResult> { - let (probe_table, tables) = { - let probe_state = state.get_probe_state(); - ( - probe_state.get_probeable().clone(), - probe_state.get_tables().clone(), - ) - }; - let bitmap_builder = state.get_bitmap_builder(); + let probe_table = probe_state.get_probeable().clone(); + let tables = probe_state.get_tables().clone(); + let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); // Need to set use_validity to true here because we add nulls to the build side let mut build_side_growable = GrowableTable::new( @@ -252,15 +264,11 @@ impl OuterHashJoinProbeSink { let mut probe_side_growable = GrowableTable::new(&input_tables.iter().collect::>(), false, input.len())?; - let left_idx_used = bitmap_builder - .as_mut() - .expect("bitmap should be set in outer join"); - drop(_growables); { let _loop = info_span!("OuterHashJoinProbeSink::eval_and_probe").entered(); for (probe_side_table_idx, table) in input_tables.iter().enumerate() { - let join_keys = table.eval_expression_list(&self.probe_on)?; + let join_keys = table.eval_expression_list(probe_on)?; let idx_mapper = probe_table.probe_indices(&join_keys)?; for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { @@ -268,7 +276,7 @@ impl OuterHashJoinProbeSink { for (build_side_table_idx, build_row_idx) in inner_iter { let build_side_table_idx = build_side_table_idx as usize; let build_row_idx = build_row_idx as usize; - left_idx_used.mark_used(build_side_table_idx, build_row_idx); + bitmap_builder.mark_used(build_side_table_idx, build_row_idx); build_side_growable.extend(build_side_table_idx, build_row_idx, 1); probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); } @@ -283,9 +291,9 @@ impl OuterHashJoinProbeSink { let build_side_table = build_side_growable.build()?; let probe_side_table = probe_side_growable.build()?; - let join_table = probe_side_table.get_columns(&self.common_join_keys)?; - let left = build_side_table.get_columns(&self.left_non_join_columns)?; - let right = probe_side_table.get_columns(&self.right_non_join_columns)?; + let join_table = probe_side_table.get_columns(common_join_keys)?; + let left = build_side_table.get_columns(left_non_join_columns)?; + let right = probe_side_table.get_columns(right_non_join_columns)?; let final_table = join_table.union(&left)?.union(&right)?; Ok(Arc::new(MicroPartition::new_loaded( final_table.schema.clone(), @@ -294,37 +302,49 @@ impl OuterHashJoinProbeSink { ))) } - fn finalize_outer( - &self, - mut states: Vec>, + async fn finalize_outer( + mut states: Vec>, + common_join_keys: &[String], + left_non_join_columns: &[String], + right_non_join_schema: &SchemaRef, ) -> DaftResult>> { - let states = states - .iter_mut() - .map(|s| { - s.as_any_mut() - .downcast_mut::() - .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState") - }) - .collect::>(); - let tables = states - .first() + let mut states_iter = states.iter_mut(); + let first_state = states_iter + .next() .expect("at least one state should be present") - .get_probe_state() + .as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); + let tables = first_state + .get_or_build_probe_state() + .await .get_tables() .clone(); + let first_bitmap = first_state + .get_or_build_bitmap() + .await + .take() + .expect("bitmap should be set") + .build(); let merged_bitmap = { - let bitmaps = states.into_iter().map(|s| { - if let OuterHashJoinProbeState::ReadyToProbe(_, bitmap) = s { - bitmap + let bitmaps = stream::once(async move { first_bitmap }) + .chain(stream::iter(states_iter).then(|s| async move { + let state = s + .as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); + state + .get_or_build_bitmap() + .await .take() - .expect("bitmap should be present in outer join") + .expect("bitmap should be set") .build() - } else { - panic!("OuterHashJoinProbeState should be in ReadyToProbe state") - } - }); - bitmaps.fold(None, |acc, x| match acc { + })) + .collect::>() + .await; + + bitmaps.into_iter().fold(None, |acc, x| match acc { None => Some(x), Some(acc) => Some(acc.merge(&x)), }) @@ -339,16 +359,15 @@ impl OuterHashJoinProbeSink { let build_side_table = Table::concat(&leftovers)?; - let join_table = build_side_table.get_columns(&self.common_join_keys)?; - let left = build_side_table.get_columns(&self.left_non_join_columns)?; + let join_table = build_side_table.get_columns(common_join_keys)?; + let left = build_side_table.get_columns(left_non_join_columns)?; let right = { - let columns = self - .right_non_join_schema + let columns = right_non_join_schema .fields .values() .map(|field| Series::full_null(&field.name, &field.dtype, left.len())) .collect::>(); - Table::new_unchecked(self.right_non_join_schema.clone(), columns, left.len()) + Table::new_unchecked(right_non_join_schema.clone(), columns, left.len()) }; let final_table = join_table.union(&left)?.union(&right)?; Ok(Some(Arc::new(MicroPartition::new_loaded( @@ -363,54 +382,113 @@ impl StreamingSink for OuterHashJoinProbeSink { #[instrument(skip_all, name = "OuterHashJoinProbeSink::execute")] fn execute( &self, - idx: usize, - input: &PipelineResultType, - state_handle: &StreamingSinkState, - ) -> DaftResult { - match idx { - 0 => { - state_handle.with_state_mut::(|state| { - state.initialize_probe_state( - input.as_probe_state().clone(), - self.join_type == JoinType::Outer, - ); - }); - Ok(StreamingSinkOutput::NeedMoreInput(None)) - } - _ => state_handle.with_state_mut::(|state| { - let input = input.as_data(); - if input.is_empty() { - let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); - return Ok(StreamingSinkOutput::NeedMoreInput(Some(empty))); - } - let out = match self.join_type { - JoinType::Left | JoinType::Right => self.probe_left_right(input, state), - JoinType::Outer => self.probe_outer(input, state), - _ => unreachable!( - "Only Left, Right, and Outer joins are supported in OuterHashJoinProbeSink" - ), - }?; - Ok(StreamingSinkOutput::NeedMoreInput(Some(out))) - }), + input: &Arc, + mut state: Box, + runtime_ref: &RuntimeRef, + ) -> StreamingSinkExecuteOutput { + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return MaybeFuture::Immediate(Ok(( + state, + StreamingSinkOutputType::NeedMoreInput(Some(empty)), + ))); } + + let join_type = self.join_type; + let probe_on = self.probe_on.clone(); + let common_join_keys = self.common_join_keys.clone(); + let left_non_join_columns = self.left_non_join_columns.clone(); + let right_non_join_columns = self.right_non_join_columns.clone(); + let input = input.clone(); + let fut = runtime_ref.spawn(async move { + let outer_join_state = state + .as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink should have OuterHashJoinProbeState"); + let probe_state = outer_join_state.get_or_build_probe_state().await; + let out = match join_type { + JoinType::Left | JoinType::Right => Self::probe_left_right( + &input, + &probe_state, + join_type, + &probe_on, + &common_join_keys, + &left_non_join_columns, + &right_non_join_columns, + ), + JoinType::Outer => { + let bitmap_builder = outer_join_state + .get_or_build_bitmap() + .await + .as_mut() + .expect("bitmap should be set"); + Self::probe_outer( + &input, + &probe_state, + bitmap_builder, + &probe_on, + &common_join_keys, + &left_non_join_columns, + &right_non_join_columns, + ) + } + _ => unreachable!( + "Only Left, Right, and Outer joins are supported in OuterHashJoinProbeSink" + ), + }?; + Ok((state, StreamingSinkOutputType::NeedMoreInput(Some(out)))) + }); + MaybeFuture::Future(fut) } fn name(&self) -> &'static str { "OuterHashJoinProbeSink" } - fn make_state(&self) -> Box { - Box::new(OuterHashJoinProbeState::Building) + fn make_state(&self) -> Box { + Box::new(OuterHashJoinState::Building( + self.probe_state_bridge.clone(), + self.join_type == JoinType::Outer, + )) } fn finalize( &self, - states: Vec>, - ) -> DaftResult>> { + states: Vec>, + runtime_ref: &RuntimeRef, + ) -> StreamingSinkFinalizeOutput { if self.join_type == JoinType::Outer { - self.finalize_outer(states) + let common_join_keys = self.common_join_keys.clone(); + let left_non_join_columns = self.left_non_join_columns.clone(); + let right_non_join_schema = self.right_non_join_schema.clone(); + let fut = runtime_ref.spawn(async move { + Self::finalize_outer( + states, + &common_join_keys, + &left_non_join_columns, + &right_non_join_schema, + ) + .await + }); + MaybeFuture::Future(fut) + } else { + MaybeFuture::Immediate(Ok(None)) + } + } + + fn make_dispatcher( + &self, + runtime_handle: &ExecutionRuntimeHandle, + maintain_order: bool, + ) -> Arc { + if maintain_order { + Arc::new(RoundRobinDispatcher::new(Some( + runtime_handle.default_morsel_size(), + ))) } else { - Ok(None) + Arc::new(UnorderedDispatcher::new(Some( + runtime_handle.default_morsel_size(), + ))) } } } diff --git a/src/daft-local-execution/src/sinks/pivot.rs b/src/daft-local-execution/src/sinks/pivot.rs index 93cb7b3843..c1a0efa926 100644 --- a/src/daft-local-execution/src/sinks/pivot.rs +++ b/src/daft-local-execution/src/sinks/pivot.rs @@ -6,7 +6,7 @@ use daft_micropartition::MicroPartition; use tracing::instrument; use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; -use crate::{pipeline::PipelineResultType, NUM_CPUS}; +use crate::NUM_CPUS; enum PivotState { Accumulating(Vec>), @@ -84,7 +84,7 @@ impl BlockingSink for PivotSink { fn finalize( &self, states: Vec>, - ) -> DaftResult> { + ) -> DaftResult>> { let all_parts = states.into_iter().flat_map(|mut state| { state .as_any_mut() @@ -109,7 +109,7 @@ impl BlockingSink for PivotSink { self.value_column.clone(), self.names.clone(), )?); - Ok(Some(pivoted.into())) + Ok(Some(pivoted)) } fn name(&self) -> &'static str { @@ -123,4 +123,13 @@ impl BlockingSink for PivotSink { fn make_state(&self) -> DaftResult> { Ok(Box::new(PivotState::Accumulating(vec![]))) } + + fn make_dispatcher( + &self, + runtime_handle: &crate::ExecutionRuntimeHandle, + ) -> Arc { + Arc::new(crate::dispatcher::UnorderedDispatcher::new(Some( + runtime_handle.default_morsel_size(), + ))) + } } diff --git a/src/daft-local-execution/src/sinks/sort.rs b/src/daft-local-execution/src/sinks/sort.rs index 83c933d1ec..5609cb4854 100644 --- a/src/daft-local-execution/src/sinks/sort.rs +++ b/src/daft-local-execution/src/sinks/sort.rs @@ -6,7 +6,7 @@ use daft_micropartition::MicroPartition; use tracing::instrument; use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; -use crate::{pipeline::PipelineResultType, NUM_CPUS}; +use crate::NUM_CPUS; enum SortState { Building(Vec>), @@ -71,7 +71,7 @@ impl BlockingSink for SortSink { fn finalize( &self, states: Vec>, - ) -> DaftResult> { + ) -> DaftResult>> { let parts = states.into_iter().flat_map(|mut state| { let state = state .as_any_mut() @@ -81,7 +81,7 @@ impl BlockingSink for SortSink { }); let concated = MicroPartition::concat(parts)?; let sorted = Arc::new(concated.sort(&self.sort_by, &self.descending)?); - Ok(Some(sorted.into())) + Ok(Some(sorted)) } fn name(&self) -> &'static str { @@ -95,4 +95,13 @@ impl BlockingSink for SortSink { fn max_concurrency(&self) -> usize { *NUM_CPUS } + + fn make_dispatcher( + &self, + runtime_handle: &crate::ExecutionRuntimeHandle, + ) -> Arc { + Arc::new(crate::dispatcher::UnorderedDispatcher::new(Some( + runtime_handle.default_morsel_size(), + ))) + } } diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 9ebab6f1a3..1ef4ec0ae8 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -1,82 +1,70 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use common_display::tree::TreeDisplay; use common_error::DaftResult; -use common_runtime::get_compute_runtime; +use common_runtime::{get_compute_runtime, RuntimeRef}; use daft_micropartition::MicroPartition; use snafu::ResultExt; use tracing::{info_span, instrument}; use crate::{ - channel::{create_channel, PipelineChannel, Receiver, Sender}, - pipeline::{PipelineNode, PipelineResultType}, - runtime_stats::{CountingReceiver, RuntimeStatsContext}, - ExecutionRuntimeHandle, JoinSnafu, TaskSet, NUM_CPUS, + channel::{create_channel, create_ordering_aware_receiver_channel, Receiver, Sender}, + dispatcher::Dispatcher, + pipeline::PipelineNode, + runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, + ExecutionRuntimeHandle, JoinSnafu, MaybeFuture, TaskSet, NUM_CPUS, }; -pub trait DynStreamingSinkState: Send + Sync { +pub trait StreamingSinkState: Send + Sync { fn as_any_mut(&mut self) -> &mut dyn std::any::Any; } -pub(crate) struct StreamingSinkState { - inner: Mutex>, -} - -impl StreamingSinkState { - fn new(inner: Box) -> Arc { - Arc::new(Self { - inner: Mutex::new(inner), - }) - } - - pub(crate) fn with_state_mut(&self, f: F) -> R - where - F: FnOnce(&mut T) -> R, - { - let mut guard = self.inner.lock().unwrap(); - let state = guard - .as_any_mut() - .downcast_mut::() - .expect("State type mismatch"); - f(state) - } -} - -pub enum StreamingSinkOutput { +pub(crate) enum StreamingSinkOutputType { NeedMoreInput(Option>), #[allow(dead_code)] HasMoreOutput(Arc), Finished(Option>), } +pub(crate) type StreamingSinkExecuteOutput = + MaybeFuture, StreamingSinkOutputType)>>; +pub(crate) type StreamingSinkFinalizeOutput = MaybeFuture>>>; + pub trait StreamingSink: Send + Sync { /// Execute the StreamingSink operator on the morsel of input data, /// received from the child with the given index, /// with the given state. fn execute( &self, - index: usize, - input: &PipelineResultType, - state_handle: &StreamingSinkState, - ) -> DaftResult; + input: &Arc, + state_handle: Box, + runtime: &RuntimeRef, + ) -> StreamingSinkExecuteOutput; /// Finalize the StreamingSink operator, with the given states from each worker. fn finalize( &self, - states: Vec>, - ) -> DaftResult>>; + states: Vec>, + runtime: &RuntimeRef, + ) -> StreamingSinkFinalizeOutput; /// The name of the StreamingSink operator. fn name(&self) -> &'static str; /// Create a new worker-local state for this StreamingSink. - fn make_state(&self) -> Box; + fn make_state(&self) -> Box; /// The maximum number of concurrent workers that can be spawned for this sink. /// Each worker will has its own StreamingSinkState. fn max_concurrency(&self) -> usize { *NUM_CPUS } + + fn make_dispatcher( + &self, + runtime_handle: &ExecutionRuntimeHandle, + maintain_order: bool, + ) -> Arc; } pub struct StreamingSinkNode { @@ -104,39 +92,35 @@ impl StreamingSinkNode { #[instrument(level = "info", skip_all, name = "StreamingSink::run_worker")] async fn run_worker( op: Arc, - mut input_receiver: Receiver<(usize, PipelineResultType)>, + input_receiver: Receiver>, output_sender: Sender>, rt_context: Arc, - ) -> DaftResult> { + ) -> DaftResult> { let span = info_span!("StreamingSink::Execute"); let compute_runtime = get_compute_runtime(); - let state_wrapper = StreamingSinkState::new(op.make_state()); + let mut state = op.make_state(); let mut finished = false; - while let Some((idx, morsel)) = input_receiver.recv().await { + while let Some(morsel) = input_receiver.recv().await { if finished { break; } loop { - let op = op.clone(); - let morsel = morsel.clone(); - let span = span.clone(); - let rt_context = rt_context.clone(); - let state_wrapper = state_wrapper.clone(); - let fut = async move { - rt_context.in_span(&span, || op.execute(idx, &morsel, state_wrapper.as_ref())) - }; - let result = compute_runtime.spawn(fut).await??; - match result { - StreamingSinkOutput::NeedMoreInput(mp) => { + let result = rt_context + .in_span(&span, || op.execute(&morsel, state, &compute_runtime)) + .output() + .await??; + state = result.0; + match result.1 { + StreamingSinkOutputType::NeedMoreInput(mp) => { if let Some(mp) = mp { let _ = output_sender.send(mp).await; } break; } - StreamingSinkOutput::HasMoreOutput(mp) => { + StreamingSinkOutputType::HasMoreOutput(mp) => { let _ = output_sender.send(mp).await; } - StreamingSinkOutput::Finished(mp) => { + StreamingSinkOutputType::Finished(mp) => { if let Some(mp) = mp { let _ = output_sender.send(mp).await; } @@ -147,58 +131,24 @@ impl StreamingSinkNode { } } - // Take the state out of the Arc and Mutex because we need to return it. - // It should be guaranteed that the ONLY holder of state at this point is this function. - Ok(Arc::into_inner(state_wrapper) - .expect("Completed worker should have exclusive access to state wrapper") - .inner - .into_inner() - .expect("Completed worker should have exclusive access to inner state")) + Ok(state) } fn spawn_workers( op: Arc, - input_receivers: Vec>, - task_set: &mut TaskSet>>, + input_receivers: Vec>>, + output_senders: Vec>>, + task_set: &mut TaskSet>>, stats: Arc, - ) -> Receiver> { - let (output_sender, output_receiver) = create_channel(input_receivers.len()); - for input_receiver in input_receivers { + ) { + for (input_receiver, output_sender) in input_receivers.into_iter().zip(output_senders) { task_set.spawn(Self::run_worker( op.clone(), input_receiver, - output_sender.clone(), + output_sender, stats.clone(), )); } - output_receiver - } - - // Forwards input from the children to the workers in a round-robin fashion. - // Always exhausts the input from one child before moving to the next. - async fn forward_input_to_workers( - receivers: Vec, - worker_senders: Vec>, - ) -> DaftResult<()> { - let mut next_worker_idx = 0; - let mut send_to_next_worker = |idx, data: PipelineResultType| { - let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); - next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); - next_worker_sender.send((idx, data)) - }; - - for (idx, mut receiver) in receivers.into_iter().enumerate() { - while let Some(morsel) = receiver.recv().await { - if morsel.should_broadcast() { - for worker_sender in &worker_senders { - let _ = worker_sender.send((idx, morsel.clone())).await; - } - } else { - let _ = send_to_next_worker(idx, morsel.clone()).await; - } - } - } - Ok(()) } } @@ -236,41 +186,48 @@ impl PipelineNode for StreamingSinkNode { } fn start( - &mut self, + &self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + ) -> crate::Result>> { let mut child_result_receivers = Vec::with_capacity(self.children.len()); - for child in &mut self.children { + for child in &self.children { let child_result_channel = child.start(maintain_order, runtime_handle)?; - child_result_receivers - .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats.clone())); + let counting_receiver = + CountingReceiver::new(child_result_channel, self.runtime_stats.clone()); + child_result_receivers.push(counting_receiver); } - let mut destination_channel = PipelineChannel::new(1, maintain_order); + let (destination_sender, destination_receiver) = create_channel(1); let destination_sender = - destination_channel.get_next_sender_with_stats(&self.runtime_stats); + CountingSender::new(destination_sender, self.runtime_stats.clone()); let op = self.op.clone(); let runtime_stats = self.runtime_stats.clone(); let num_workers = op.max_concurrency(); - let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); - runtime_handle.spawn( - Self::forward_input_to_workers(child_result_receivers, input_senders), + + let dispatcher = op.make_dispatcher(runtime_handle, maintain_order); + let input_receivers = dispatcher.dispatch( + child_result_receivers, + num_workers, + runtime_handle, self.name(), ); runtime_handle.spawn( async move { let mut task_set = TaskSet::new(); - let mut output_receiver = Self::spawn_workers( + let (output_senders, mut output_receiver) = + create_ordering_aware_receiver_channel(maintain_order, num_workers); + Self::spawn_workers( op.clone(), input_receivers, + output_senders, &mut task_set, runtime_stats.clone(), ); while let Some(morsel) = output_receiver.recv().await { - let _ = destination_sender.send(morsel.into()).await; + let _ = destination_sender.send(morsel).await; } let mut finished_states = Vec::with_capacity(num_workers); @@ -280,21 +237,20 @@ impl PipelineNode for StreamingSinkNode { } let compute_runtime = get_compute_runtime(); - let finalized_result = compute_runtime - .spawn(async move { - runtime_stats.in_span(&info_span!("StreamingSinkNode::finalize"), || { - op.finalize(finished_states) - }) + let finalized_result = runtime_stats + .in_span(&info_span!("StreamingSinkNode::finalize"), || { + op.finalize(finished_states, &compute_runtime) }) + .output() .await??; if let Some(res) = finalized_result { - let _ = destination_sender.send(res.into()).await; + let _ = destination_sender.send(res).await; } Ok(()) }, self.name(), ); - Ok(destination_channel) + Ok(destination_receiver) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/sinks/write.rs b/src/daft-local-execution/src/sinks/write.rs index 002f32a25a..813f4cf09f 100644 --- a/src/daft-local-execution/src/sinks/write.rs +++ b/src/daft-local-execution/src/sinks/write.rs @@ -10,8 +10,7 @@ use tracing::instrument; use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; use crate::{ - dispatcher::{Dispatcher, PartitionedDispatcher, RoundRobinBufferedDispatcher}, - pipeline::PipelineResultType, + dispatcher::{Dispatcher, PartitionedDispatcher, UnorderedDispatcher}, NUM_CPUS, }; @@ -83,7 +82,7 @@ impl BlockingSink for WriteSink { fn finalize( &self, states: Vec>, - ) -> DaftResult> { + ) -> DaftResult>> { let mut results = vec![]; for mut state in states { let state = state @@ -97,7 +96,7 @@ impl BlockingSink for WriteSink { results.into(), None, )); - Ok(Some(mp.into())) + Ok(Some(mp)) } fn name(&self) -> &'static str { @@ -116,14 +115,12 @@ impl BlockingSink for WriteSink { fn make_dispatcher( &self, - runtime_handle: &crate::ExecutionRuntimeHandle, + _runtime_handle: &crate::ExecutionRuntimeHandle, ) -> Arc { if let Some(partition_by) = &self.partition_by { Arc::new(PartitionedDispatcher::new(partition_by.clone())) } else { - Arc::new(RoundRobinBufferedDispatcher::new( - runtime_handle.default_morsel_size(), - )) + Arc::new(UnorderedDispatcher::new(None)) } } diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index 1bd4bac0d7..190ba71838 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -14,7 +14,6 @@ use daft_parquet::read::{read_parquet_bulk_async, ParquetSchemaInferenceOptions} use daft_scan::{storage_config::StorageConfig, ChunkSpec, ScanTask}; use futures::{Stream, StreamExt}; use snafu::ResultExt; -use tokio_stream::wrappers::ReceiverStream; use tracing::instrument; use crate::{ @@ -105,7 +104,7 @@ impl Source for ScanTaskSource { self.name(), ); - let stream = futures::stream::iter(receivers.into_iter().map(ReceiverStream::new)); + let stream = futures::stream::iter(receivers.into_iter().map(|r| r.into_stream())); Ok(Box::pin(stream.flatten())) } diff --git a/src/daft-local-execution/src/sources/source.rs b/src/daft-local-execution/src/sources/source.rs index 8c55401db2..0e7e740c3c 100644 --- a/src/daft-local-execution/src/sources/source.rs +++ b/src/daft-local-execution/src/sources/source.rs @@ -6,13 +6,15 @@ use daft_micropartition::MicroPartition; use futures::{stream::BoxStream, StreamExt}; use crate::{ - channel::PipelineChannel, pipeline::PipelineNode, runtime_stats::RuntimeStatsContext, + channel::{create_channel, Receiver}, + pipeline::PipelineNode, + runtime_stats::{CountingSender, RuntimeStatsContext}, ExecutionRuntimeHandle, }; -pub type SourceStream<'a> = BoxStream<'a, Arc>; +pub(crate) type SourceStream<'a> = BoxStream<'a, Arc>; -pub trait Source: Send + Sync { +pub(crate) trait Source: Send + Sync { fn name(&self) -> &'static str; fn get_data( &self, @@ -66,26 +68,26 @@ impl PipelineNode for SourceNode { vec![] } fn start( - &mut self, + &self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + ) -> crate::Result>> { let mut source_stream = self.source .get_data(maintain_order, runtime_handle, self.io_stats.clone())?; - let mut channel = PipelineChannel::new(1, maintain_order); - let counting_sender = channel.get_next_sender_with_stats(&self.runtime_stats); + let (tx, rx) = create_channel(1); + let counting_sender = CountingSender::new(tx, self.runtime_stats.clone()); runtime_handle.spawn( async move { while let Some(part) = source_stream.next().await { - let _ = counting_sender.send(part.into()).await; + let _ = counting_sender.send(part).await; } Ok(()) }, self.name(), ); - Ok(channel) + Ok(rx) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/tests/dataframe/test_unpivot.py b/tests/dataframe/test_unpivot.py index 1767306fc8..6bd2a5ed8c 100644 --- a/tests/dataframe/test_unpivot.py +++ b/tests/dataframe/test_unpivot.py @@ -4,7 +4,7 @@ from daft.datatype import DataType -@pytest.mark.parametrize("n_partitions", [2]) +@pytest.mark.parametrize("n_partitions", [1, 2, 4]) def test_unpivot(make_df, n_partitions, with_morsel_size): df = make_df( {