From c36af6a91b6cc4a5313409e727387b6209991db0 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 24 Oct 2024 11:28:11 -0700 Subject: [PATCH 01/10] parallel buffered sinks, pivot fix, morsel size 1 for tests --- src/daft-local-execution/src/buffer.rs | 97 +++++++++ .../src/intermediate_ops/aggregate.rs | 43 ---- .../src/intermediate_ops/buffer.rs | 91 -------- .../src/intermediate_ops/intermediate_op.rs | 40 ++-- .../src/intermediate_ops/mod.rs | 2 - .../src/intermediate_ops/pivot.rs | 5 + src/daft-local-execution/src/lib.rs | 16 +- src/daft-local-execution/src/pipeline.rs | 81 ++----- .../src/sinks/aggregate.rs | 147 +++++++++---- .../src/sinks/blocking_sink.rs | 198 +++++++++++++++--- src/daft-local-execution/src/sinks/concat.rs | 5 + .../src/sinks/hash_join_build.rs | 66 ++++-- src/daft-local-execution/src/sinks/limit.rs | 5 + src/daft-local-execution/src/sinks/sort.rs | 118 +++++++---- .../src/sinks/streaming_sink.rs | 25 ++- tests/cookbook/conftest.py | 6 + tests/dataframe/test_aggregations.py | 14 +- tests/dataframe/test_approx_count_distinct.py | 7 + .../test_approx_percentiles_aggregations.py | 7 + tests/dataframe/test_concat.py | 8 + tests/dataframe/test_distinct.py | 7 + tests/dataframe/test_iter.py | 74 ++++--- tests/dataframe/test_joins.py | 14 +- tests/dataframe/test_map_groups.py | 6 + tests/dataframe/test_pivot.py | 8 + tests/dataframe/test_sort.py | 7 + tests/dataframe/test_stddev.py | 6 + tests/dataframe/test_unpivot.py | 7 + 28 files changed, 716 insertions(+), 394 deletions(-) create mode 100644 src/daft-local-execution/src/buffer.rs delete mode 100644 src/daft-local-execution/src/intermediate_ops/aggregate.rs delete mode 100644 src/daft-local-execution/src/intermediate_ops/buffer.rs diff --git a/src/daft-local-execution/src/buffer.rs b/src/daft-local-execution/src/buffer.rs new file mode 100644 index 0000000000..5b3cc60f3b --- /dev/null +++ b/src/daft-local-execution/src/buffer.rs @@ -0,0 +1,97 @@ +use std::{cmp::Ordering::*, collections::VecDeque, sync::Arc}; + +use common_error::DaftResult; +use daft_micropartition::MicroPartition; + +// A buffer that accumulates morsels until a threshold is reached +pub struct RowBasedBuffer { + pub buffer: VecDeque>, + pub curr_len: usize, + pub threshold: usize, +} + +impl RowBasedBuffer { + pub fn new(threshold: usize) -> Self { + assert!(threshold > 0); + Self { + buffer: VecDeque::new(), + curr_len: 0, + threshold, + } + } + + // Push a morsel to the buffer + pub fn push(&mut self, part: Arc) { + self.curr_len += part.len(); + self.buffer.push_back(part); + } + + // Pop enough morsels that reach the threshold + // - If the buffer currently has not enough morsels, return None + // - If the buffer has exactly enough morsels, return the morsels + // - If the buffer has more than enough morsels, return a vec of morsels, each correctly sized to the threshold. + // The remaining morsels will be pushed back to the buffer + pub fn pop_enough(&mut self) -> DaftResult>>> { + match self.curr_len.cmp(&self.threshold) { + Less => Ok(None), + Equal => { + if self.buffer.len() == 1 { + let part = self.buffer.pop_front().unwrap(); + self.curr_len = 0; + Ok(Some(vec![part])) + } else { + let chunk = MicroPartition::concat( + &std::mem::take(&mut self.buffer) + .iter() + .map(|x| x.as_ref()) + .collect::>(), + )?; + self.curr_len = 0; + Ok(Some(vec![Arc::new(chunk)])) + } + } + Greater => { + let num_ready_chunks = self.curr_len / self.threshold; + let concated = MicroPartition::concat( + &std::mem::take(&mut self.buffer) + .iter() + .map(|x| x.as_ref()) + .collect::>(), + )?; + let mut start = 0; + let mut parts_to_return = Vec::with_capacity(num_ready_chunks); + for _ in 0..num_ready_chunks { + let end = start + self.threshold; + let part = Arc::new(concated.slice(start, end)?); + parts_to_return.push(part); + start = end; + } + if start < concated.len() { + let part = Arc::new(concated.slice(start, concated.len())?); + self.curr_len = part.len(); + self.buffer.push_back(part); + } else { + self.curr_len = 0; + } + Ok(Some(parts_to_return)) + } + } + } + + // Pop all morsels in the buffer regardless of the threshold + pub fn pop_all(&mut self) -> DaftResult>> { + assert!(self.curr_len < self.threshold); + if self.buffer.is_empty() { + Ok(None) + } else { + let concated = MicroPartition::concat( + &std::mem::take(&mut self.buffer) + .iter() + .map(|x| x.as_ref()) + .collect::>(), + )?; + self.curr_len = 0; + Ok(Some(Arc::new(concated))) + } + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/aggregate.rs b/src/daft-local-execution/src/intermediate_ops/aggregate.rs deleted file mode 100644 index 4b8fa7bbb6..0000000000 --- a/src/daft-local-execution/src/intermediate_ops/aggregate.rs +++ /dev/null @@ -1,43 +0,0 @@ -use std::sync::Arc; - -use common_error::DaftResult; -use daft_dsl::ExprRef; -use tracing::instrument; - -use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, -}; -use crate::pipeline::PipelineResultType; - -pub struct AggregateOperator { - agg_exprs: Vec, - group_by: Vec, -} - -impl AggregateOperator { - pub fn new(agg_exprs: Vec, group_by: Vec) -> Self { - Self { - agg_exprs, - group_by, - } - } -} - -impl IntermediateOperator for AggregateOperator { - #[instrument(skip_all, name = "AggregateOperator::execute")] - fn execute( - &self, - _idx: usize, - input: &PipelineResultType, - _state: &IntermediateOperatorState, - ) -> DaftResult { - let out = input.as_data().agg(&self.agg_exprs, &self.group_by)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, - )))) - } - - fn name(&self) -> &'static str { - "AggregateOperator" - } -} diff --git a/src/daft-local-execution/src/intermediate_ops/buffer.rs b/src/daft-local-execution/src/intermediate_ops/buffer.rs deleted file mode 100644 index 3c66301610..0000000000 --- a/src/daft-local-execution/src/intermediate_ops/buffer.rs +++ /dev/null @@ -1,91 +0,0 @@ -use std::{ - cmp::Ordering::{Equal, Greater, Less}, - collections::VecDeque, - sync::Arc, -}; - -use common_error::DaftResult; -use daft_micropartition::MicroPartition; - -pub struct OperatorBuffer { - pub buffer: VecDeque>, - pub curr_len: usize, - pub threshold: usize, -} - -impl OperatorBuffer { - pub fn new(threshold: usize) -> Self { - assert!(threshold > 0); - Self { - buffer: VecDeque::new(), - curr_len: 0, - threshold, - } - } - - pub fn push(&mut self, part: Arc) { - self.curr_len += part.len(); - self.buffer.push_back(part); - } - - pub fn try_clear(&mut self) -> Option>> { - match self.curr_len.cmp(&self.threshold) { - Less => None, - Equal => self.clear_all(), - Greater => Some(self.clear_enough()), - } - } - - fn clear_enough(&mut self) -> DaftResult> { - assert!(self.curr_len > self.threshold); - - let mut to_concat = Vec::with_capacity(self.buffer.len()); - let mut remaining = self.threshold; - - while remaining > 0 { - let part = self.buffer.pop_front().expect("Buffer should not be empty"); - let part_len = part.len(); - if part_len <= remaining { - remaining -= part_len; - to_concat.push(part); - } else { - let (head, tail) = part.split_at(remaining)?; - remaining = 0; - to_concat.push(Arc::new(head)); - self.buffer.push_front(Arc::new(tail)); - break; - } - } - assert_eq!(remaining, 0); - - self.curr_len -= self.threshold; - match to_concat.len() { - 1 => Ok(to_concat.pop().unwrap()), - _ => MicroPartition::concat( - &to_concat - .iter() - .map(std::convert::AsRef::as_ref) - .collect::>(), - ) - .map(Arc::new), - } - } - - pub fn clear_all(&mut self) -> Option>> { - if self.buffer.is_empty() { - return None; - } - - let concated = MicroPartition::concat( - &self - .buffer - .iter() - .map(std::convert::AsRef::as_ref) - .collect::>(), - ) - .map(Arc::new); - self.buffer.clear(); - self.curr_len = 0; - Some(concated) - } -} diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 412d7641a7..b53cfdbd34 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -1,13 +1,14 @@ use std::sync::{Arc, Mutex}; +use common_daft_config::DaftExecutionConfig; use common_display::tree::TreeDisplay; use common_error::DaftResult; use common_runtime::get_compute_runtime; use daft_micropartition::MicroPartition; use tracing::{info_span, instrument}; -use super::buffer::OperatorBuffer; use crate::{ + buffer::RowBasedBuffer, channel::{create_channel, PipelineChannel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, @@ -66,6 +67,9 @@ pub trait IntermediateOperator: Send + Sync { fn make_state(&self) -> Box { Box::new(DefaultIntermediateOperatorState {}) } + fn morsel_size(&self) -> Option { + Some(DaftExecutionConfig::default().default_morsel_size) + } } pub struct IntermediateNode { @@ -165,8 +169,9 @@ impl IntermediateNode { pub async fn send_to_workers( receivers: Vec, worker_senders: Vec>, - morsel_size: usize, + morsel_size: Option, ) -> DaftResult<()> { + println!("morsel_size: {:?}", morsel_size); let mut next_worker_idx = 0; let mut send_to_next_worker = |idx, data: PipelineResultType| { let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); @@ -175,26 +180,28 @@ impl IntermediateNode { }; for (idx, mut receiver) in receivers.into_iter().enumerate() { - let mut buffer = OperatorBuffer::new(morsel_size); + let mut buffer = morsel_size.map(RowBasedBuffer::new); while let Some(morsel) = receiver.recv().await { if morsel.should_broadcast() { for worker_sender in &worker_senders { let _ = worker_sender.send((idx, morsel.clone())).await; } - } else { + } else if let Some(buffer) = buffer.as_mut() { buffer.push(morsel.as_data().clone()); - if let Some(ready) = buffer.try_clear() { - let _ = send_to_next_worker(idx, ready?.into()).await; + if let Some(ready) = buffer.pop_enough()? { + for r in ready { + let _ = send_to_next_worker(idx, r.into()).await; + } } + } else { + let _ = send_to_next_worker(idx, morsel).await; } } - // Buffer may still have some morsels left above the threshold - while let Some(ready) = buffer.try_clear() { - let _ = send_to_next_worker(idx, ready?.into()).await; - } - // Clear all remaining morsels - if let Some(last_morsel) = buffer.clear_all() { - let _ = send_to_next_worker(idx, last_morsel?.into()).await; + if let Some(buffer) = buffer.as_mut() { + // Clear all remaining morsels + if let Some(last_morsel) = buffer.pop_all()? { + let _ = send_to_next_worker(idx, last_morsel.into()).await; + } } } Ok(()) @@ -247,12 +254,9 @@ impl PipelineNode for IntermediateNode { let worker_senders = self.spawn_workers(*NUM_CPUS, &mut destination_channel, runtime_handle); + let morsel_size = runtime_handle.determine_morsel_size(self.intermediate_op.morsel_size()); runtime_handle.spawn( - Self::send_to_workers( - child_result_receivers, - worker_senders, - runtime_handle.default_morsel_size(), - ), + Self::send_to_workers(child_result_receivers, worker_senders, morsel_size), self.intermediate_op.name(), ); Ok(destination_channel) diff --git a/src/daft-local-execution/src/intermediate_ops/mod.rs b/src/daft-local-execution/src/intermediate_ops/mod.rs index 7d97464e24..c7598c8e7a 100644 --- a/src/daft-local-execution/src/intermediate_ops/mod.rs +++ b/src/daft-local-execution/src/intermediate_ops/mod.rs @@ -1,6 +1,4 @@ -pub mod aggregate; pub mod anti_semi_hash_join_probe; -pub mod buffer; pub mod explode; pub mod filter; pub mod inner_hash_join_probe; diff --git a/src/daft-local-execution/src/intermediate_ops/pivot.rs b/src/daft-local-execution/src/intermediate_ops/pivot.rs index afac5f9b02..c54ea3e69b 100644 --- a/src/daft-local-execution/src/intermediate_ops/pivot.rs +++ b/src/daft-local-execution/src/intermediate_ops/pivot.rs @@ -54,4 +54,9 @@ impl IntermediateOperator for PivotOperator { fn name(&self) -> &'static str { "PivotOperator" } + + // Don't buffer the input to pivot because it depends on the full output of agg. + fn morsel_size(&self) -> Option { + None + } } diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index e9b7e08b96..399363c1fd 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -1,4 +1,5 @@ #![feature(let_chains)] +mod buffer; mod channel; mod intermediate_ops; mod pipeline; @@ -7,6 +8,7 @@ mod runtime_stats; mod sinks; mod sources; +use common_daft_config::DaftExecutionConfig; use common_error::{DaftError, DaftResult}; use lazy_static::lazy_static; pub use run::NativeExecutor; @@ -74,9 +76,17 @@ impl ExecutionRuntimeHandle { self.worker_set.shutdown().await; } - #[must_use] - pub fn default_morsel_size(&self) -> usize { - self.default_morsel_size + pub fn determine_morsel_size(&self, operator_morsel_size: Option) -> Option { + match operator_morsel_size { + None => None, + Some(_) + if self.default_morsel_size + != DaftExecutionConfig::default().default_morsel_size => + { + Some(self.default_morsel_size) + } + size => size, + } } } diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index eccece1a56..fc95182a65 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -7,13 +7,13 @@ use daft_core::{ prelude::{Schema, SchemaRef}, utils::supertype, }; -use daft_dsl::{col, join::get_common_join_keys, Expr}; +use daft_dsl::join::get_common_join_keys; use daft_micropartition::MicroPartition; use daft_physical_plan::{ Concat, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, Pivot, Project, Sample, Sort, UnGroupedAggregate, Unpivot, }; -use daft_plan::{populate_aggregation_stages, JoinType}; +use daft_plan::JoinType; use daft_table::ProbeState; use indexmap::IndexSet; use snafu::ResultExt; @@ -21,11 +21,10 @@ use snafu::ResultExt; use crate::{ channel::PipelineChannel, intermediate_ops::{ - aggregate::AggregateOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator, - explode::ExplodeOperator, filter::FilterOperator, - inner_hash_join_probe::InnerHashJoinProbeOperator, intermediate_op::IntermediateNode, - pivot::PivotOperator, project::ProjectOperator, sample::SampleOperator, - unpivot::UnpivotOperator, + anti_semi_hash_join_probe::AntiSemiProbeOperator, explode::ExplodeOperator, + filter::FilterOperator, inner_hash_join_probe::InnerHashJoinProbeOperator, + intermediate_op::IntermediateNode, pivot::PivotOperator, project::ProjectOperator, + sample::SampleOperator, unpivot::UnpivotOperator, }, sinks::{ aggregate::AggregateSink, blocking_sink::BlockingSinkNode, concat::ConcatSink, @@ -172,34 +171,9 @@ pub fn physical_plan_to_pipeline( schema, .. }) => { - let (first_stage_aggs, second_stage_aggs, final_exprs) = - populate_aggregation_stages(aggregations, schema, &[]); - let first_stage_agg_op = AggregateOperator::new( - first_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - vec![], - ); let child_node = physical_plan_to_pipeline(input, psets)?; - let post_first_agg_node = - IntermediateNode::new(Arc::new(first_stage_agg_op), vec![child_node]).boxed(); - - let second_stage_agg_sink = AggregateSink::new( - second_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - vec![], - ); - let second_stage_node = - BlockingSinkNode::new(second_stage_agg_sink.boxed(), post_first_agg_node).boxed(); - - let final_stage_project = ProjectOperator::new(final_exprs); - - IntermediateNode::new(Arc::new(final_stage_project), vec![second_stage_node]).boxed() + let agg_sink = AggregateSink::new(aggregations, &[], schema); + BlockingSinkNode::new(Arc::new(agg_sink), child_node).boxed() } LocalPhysicalPlan::HashAggregate(HashAggregate { input, @@ -208,40 +182,9 @@ pub fn physical_plan_to_pipeline( schema, .. }) => { - let (first_stage_aggs, second_stage_aggs, final_exprs) = - populate_aggregation_stages(aggregations, schema, group_by); let child_node = physical_plan_to_pipeline(input, psets)?; - let (post_first_agg_node, group_by) = if !first_stage_aggs.is_empty() { - let agg_op = AggregateOperator::new( - first_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - group_by.clone(), - ); - ( - IntermediateNode::new(Arc::new(agg_op), vec![child_node]).boxed(), - &group_by.iter().map(|e| col(e.name())).collect(), - ) - } else { - (child_node, group_by) - }; - - let second_stage_agg_sink = AggregateSink::new( - second_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - group_by.clone(), - ); - let second_stage_node = - BlockingSinkNode::new(second_stage_agg_sink.boxed(), post_first_agg_node).boxed(); - - let final_stage_project = ProjectOperator::new(final_exprs); - - IntermediateNode::new(Arc::new(final_stage_project), vec![second_stage_node]).boxed() + let agg_sink = AggregateSink::new(aggregations, group_by, schema); + BlockingSinkNode::new(Arc::new(agg_sink), child_node).boxed() } LocalPhysicalPlan::Unpivot(Unpivot { input, @@ -285,7 +228,7 @@ pub fn physical_plan_to_pipeline( }) => { let sort_sink = SortSink::new(sort_by.clone(), descending.clone()); let child_node = physical_plan_to_pipeline(input, psets)?; - BlockingSinkNode::new(sort_sink.boxed(), child_node).boxed() + BlockingSinkNode::new(Arc::new(sort_sink), child_node).boxed() } LocalPhysicalPlan::HashJoin(HashJoin { @@ -356,7 +299,7 @@ pub fn physical_plan_to_pipeline( let build_sink = HashJoinBuildSink::new(key_schema, casted_build_on, join_type)?; let build_child_node = physical_plan_to_pipeline(build_child, psets)?; let build_node = - BlockingSinkNode::new(build_sink.boxed(), build_child_node).boxed(); + BlockingSinkNode::new(Arc::new(build_sink), build_child_node).boxed(); let probe_child_node = physical_plan_to_pipeline(probe_child, psets)?; diff --git a/src/daft-local-execution/src/sinks/aggregate.rs b/src/daft-local-execution/src/sinks/aggregate.rs index e94ff7c68b..8c38c81a4f 100644 --- a/src/daft-local-execution/src/sinks/aggregate.rs +++ b/src/daft-local-execution/src/sinks/aggregate.rs @@ -1,71 +1,140 @@ use std::sync::Arc; use common_error::DaftResult; -use daft_dsl::ExprRef; +use daft_core::prelude::SchemaRef; +use daft_dsl::{col, AggExpr, Expr, ExprRef}; use daft_micropartition::MicroPartition; +use daft_plan::populate_aggregation_stages; use tracing::instrument; -use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; -use crate::pipeline::PipelineResultType; +use super::blocking_sink::{ + BlockingSink, BlockingSinkState, BlockingSinkStatus, DynBlockingSinkState, +}; +use crate::{pipeline::PipelineResultType, NUM_CPUS}; enum AggregateState { Accumulating(Vec>), - #[allow(dead_code)] - Done(Arc), + Done, +} + +impl AggregateState { + fn push(&mut self, part: Arc) { + if let Self::Accumulating(ref mut parts) = self { + parts.push(part); + } else { + panic!("AggregateSink should be in Accumulating state"); + } + } + + fn finalize(&mut self) -> DaftResult>> { + let res = if let Self::Accumulating(ref mut parts) = self { + std::mem::take(parts) + } else { + panic!("AggregateSink should be in Accumulating state"); + }; + *self = Self::Done; + Ok(res) + } +} + +impl DynBlockingSinkState for AggregateState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } } pub struct AggregateSink { - agg_exprs: Vec, - group_by: Vec, - state: AggregateState, + sink_aggs: Vec, + finalize_aggs: Vec, + final_projections: Vec, + sink_group_by: Vec, + finalize_group_by: Vec, } impl AggregateSink { - pub fn new(agg_exprs: Vec, group_by: Vec) -> Self { + pub fn new(agg_exprs: &[AggExpr], group_by: &[ExprRef], schema: &SchemaRef) -> Self { + let (sink_aggs, finalize_aggs, final_projections) = + populate_aggregation_stages(agg_exprs, schema, group_by); + let sink_aggs = sink_aggs + .values() + .cloned() + .map(|e| Arc::new(Expr::Agg(e))) + .collect::>(); + let finalize_aggs = finalize_aggs + .values() + .cloned() + .map(|e| Arc::new(Expr::Agg(e))) + .collect::>(); + let finalize_group_by = if sink_aggs.is_empty() { + group_by.to_vec() + } else { + group_by.iter().map(|e| col(e.name())).collect() + }; Self { - agg_exprs, - group_by, - state: AggregateState::Accumulating(vec![]), + sink_aggs, + finalize_aggs, + final_projections, + sink_group_by: group_by.to_vec(), + finalize_group_by, } } - - pub fn boxed(self) -> Box { - Box::new(self) - } } impl BlockingSink for AggregateSink { #[instrument(skip_all, name = "AggregateSink::sink")] - fn sink(&mut self, input: &Arc) -> DaftResult { - if let AggregateState::Accumulating(parts) = &mut self.state { - parts.push(input.clone()); + fn sink( + &self, + input: &Arc, + state_handle: &BlockingSinkState, + ) -> DaftResult { + state_handle.with_state_mut::(|state| { + if self.sink_aggs.is_empty() { + state.push(input.clone()); + } else { + let agged = input.agg(&self.sink_aggs, &self.sink_group_by)?; + state.push(agged.into()); + } Ok(BlockingSinkStatus::NeedMoreInput) - } else { - panic!("AggregateSink should be in Accumulating state"); - } + }) } #[instrument(skip_all, name = "AggregateSink::finalize")] - fn finalize(&mut self) -> DaftResult> { - if let AggregateState::Accumulating(parts) = &mut self.state { - assert!( - !parts.is_empty(), - "We can not finalize AggregateSink with no data" - ); - let concated = MicroPartition::concat( - &parts - .iter() - .map(std::convert::AsRef::as_ref) - .collect::>(), - )?; - let agged = Arc::new(concated.agg(&self.agg_exprs, &self.group_by)?); - self.state = AggregateState::Done(agged.clone()); - Ok(Some(agged.into())) - } else { - panic!("AggregateSink should be in Accumulating state"); + fn finalize( + &self, + states: Vec>, + ) -> DaftResult> { + let mut all_parts = vec![]; + for mut state in states { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("State type mismatch"); + all_parts.extend(state.finalize()?); } + assert!( + !all_parts.is_empty(), + "We can not finalize AggregateSink with no data" + ); + let concated = MicroPartition::concat( + &all_parts + .iter() + .map(std::convert::AsRef::as_ref) + .collect::>(), + )?; + let agged = Arc::new(concated.agg(&self.finalize_aggs, &self.finalize_group_by)?); + let projected = Arc::new(agged.eval_expression_list(&self.final_projections)?); + Ok(Some(projected.into())) } + fn name(&self) -> &'static str { "AggregateSink" } + + fn max_concurrency(&self) -> usize { + *NUM_CPUS + } + + fn make_state(&self) -> DaftResult> { + Ok(Box::new(AggregateState::Accumulating(vec![]))) + } } diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 3fcbf8d660..80438c8173 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -1,17 +1,48 @@ -use std::sync::Arc; +use std::sync::{Arc, Mutex}; +use common_daft_config::DaftExecutionConfig; use common_display::tree::TreeDisplay; use common_error::DaftResult; use common_runtime::get_compute_runtime; use daft_micropartition::MicroPartition; -use tracing::info_span; +use snafu::ResultExt; +use tracing::{info_span, instrument}; use crate::{ - channel::PipelineChannel, + buffer::RowBasedBuffer, + channel::{create_channel, PipelineChannel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, - runtime_stats::RuntimeStatsContext, - ExecutionRuntimeHandle, + runtime_stats::{CountingReceiver, RuntimeStatsContext}, + ExecutionRuntimeHandle, JoinSnafu, TaskSet, }; +pub trait DynBlockingSinkState: Send + Sync { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any; +} + +pub(crate) struct BlockingSinkState { + inner: Mutex>, +} + +impl BlockingSinkState { + fn new(inner: Box) -> Arc { + Arc::new(Self { + inner: Mutex::new(inner), + }) + } + + pub(crate) fn with_state_mut(&self, f: F) -> R + where + F: FnOnce(&mut T) -> R, + { + let mut guard = self.inner.lock().unwrap(); + let state = guard + .as_any_mut() + .downcast_mut::() + .expect("State type mismatch"); + f(state) + } +} + pub enum BlockingSinkStatus { NeedMoreInput, #[allow(dead_code)] @@ -19,24 +50,35 @@ pub enum BlockingSinkStatus { } pub trait BlockingSink: Send + Sync { - fn sink(&mut self, input: &Arc) -> DaftResult; - fn finalize(&mut self) -> DaftResult>; + fn sink( + &self, + input: &Arc, + state_handle: &BlockingSinkState, + ) -> DaftResult; + fn finalize( + &self, + states: Vec>, + ) -> DaftResult>; fn name(&self) -> &'static str; + fn make_state(&self) -> DaftResult>; + fn max_concurrency(&self) -> usize; + fn morsel_size(&self) -> Option { + Some(DaftExecutionConfig::default().default_morsel_size) + } } pub struct BlockingSinkNode { - // use a RW lock - op: Arc>>, + op: Arc, name: &'static str, child: Box, runtime_stats: Arc, } impl BlockingSinkNode { - pub(crate) fn new(op: Box, child: Box) -> Self { + pub(crate) fn new(op: Arc, child: Box) -> Self { let name = op.name(); Self { - op: Arc::new(tokio::sync::Mutex::new(op)), + op, name, child, runtime_stats: RuntimeStatsContext::new(), @@ -45,6 +87,92 @@ impl BlockingSinkNode { pub(crate) fn boxed(self) -> Box { Box::new(self) } + + #[instrument(level = "info", skip_all, name = "BlockingSink::run_worker")] + async fn run_worker( + op: Arc, + mut input_receiver: Receiver, + rt_context: Arc, + ) -> DaftResult> { + let span = info_span!("BlockingSink::Sink"); + let compute_runtime = get_compute_runtime(); + let state_wrapper = BlockingSinkState::new(op.make_state()?); + while let Some(morsel) = input_receiver.recv().await { + let op = op.clone(); + let morsel = morsel.clone(); + let span = span.clone(); + let rt_context = rt_context.clone(); + let state_wrapper = state_wrapper.clone(); + let fut = async move { + rt_context.in_span(&span, || op.sink(morsel.as_data(), &state_wrapper)) + }; + let result = compute_runtime.await_on(fut).await??; + match result { + BlockingSinkStatus::NeedMoreInput => {} + BlockingSinkStatus::Finished => { + break; + } + } + } + + // Take the state out of the Arc and Mutex because we need to return it. + // It should be guaranteed that the ONLY holder of state at this point is this function. + Ok(Arc::into_inner(state_wrapper) + .expect("Completed worker should have exclusive access to state wrapper") + .inner + .into_inner() + .expect("Completed worker should have exclusive access to inner state")) + } + + fn spawn_workers( + op: Arc, + input_receivers: Vec>, + task_set: &mut TaskSet>>, + stats: Arc, + ) { + for input_receiver in input_receivers { + task_set.spawn(Self::run_worker(op.clone(), input_receiver, stats.clone())); + } + } + + // Forwards input from the child to the workers in a round-robin fashion. + pub async fn forward_input_to_workers( + mut receiver: CountingReceiver, + worker_senders: Vec>, + morsel_size: Option, + ) -> DaftResult<()> { + let mut next_worker_idx = 0; + let mut send_to_next_worker = |data: PipelineResultType| { + let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); + next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); + next_worker_sender.send(data) + }; + + let mut buffer = morsel_size.map(RowBasedBuffer::new); + while let Some(morsel) = receiver.recv().await { + if morsel.should_broadcast() { + for worker_sender in &worker_senders { + let _ = worker_sender.send(morsel.clone()).await; + } + } else if let Some(buffer) = buffer.as_mut() { + buffer.push(morsel.as_data().clone()); + if let Some(ready) = buffer.pop_enough()? { + for r in ready { + let _ = send_to_next_worker(r.into()).await; + } + } + } else { + let _ = send_to_next_worker(morsel).await; + } + } + if let Some(buffer) = buffer.as_mut() { + // Clear all remaining morsels + if let Some(last_morsel) = buffer.pop_all()? { + let _ = send_to_next_worker(last_morsel.into()).await; + } + } + Ok(()) + } } impl TreeDisplay for BlockingSinkNode { @@ -81,7 +209,7 @@ impl PipelineNode for BlockingSinkNode { runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result { let child = self.child.as_mut(); - let mut child_results_receiver = child + let child_results_receiver = child .start(false, runtime_handle)? .get_receiver_with_stats(&self.runtime_stats); @@ -89,34 +217,40 @@ impl PipelineNode for BlockingSinkNode { let destination_sender = destination_channel.get_next_sender_with_stats(&self.runtime_stats); let op = self.op.clone(); - let rt_context = self.runtime_stats.clone(); + let runtime_stats = self.runtime_stats.clone(); + let num_workers = op.max_concurrency(); + let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); + let morsel_size = runtime_handle.determine_morsel_size(op.morsel_size()); + runtime_handle.spawn( + Self::forward_input_to_workers(child_results_receiver, input_senders, morsel_size), + self.name(), + ); runtime_handle.spawn( async move { - let span = info_span!("BlockingSinkNode::execute"); - let compute_runtime = get_compute_runtime(); - while let Some(val) = child_results_receiver.recv().await { - let op = op.clone(); - let span = span.clone(); - let rt_context = rt_context.clone(); - let fut = async move { - let mut guard = op.lock().await; - rt_context.in_span(&span, || guard.sink(val.as_data())) - }; - let result = compute_runtime.await_on(fut).await??; - if matches!(result, BlockingSinkStatus::Finished) { - break; - } + let mut task_set = TaskSet::new(); + Self::spawn_workers( + op.clone(), + input_receivers, + &mut task_set, + runtime_stats.clone(), + ); + + let mut finished_states = Vec::with_capacity(num_workers); + while let Some(result) = task_set.join_next().await { + let state = result.context(JoinSnafu)??; + finished_states.push(state); } + + let compute_runtime = get_compute_runtime(); let finalized_result = compute_runtime .await_on(async move { - let mut guard = op.lock().await; - rt_context.in_span(&info_span!("BlockingSinkNode::finalize"), || { - guard.finalize() + runtime_stats.in_span(&info_span!("BlockingSinkNode::finalize"), || { + op.finalize(finished_states) }) }) .await??; - if let Some(part) = finalized_result { - let _ = destination_sender.send(part).await; + if let Some(res) = finalized_result { + let _ = destination_sender.send(res).await; } Ok(()) }, diff --git a/src/daft-local-execution/src/sinks/concat.rs b/src/daft-local-execution/src/sinks/concat.rs index 3fc710c691..a178d287e1 100644 --- a/src/daft-local-execution/src/sinks/concat.rs +++ b/src/daft-local-execution/src/sinks/concat.rs @@ -64,4 +64,9 @@ impl StreamingSink for ConcatSink { fn max_concurrency(&self) -> usize { 1 } + + /// The ConcatSink does not do any computation in the sink method, so no need to buffer. + fn morsel_size(&self) -> Option { + None + } } diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index c8258e281a..c8c584d268 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -7,7 +7,9 @@ use daft_micropartition::MicroPartition; use daft_plan::JoinType; use daft_table::{make_probeable_builder, ProbeState, ProbeableBuilder, Table}; -use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; +use super::blocking_sink::{ + BlockingSink, BlockingSinkState, BlockingSinkStatus, DynBlockingSinkState, +}; use crate::pipeline::PipelineResultType; enum ProbeTableState { @@ -74,8 +76,16 @@ impl ProbeTableState { } } +impl DynBlockingSinkState for ProbeTableState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + pub struct HashJoinBuildSink { - probe_table_state: ProbeTableState, + key_schema: SchemaRef, + projection: Vec, + join_type: JoinType, } impl HashJoinBuildSink { @@ -85,13 +95,11 @@ impl HashJoinBuildSink { join_type: &JoinType, ) -> DaftResult { Ok(Self { - probe_table_state: ProbeTableState::new(&key_schema, projection, join_type)?, + key_schema, + projection, + join_type: *join_type, }) } - - pub(crate) fn boxed(self) -> Box { - Box::new(self) - } } impl BlockingSink for HashJoinBuildSink { @@ -99,17 +107,49 @@ impl BlockingSink for HashJoinBuildSink { "HashJoinBuildSink" } - fn sink(&mut self, input: &Arc) -> DaftResult { - self.probe_table_state.add_tables(input)?; - Ok(BlockingSinkStatus::NeedMoreInput) + fn sink( + &self, + input: &Arc, + state_handle: &BlockingSinkState, + ) -> DaftResult { + state_handle.with_state_mut::(|state| { + state.add_tables(input)?; + Ok(BlockingSinkStatus::NeedMoreInput) + }) } - fn finalize(&mut self) -> DaftResult> { - self.probe_table_state.finalize()?; - if let ProbeTableState::Done { probe_state } = &self.probe_table_state { + fn finalize( + &self, + states: Vec>, + ) -> DaftResult> { + assert_eq!(states.len(), 1); + let mut state = states.into_iter().next().unwrap(); + let probe_table_state = state + .as_any_mut() + .downcast_mut::() + .expect("State type mismatch"); + probe_table_state.finalize()?; + if let ProbeTableState::Done { probe_state } = probe_table_state { Ok(Some(probe_state.clone().into())) } else { panic!("finalize should only be called after the probe table is built") } } + + fn max_concurrency(&self) -> usize { + 1 + } + + // Hash join build is currently single threaded, so we don't need to buffer the input + fn morsel_size(&self) -> Option { + None + } + + fn make_state(&self) -> DaftResult> { + Ok(Box::new(ProbeTableState::new( + &self.key_schema, + self.projection.clone(), + &self.join_type, + )?)) + } } diff --git a/src/daft-local-execution/src/sinks/limit.rs b/src/daft-local-execution/src/sinks/limit.rs index ff24a703e2..e65a4afb1c 100644 --- a/src/daft-local-execution/src/sinks/limit.rs +++ b/src/daft-local-execution/src/sinks/limit.rs @@ -90,4 +90,9 @@ impl StreamingSink for LimitSink { fn max_concurrency(&self) -> usize { 1 } + + /// Limits are greedy and should consume all input data whenever possible. + fn morsel_size(&self) -> Option { + None + } } diff --git a/src/daft-local-execution/src/sinks/sort.rs b/src/daft-local-execution/src/sinks/sort.rs index 169ea9e55d..0e5a2a7e41 100644 --- a/src/daft-local-execution/src/sinks/sort.rs +++ b/src/daft-local-execution/src/sinks/sort.rs @@ -5,18 +5,44 @@ use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use tracing::instrument; -use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; -use crate::pipeline::PipelineResultType; -pub struct SortSink { - sort_by: Vec, - descending: Vec, - state: SortState, -} +use super::blocking_sink::{ + BlockingSink, BlockingSinkState, BlockingSinkStatus, DynBlockingSinkState, +}; +use crate::{pipeline::PipelineResultType, NUM_CPUS}; enum SortState { Building(Vec>), - #[allow(dead_code)] - Done(Arc), + Done, +} + +impl SortState { + fn push(&mut self, part: Arc) { + if let Self::Building(ref mut parts) = self { + parts.push(part); + } else { + panic!("SortSink should be in Building state"); + } + } + + fn finalize(&mut self) -> DaftResult>> { + let res = if let Self::Building(ref mut parts) = self { + std::mem::take(parts) + } else { + panic!("SortSink should be in Building state"); + }; + *self = Self::Done; + Ok(res) + } +} + +impl DynBlockingSinkState for SortState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} +pub struct SortSink { + sort_by: Vec, + descending: Vec, } impl SortSink { @@ -24,46 +50,64 @@ impl SortSink { Self { sort_by, descending, - state: SortState::Building(vec![]), } } - pub fn boxed(self) -> Box { - Box::new(self) - } } impl BlockingSink for SortSink { #[instrument(skip_all, name = "SortSink::sink")] - fn sink(&mut self, input: &Arc) -> DaftResult { - if let SortState::Building(parts) = &mut self.state { - parts.push(input.clone()); - } else { - panic!("SortSink should be in Building state"); - } - Ok(BlockingSinkStatus::NeedMoreInput) + fn sink( + &self, + input: &Arc, + state_handle: &BlockingSinkState, + ) -> DaftResult { + state_handle.with_state_mut::(|state| { + state.push(input.clone()); + Ok(BlockingSinkStatus::NeedMoreInput) + }) } #[instrument(skip_all, name = "SortSink::finalize")] - fn finalize(&mut self) -> DaftResult> { - if let SortState::Building(parts) = &mut self.state { - assert!( - !parts.is_empty(), - "We can not finalize SortSink with no data" - ); - let concated = MicroPartition::concat( - &parts - .iter() - .map(std::convert::AsRef::as_ref) - .collect::>(), - )?; - let sorted = Arc::new(concated.sort(&self.sort_by, &self.descending)?); - self.state = SortState::Done(sorted.clone()); - Ok(Some(sorted.into())) - } else { - panic!("SortSink should be in Building state"); + fn finalize( + &self, + states: Vec>, + ) -> DaftResult> { + let mut parts = Vec::new(); + for mut state in states { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("State type mismatch"); + parts.extend(state.finalize()?); } + assert!( + !parts.is_empty(), + "We can not finalize SortSink with no data" + ); + let concated = MicroPartition::concat( + &parts + .iter() + .map(std::convert::AsRef::as_ref) + .collect::>(), + )?; + let sorted = Arc::new(concated.sort(&self.sort_by, &self.descending)?); + Ok(Some(sorted.into())) } + fn name(&self) -> &'static str { "SortResult" } + + fn make_state(&self) -> DaftResult> { + Ok(Box::new(SortState::Building(Vec::new()))) + } + + // SortSink currently does not do any computation in the sink method, so no need to buffer. + fn morsel_size(&self) -> Option { + None + } + + fn max_concurrency(&self) -> usize { + *NUM_CPUS + } } diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 102fd39618..112c6d5905 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -8,6 +8,7 @@ use snafu::ResultExt; use tracing::{info_span, instrument}; use crate::{ + buffer::RowBasedBuffer, channel::{create_channel, PipelineChannel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::{CountingReceiver, RuntimeStatsContext}, @@ -77,6 +78,10 @@ pub trait StreamingSink: Send + Sync { fn max_concurrency(&self) -> usize { *NUM_CPUS } + + fn morsel_size(&self) -> Option { + Some(*NUM_CPUS) + } } pub struct StreamingSinkNode { @@ -179,6 +184,7 @@ impl StreamingSinkNode { async fn forward_input_to_workers( receivers: Vec, worker_senders: Vec>, + morsel_size: Option, ) -> DaftResult<()> { let mut next_worker_idx = 0; let mut send_to_next_worker = |idx, data: PipelineResultType| { @@ -188,13 +194,27 @@ impl StreamingSinkNode { }; for (idx, mut receiver) in receivers.into_iter().enumerate() { + let mut buffer = morsel_size.map(RowBasedBuffer::new); while let Some(morsel) = receiver.recv().await { if morsel.should_broadcast() { for worker_sender in &worker_senders { let _ = worker_sender.send((idx, morsel.clone())).await; } + } else if let Some(buffer) = buffer.as_mut() { + buffer.push(morsel.as_data().clone()); + if let Some(ready) = buffer.pop_enough()? { + for r in ready { + let _ = send_to_next_worker(idx, r.into()).await; + } + } } else { - let _ = send_to_next_worker(idx, morsel.clone()).await; + let _ = send_to_next_worker(idx, morsel).await; + } + } + if let Some(buffer) = buffer.as_mut() { + // Clear all remaining morsels + if let Some(last_morsel) = buffer.pop_all()? { + let _ = send_to_next_worker(idx, last_morsel.into()).await; } } } @@ -255,8 +275,9 @@ impl PipelineNode for StreamingSinkNode { let runtime_stats = self.runtime_stats.clone(); let num_workers = op.max_concurrency(); let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); + let morsel_size = runtime_handle.determine_morsel_size(op.morsel_size()); runtime_handle.spawn( - Self::forward_input_to_workers(child_result_receivers, input_senders), + Self::forward_input_to_workers(child_result_receivers, input_senders, morsel_size), self.name(), ); runtime_handle.spawn( diff --git a/tests/cookbook/conftest.py b/tests/cookbook/conftest.py index c04f7ab208..bbc5ea7014 100644 --- a/tests/cookbook/conftest.py +++ b/tests/cookbook/conftest.py @@ -48,3 +48,9 @@ def repartition_nparts(request): partitions that the test case should repartition its dataset into for testing """ return request.param + + +@pytest.fixture(scope="function", autouse=True) +def set_default_morsel_size(): + with daft.context.execution_config_ctx(default_morsel_size=1): + yield diff --git a/tests/dataframe/test_aggregations.py b/tests/dataframe/test_aggregations.py index 74fe889ce0..6f5fe7900f 100644 --- a/tests/dataframe/test_aggregations.py +++ b/tests/dataframe/test_aggregations.py @@ -15,6 +15,12 @@ from tests.utils import sort_arrow_table +@pytest.fixture(scope="function", autouse=True) +def set_default_morsel_size(): + with daft.context.execution_config_ctx(default_morsel_size=1): + yield + + @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) def test_agg_global(make_df, repartition_nparts): daft_df = make_df( @@ -360,6 +366,9 @@ def test_agg_groupby_with_alias(make_df, repartition_nparts): class CustomObject: val: int + def __hash__(self): + return hash(self.val) + def test_agg_pyobjects(): objects = [CustomObject(val=0), None, CustomObject(val=1)] @@ -375,7 +384,7 @@ def test_agg_pyobjects(): res = df.to_pydict() assert res["count"] == [2] - assert res["list"] == [objects] + assert set(res["list"][0]) == set(objects) def test_groupby_agg_pyobjects(): @@ -397,7 +406,8 @@ def test_groupby_agg_pyobjects(): res = df.to_pydict() assert res["groups"] == [1, 2] assert res["count"] == [2, 1] - assert res["list"] == [[objects[0], objects[2], objects[4]], [objects[1], objects[3]]] + assert set(res["list"][0]) == set([objects[0], objects[2], objects[4]]) + assert set(res["list"][1]) == set([objects[1], objects[3]]) @pytest.mark.parametrize("shuffle_aggregation_default_partitions", [None, 20]) diff --git a/tests/dataframe/test_approx_count_distinct.py b/tests/dataframe/test_approx_count_distinct.py index 68d7057ca0..43cdab9f15 100644 --- a/tests/dataframe/test_approx_count_distinct.py +++ b/tests/dataframe/test_approx_count_distinct.py @@ -4,6 +4,13 @@ import daft from daft import col + +@pytest.fixture(scope="function", autouse=True) +def set_default_morsel_size(): + with daft.context.execution_config_ctx(default_morsel_size=1): + yield + + TESTS = [ [[], 0], [[None] * 10, 0], diff --git a/tests/dataframe/test_approx_percentiles_aggregations.py b/tests/dataframe/test_approx_percentiles_aggregations.py index d64f1a2381..d993f428cf 100644 --- a/tests/dataframe/test_approx_percentiles_aggregations.py +++ b/tests/dataframe/test_approx_percentiles_aggregations.py @@ -4,9 +4,16 @@ import pyarrow as pa import pytest +import daft from daft import col +@pytest.fixture(scope="function", autouse=True) +def set_default_morsel_size(): + with daft.context.execution_config_ctx(default_morsel_size=1): + yield + + @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) @pytest.mark.parametrize("percentiles_expected", [(0.5, [2.0]), ([0.5], [[2.0]]), ([0.5, 0.5], [[2.0, 2.0]])]) def test_approx_percentiles_global(make_df, repartition_nparts, percentiles_expected): diff --git a/tests/dataframe/test_concat.py b/tests/dataframe/test_concat.py index f3caf56bb1..a21081fe84 100644 --- a/tests/dataframe/test_concat.py +++ b/tests/dataframe/test_concat.py @@ -2,6 +2,14 @@ import pytest +import daft + + +@pytest.fixture(scope="function", autouse=True) +def set_default_morsel_size(): + with daft.context.execution_config_ctx(default_morsel_size=1): + yield + def test_simple_concat(make_df): df1 = make_df({"foo": [1, 2, 3]}) diff --git a/tests/dataframe/test_distinct.py b/tests/dataframe/test_distinct.py index 8e4b2c0a85..4b7e8db31c 100644 --- a/tests/dataframe/test_distinct.py +++ b/tests/dataframe/test_distinct.py @@ -3,10 +3,17 @@ import pyarrow as pa import pytest +import daft from daft.datatype import DataType from tests.utils import sort_arrow_table +@pytest.fixture(scope="function", autouse=True) +def set_default_morsel_size(): + with daft.context.execution_config_ctx(default_morsel_size=1): + yield + + @pytest.mark.parametrize("repartition_nparts", [1, 2, 5]) def test_distinct_with_nulls(make_df, repartition_nparts): daft_df = make_df( diff --git a/tests/dataframe/test_iter.py b/tests/dataframe/test_iter.py index 8658e5da30..c0604568a3 100644 --- a/tests/dataframe/test_iter.py +++ b/tests/dataframe/test_iter.py @@ -3,12 +3,6 @@ import pytest import daft -from daft import context - -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) class MockException(Exception): @@ -33,25 +27,26 @@ def test_iter_partitions(make_df, materialized): # Test that df.iter_partitions() produces partitions in the correct order. # It should work regardless of whether the dataframe has already been materialized or not. - df = make_df({"a": list(range(10))}).into_partitions(5).with_column("b", daft.col("a") + 100) + with daft.execution_config_ctx(default_morsel_size=2): + df = make_df({"a": list(range(10))}).into_partitions(5).with_column("b", daft.col("a") + 100) - if materialized: - df = df.collect() + if materialized: + df = df.collect() - parts = list(df.iter_partitions()) - if daft.context.get_context().runner_config.name == "ray": - import ray + parts = list(df.iter_partitions()) + if daft.context.get_context().runner_config.name == "ray": + import ray - parts = ray.get(parts) - parts = [_.to_pydict() for _ in parts] + parts = ray.get(parts) + parts = [_.to_pydict() for _ in parts] - assert parts == [ - {"a": [0, 1], "b": [100, 101]}, - {"a": [2, 3], "b": [102, 103]}, - {"a": [4, 5], "b": [104, 105]}, - {"a": [6, 7], "b": [106, 107]}, - {"a": [8, 9], "b": [108, 109]}, - ] + assert parts == [ + {"a": [0, 1], "b": [100, 101]}, + {"a": [2, 3], "b": [102, 103]}, + {"a": [4, 5], "b": [104, 105]}, + {"a": [6, 7], "b": [106, 107]}, + {"a": [8, 9], "b": [108, 109]}, + ] def test_iter_exception(make_df): @@ -94,26 +89,27 @@ def echo_or_trigger(s): else: return s - df = make_df({"a": list(range(200))}).into_partitions(100).with_column("b", echo_or_trigger(daft.col("a"))) + with daft.execution_config_ctx(default_morsel_size=2): + df = make_df({"a": list(range(200))}).into_partitions(100).with_column("b", echo_or_trigger(daft.col("a"))) - it = df.iter_partitions() - part = next(it) - if daft.context.get_context().runner_config.name == "ray": - import ray + it = df.iter_partitions() + part = next(it) + if daft.context.get_context().runner_config.name == "ray": + import ray - part = ray.get(part) - part = part.to_pydict() + part = ray.get(part) + part = part.to_pydict() - assert part == {"a": [0, 1], "b": [0, 1]} + assert part == {"a": [0, 1], "b": [0, 1]} - # Ensure the exception does trigger if execution continues. - with pytest.raises(RuntimeError) as exc_info: - res = list(it) - if daft.context.get_context().runner_config.name == "ray": - ray.get(res) + # Ensure the exception does trigger if execution continues. + with pytest.raises(RuntimeError) as exc_info: + res = list(it) + if daft.context.get_context().runner_config.name == "ray": + ray.get(res) - # Ray's wrapping of the exception loses information about the `.cause`, but preserves it in the string error message - if daft.context.get_context().runner_config.name == "ray": - assert "MockException" in str(exc_info.value) - else: - assert isinstance(exc_info.value.__cause__, MockException) + # Ray's wrapping of the exception loses information about the `.cause`, but preserves it in the string error message + if daft.context.get_context().runner_config.name == "ray": + assert "MockException" in str(exc_info.value) + else: + assert isinstance(exc_info.value.__cause__, MockException) diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index 4b08abea61..eb131ac2ce 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -10,6 +10,12 @@ from tests.utils import sort_arrow_table +@pytest.fixture(scope="function", autouse=True) +def set_default_morsel_size(): + with context.execution_config_ctx(default_morsel_size=1): + yield + + def skip_invalid_join_strategies(join_strategy, join_type): if context.get_context().daft_execution_config.enable_native_executor is True: if join_strategy not in [None, "hash"]: @@ -151,7 +157,7 @@ def test_dupes_join_key(join_strategy, join_type, make_df, n_partitions: int): ) joined = df.join(df, on="A", strategy=join_strategy, how=join_type) - joined = joined.sort(["A", "B"]) + joined = joined.sort(["A", "B", "right.B"]) joined_data = joined.to_pydict() assert joined_data == { @@ -182,14 +188,14 @@ def test_multicol_dupes_join_key(join_strategy, join_type, make_df, n_partitions ) joined = df.join(df, on=["A", "B"], strategy=join_strategy, how=join_type) - joined = joined.sort(["A", "B", "C"]) + joined = joined.sort(["A", "B", "C", "right.C"]) joined_data = joined.to_pydict() assert joined_data == { "A": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3], "B": ["a"] * 4 + ["b"] * 4 + ["c", "d"], "C": [0, 0, 1, 1, 0, 0, 1, 1, 1, 0], - "right.C": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0], + "right.C": [0, 1, 0, 1, 0, 1, 0, 1, 1, 0], } @@ -213,7 +219,7 @@ def test_joins_all_same_key(join_strategy, join_type, make_df, n_partitions: int ) joined = df.join(df, on="A", strategy=join_strategy, how=join_type) - joined = joined.sort(["A", "B"]) + joined = joined.sort(["A", "B", "right.B"]) joined_data = joined.to_pydict() assert joined_data == { diff --git a/tests/dataframe/test_map_groups.py b/tests/dataframe/test_map_groups.py index 4f0f2e29ec..76ec443ec3 100644 --- a/tests/dataframe/test_map_groups.py +++ b/tests/dataframe/test_map_groups.py @@ -5,6 +5,12 @@ import daft +@pytest.fixture(scope="function", autouse=True) +def set_default_morsel_size(): + with daft.context.execution_config_ctx(default_morsel_size=1): + yield + + @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) def test_map_groups(make_df, repartition_nparts): daft_df = make_df( diff --git a/tests/dataframe/test_pivot.py b/tests/dataframe/test_pivot.py index fcd88c9c51..6fe7989adc 100644 --- a/tests/dataframe/test_pivot.py +++ b/tests/dataframe/test_pivot.py @@ -1,5 +1,13 @@ import pytest +import daft + + +@pytest.fixture(scope="function", autouse=True) +def set_default_morsel_size(): + with daft.context.execution_config_ctx(default_morsel_size=1): + yield + @pytest.mark.parametrize("repartition_nparts", [1, 2, 5]) def test_pivot(make_df, repartition_nparts): diff --git a/tests/dataframe/test_sort.py b/tests/dataframe/test_sort.py index 8a831a2bcf..9ffd6497c6 100644 --- a/tests/dataframe/test_sort.py +++ b/tests/dataframe/test_sort.py @@ -5,6 +5,7 @@ import pyarrow as pa import pytest +import daft from daft.datatype import DataType from daft.errors import ExpressionTypeError @@ -13,6 +14,12 @@ ### +@pytest.fixture(scope="function", autouse=True) +def set_default_morsel_size(): + with daft.context.execution_config_ctx(default_morsel_size=1): + yield + + def test_disallowed_sort_null(make_df): df = make_df({"A": [None, None]}) diff --git a/tests/dataframe/test_stddev.py b/tests/dataframe/test_stddev.py index 464d20bd41..c745263bf5 100644 --- a/tests/dataframe/test_stddev.py +++ b/tests/dataframe/test_stddev.py @@ -8,6 +8,12 @@ import daft +@pytest.fixture(scope="function", autouse=True) +def set_default_morsel_size(): + with daft.context.execution_config_ctx(default_morsel_size=1): + yield + + def grouped_stddev(rows) -> Tuple[List[Any], List[Any]]: map = {} for key, data in rows: diff --git a/tests/dataframe/test_unpivot.py b/tests/dataframe/test_unpivot.py index b4c7a84cc5..ef20c453cc 100644 --- a/tests/dataframe/test_unpivot.py +++ b/tests/dataframe/test_unpivot.py @@ -1,9 +1,16 @@ import pytest +import daft from daft import col from daft.datatype import DataType +@pytest.fixture(scope="function", autouse=True) +def set_default_morsel_size(): + with daft.context.execution_config_ctx(default_morsel_size=1): + yield + + @pytest.mark.parametrize("n_partitions", [1, 2, 4]) def test_unpivot(make_df, n_partitions): df = make_df( From f1f3c6a7eb3545e5d53efc129bfec30c97e8b99b Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 24 Oct 2024 12:34:15 -0700 Subject: [PATCH 02/10] fix test --- tests/dataframe/test_iter.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/dataframe/test_iter.py b/tests/dataframe/test_iter.py index c0604568a3..7d4ede270d 100644 --- a/tests/dataframe/test_iter.py +++ b/tests/dataframe/test_iter.py @@ -61,20 +61,21 @@ def echo_or_trigger(s): else: return s - df = make_df({"a": list(range(200))}).into_partitions(100).with_column("b", echo_or_trigger(daft.col("a"))) + with daft.execution_config_ctx(default_morsel_size=2): + df = make_df({"a": list(range(200))}).into_partitions(100).with_column("b", echo_or_trigger(daft.col("a"))) - it = iter(df) - assert next(it) == {"a": 0, "b": 0} + it = iter(df) + assert next(it) == {"a": 0, "b": 0} - # Ensure the exception does trigger if execution continues. - with pytest.raises(RuntimeError) as exc_info: - list(it) + # Ensure the exception does trigger if execution continues. + with pytest.raises(RuntimeError) as exc_info: + list(it) - # Ray's wrapping of the exception loses information about the `.cause`, but preserves it in the string error message - if daft.context.get_context().runner_config.name == "ray": - assert "MockException" in str(exc_info.value) - else: - assert isinstance(exc_info.value.__cause__, MockException) + # Ray's wrapping of the exception loses information about the `.cause`, but preserves it in the string error message + if daft.context.get_context().runner_config.name == "ray": + assert "MockException" in str(exc_info.value) + else: + assert isinstance(exc_info.value.__cause__, MockException) def test_iter_partitions_exception(make_df): From 67b429ddcd221b59a391dec96bd615f4f59b4e76 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 24 Oct 2024 19:50:36 -0700 Subject: [PATCH 03/10] loole --- Cargo.lock | 11 ++ src/daft-local-execution/Cargo.toml | 1 + src/daft-local-execution/src/channel.rs | 107 +++++------------- .../src/intermediate_ops/intermediate_op.rs | 63 +++++++---- src/daft-local-execution/src/pipeline.rs | 4 +- src/daft-local-execution/src/run.rs | 8 +- src/daft-local-execution/src/runtime_stats.rs | 15 ++- .../src/sinks/blocking_sink.rs | 33 +++--- .../src/sinks/streaming_sink.rs | 52 +++++---- .../src/sources/scan_task.rs | 7 +- .../src/sources/source.rs | 12 +- tests/benchmarks/test_local_tpch.py | 4 +- 12 files changed, 153 insertions(+), 164 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c146000729..3186bf148c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1990,6 +1990,7 @@ dependencies = [ "indexmap 2.5.0", "lazy_static", "log", + "loole", "num-format", "pyo3", "snafu", @@ -3405,6 +3406,16 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "loole" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2998397c725c822c6b2ba605fd9eb4c6a7a0810f1629ba3cc232ef4f0308d96" +dependencies = [ + "futures-core", + "futures-sink", +] + [[package]] name = "lz4" version = "1.26.0" diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index 8da3b93325..3516de9079 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -21,6 +21,7 @@ futures = {workspace = true} indexmap = {workspace = true} lazy_static = {workspace = true} log = {workspace = true} +loole = "0.4.0" num-format = "0.4.4" pyo3 = {workspace = true, optional = true} snafu = {workspace = true} diff --git a/src/daft-local-execution/src/channel.rs b/src/daft-local-execution/src/channel.rs index f16e3bd061..c518de28e6 100644 --- a/src/daft-local-execution/src/channel.rs +++ b/src/daft-local-execution/src/channel.rs @@ -1,95 +1,42 @@ -use std::sync::Arc; - -use crate::{ - pipeline::PipelineResultType, - runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, -}; - -pub type Sender = tokio::sync::mpsc::Sender; -pub type Receiver = tokio::sync::mpsc::Receiver; +pub type Sender = loole::Sender; +pub type Receiver = loole::Receiver; pub fn create_channel(buffer_size: usize) -> (Sender, Receiver) { - tokio::sync::mpsc::channel(buffer_size) -} - -pub struct PipelineChannel { - sender: PipelineSender, - receiver: PipelineReceiver, + loole::bounded(buffer_size) } -impl PipelineChannel { - pub fn new(buffer_size: usize, in_order: bool) -> Self { - if in_order { - let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip(); - let sender = PipelineSender::InOrder(RoundRobinSender::new(senders)); - let receiver = PipelineReceiver::InOrder(RoundRobinReceiver::new(receivers)); - Self { sender, receiver } - } else { - let (sender, receiver) = create_channel(buffer_size); - let sender = PipelineSender::OutOfOrder(sender); - let receiver = PipelineReceiver::OutOfOrder(receiver); - Self { sender, receiver } - } - } - - fn get_next_sender(&mut self) -> Sender { - match &mut self.sender { - PipelineSender::InOrder(rr) => rr.get_next_sender(), - PipelineSender::OutOfOrder(sender) => sender.clone(), +pub fn create_ordering_aware_channel( + ordered: bool, + buffer_size: usize, +) -> (Vec>, OrderingAwareReceiver) { + match ordered { + true => { + let (senders, receiver) = (0..buffer_size).map(|_| create_channel::(1)).unzip(); + ( + senders, + OrderingAwareReceiver::InOrder(RoundRobinReceiver::new(receiver)), + ) } - } - - pub(crate) fn get_next_sender_with_stats( - &mut self, - rt: &Arc, - ) -> CountingSender { - CountingSender::new(self.get_next_sender(), rt.clone()) - } - - pub fn get_receiver(self) -> PipelineReceiver { - self.receiver - } - - pub(crate) fn get_receiver_with_stats(self, rt: &Arc) -> CountingReceiver { - CountingReceiver::new(self.get_receiver(), rt.clone()) - } -} - -pub enum PipelineSender { - InOrder(RoundRobinSender), - OutOfOrder(Sender), -} - -pub struct RoundRobinSender { - senders: Vec>, - curr_sender_idx: usize, -} - -impl RoundRobinSender { - pub fn new(senders: Vec>) -> Self { - Self { - senders, - curr_sender_idx: 0, + false => { + let (sender, receiver) = create_channel::(buffer_size); + ( + (0..buffer_size).map(|_| sender.clone()).collect(), + OrderingAwareReceiver::OutOfOrder(receiver), + ) } } - - pub fn get_next_sender(&mut self) -> Sender { - let next_idx = self.curr_sender_idx; - self.curr_sender_idx = (next_idx + 1) % self.senders.len(); - self.senders[next_idx].clone() - } } -pub enum PipelineReceiver { - InOrder(RoundRobinReceiver), - OutOfOrder(Receiver), +pub enum OrderingAwareReceiver { + InOrder(RoundRobinReceiver), + OutOfOrder(Receiver), } -impl PipelineReceiver { - pub async fn recv(&mut self) -> Option { +impl OrderingAwareReceiver { + pub async fn recv(&mut self) -> Option { match self { Self::InOrder(rr) => rr.recv().await, - Self::OutOfOrder(r) => r.recv().await, + Self::OutOfOrder(r) => r.recv_async().await.ok(), } } } @@ -115,7 +62,7 @@ impl RoundRobinReceiver { } for i in 0..self.receivers.len() { let next_idx = (i + self.curr_receiver_idx) % self.receivers.len(); - if let Some(val) = self.receivers[next_idx].recv().await { + if let Some(val) = self.receivers[next_idx].recv_async().await.ok() { self.curr_receiver_idx = (next_idx + 1) % self.receivers.len(); return Some(val); } diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index b53cfdbd34..1b672add3f 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -9,7 +9,7 @@ use tracing::{info_span, instrument}; use crate::{ buffer::RowBasedBuffer, - channel::{create_channel, PipelineChannel, Receiver, Sender}, + channel::{create_channel, create_ordering_aware_channel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, ExecutionRuntimeHandle, NUM_CPUS, @@ -106,14 +106,14 @@ impl IntermediateNode { #[instrument(level = "info", skip_all, name = "IntermediateOperator::run_worker")] pub async fn run_worker( op: Arc, - mut receiver: Receiver<(usize, PipelineResultType)>, - sender: CountingSender, + receiver: Receiver<(usize, PipelineResultType)>, + sender: Sender, rt_context: Arc, ) -> DaftResult<()> { let span = info_span!("IntermediateOp::execute"); let compute_runtime = get_compute_runtime(); let state_wrapper = IntermediateOperatorState::new(op.make_state()); - while let Some((idx, morsel)) = receiver.recv().await { + while let Some((idx, morsel)) = receiver.recv_async().await.ok() { loop { let op = op.clone(); let morsel = morsel.clone(); @@ -126,14 +126,14 @@ impl IntermediateNode { let result = compute_runtime.await_on(fut).await??; match result { IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { - let _ = sender.send(mp.into()).await; + let _ = sender.send_async(mp.into()).await; break; } IntermediateOperatorResult::NeedMoreInput(None) => { break; } IntermediateOperatorResult::HasMoreOutput(mp) => { - let _ = sender.send(mp.into()).await; + let _ = sender.send_async(mp.into()).await; } } } @@ -144,24 +144,26 @@ impl IntermediateNode { pub fn spawn_workers( &self, num_workers: usize, - destination_channel: &mut PipelineChannel, + output_senders: Vec>, runtime_handle: &mut ExecutionRuntimeHandle, + maintain_order: bool, ) -> Vec> { - let mut worker_senders = Vec::with_capacity(num_workers); - for _ in 0..num_workers { - let (worker_sender, worker_receiver) = create_channel(1); - let destination_sender = - destination_channel.get_next_sender_with_stats(&self.runtime_stats); + let (worker_senders, worker_receivers): (Vec<_>, Vec<_>) = if maintain_order { + (0..num_workers).map(|_| create_channel(1)).unzip() + } else { + let (sender, receiver) = create_channel(num_workers); + (vec![sender; 1], vec![receiver; num_workers]) + }; + for (receiver, destination_channel) in worker_receivers.into_iter().zip(output_senders) { runtime_handle.spawn( Self::run_worker( self.intermediate_op.clone(), - worker_receiver, - destination_sender, + receiver, + destination_channel, self.runtime_stats.clone(), ), self.intermediate_op.name(), ); - worker_senders.push(worker_sender); } worker_senders } @@ -171,12 +173,11 @@ impl IntermediateNode { worker_senders: Vec>, morsel_size: Option, ) -> DaftResult<()> { - println!("morsel_size: {:?}", morsel_size); let mut next_worker_idx = 0; let mut send_to_next_worker = |idx, data: PipelineResultType| { let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); - next_worker_sender.send((idx, data)) + next_worker_sender.send_async((idx, data)) }; for (idx, mut receiver) in receivers.into_iter().enumerate() { @@ -184,7 +185,7 @@ impl IntermediateNode { while let Some(morsel) = receiver.recv().await { if morsel.should_broadcast() { for worker_sender in &worker_senders { - let _ = worker_sender.send((idx, morsel.clone())).await; + let _ = worker_sender.send_async((idx, morsel.clone())).await; } } else if let Some(buffer) = buffer.as_mut() { buffer.push(morsel.as_data().clone()); @@ -243,23 +244,37 @@ impl PipelineNode for IntermediateNode { &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + ) -> crate::Result> { let mut child_result_receivers = Vec::with_capacity(self.children.len()); for child in &mut self.children { let child_result_channel = child.start(maintain_order, runtime_handle)?; - child_result_receivers - .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats)); + child_result_receivers.push(CountingReceiver::new( + child_result_channel, + self.runtime_stats.clone(), + )); } - let mut destination_channel = PipelineChannel::new(*NUM_CPUS, maintain_order); + let (destination_sender, destination_receiver) = create_channel(1); + let counting_sender = CountingSender::new(destination_sender, self.runtime_stats.clone()); + let (output_senders, mut output_receiver) = + create_ordering_aware_channel(maintain_order, *NUM_CPUS); let worker_senders = - self.spawn_workers(*NUM_CPUS, &mut destination_channel, runtime_handle); + self.spawn_workers(*NUM_CPUS, output_senders, runtime_handle, maintain_order); let morsel_size = runtime_handle.determine_morsel_size(self.intermediate_op.morsel_size()); runtime_handle.spawn( Self::send_to_workers(child_result_receivers, worker_senders, morsel_size), self.intermediate_op.name(), ); - Ok(destination_channel) + runtime_handle.spawn( + async move { + while let Some(val) = output_receiver.recv().await { + let _ = counting_sender.send(val).await; + } + Ok(()) + }, + self.intermediate_op.name(), + ); + Ok(destination_receiver) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index fc95182a65..7f92f2d2ed 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -19,7 +19,7 @@ use indexmap::IndexSet; use snafu::ResultExt; use crate::{ - channel::PipelineChannel, + channel::Receiver, intermediate_ops::{ anti_semi_hash_join_probe::AntiSemiProbeOperator, explode::ExplodeOperator, filter::FilterOperator, inner_hash_join_probe::InnerHashJoinProbeOperator, @@ -81,7 +81,7 @@ pub trait PipelineNode: Sync + Send + TreeDisplay { &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result; + ) -> crate::Result>; fn as_tree_display(&self) -> &dyn TreeDisplay; } diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index ae0939ef8a..62d10b381b 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -127,10 +127,10 @@ pub fn run_local( .expect("Failed to create tokio runtime"); let execution_task = async { let mut runtime_handle = ExecutionRuntimeHandle::new(cfg.default_morsel_size); - let mut receiver = pipeline.start(true, &mut runtime_handle)?.get_receiver(); + let receiver = pipeline.start(true, &mut runtime_handle)?; - while let Some(val) = receiver.recv().await { - let _ = tx.send(val.as_data().clone()).await; + while let Some(val) = receiver.recv_async().await.ok() { + let _ = tx.send_async(val.as_data().clone()).await; } while let Some(result) = runtime_handle.join_next().await { @@ -180,7 +180,7 @@ pub fn run_local( type Item = DaftResult>; fn next(&mut self) -> Option { - match self.receiver.blocking_recv() { + match self.receiver.recv().ok() { Some(part) => Some(Ok(part)), None => { if self.handle.is_some() { diff --git a/src/daft-local-execution/src/runtime_stats.rs b/src/daft-local-execution/src/runtime_stats.rs index 566d253e9c..23af82d59a 100644 --- a/src/daft-local-execution/src/runtime_stats.rs +++ b/src/daft-local-execution/src/runtime_stats.rs @@ -5,10 +5,10 @@ use std::{ time::Instant, }; -use tokio::sync::mpsc::error::SendError; +use loole::SendError; use crate::{ - channel::{PipelineReceiver, Sender}, + channel::{Receiver, Sender}, pipeline::PipelineResultType, }; @@ -128,24 +128,27 @@ impl CountingSender { state.get_tables().iter().map(|t| t.len()).sum() } }; - self.sender.send(v).await?; + self.sender.send_async(v).await?; self.rt.mark_rows_emitted(len as u64); Ok(()) } } pub struct CountingReceiver { - receiver: PipelineReceiver, + receiver: Receiver, rt: Arc, } impl CountingReceiver { - pub(crate) fn new(receiver: PipelineReceiver, rt: Arc) -> Self { + pub(crate) fn new( + receiver: Receiver, + rt: Arc, + ) -> Self { Self { receiver, rt } } #[inline] pub(crate) async fn recv(&mut self) -> Option { - let v = self.receiver.recv().await; + let v = self.receiver.recv_async().await.ok(); if let Some(ref v) = v { let len = match v { PipelineResultType::Data(ref mp) => mp.len(), diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 80438c8173..e3901770b1 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -10,9 +10,9 @@ use tracing::{info_span, instrument}; use crate::{ buffer::RowBasedBuffer, - channel::{create_channel, PipelineChannel, Receiver, Sender}, + channel::{create_channel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, - runtime_stats::{CountingReceiver, RuntimeStatsContext}, + runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, ExecutionRuntimeHandle, JoinSnafu, TaskSet, }; pub trait DynBlockingSinkState: Send + Sync { @@ -91,13 +91,13 @@ impl BlockingSinkNode { #[instrument(level = "info", skip_all, name = "BlockingSink::run_worker")] async fn run_worker( op: Arc, - mut input_receiver: Receiver, + input_receiver: Receiver, rt_context: Arc, ) -> DaftResult> { let span = info_span!("BlockingSink::Sink"); let compute_runtime = get_compute_runtime(); let state_wrapper = BlockingSinkState::new(op.make_state()?); - while let Some(morsel) = input_receiver.recv().await { + while let Some(morsel) = input_receiver.recv_async().await.ok() { let op = op.clone(); let morsel = morsel.clone(); let span = span.clone(); @@ -145,14 +145,14 @@ impl BlockingSinkNode { let mut send_to_next_worker = |data: PipelineResultType| { let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); - next_worker_sender.send(data) + next_worker_sender.send_async(data) }; let mut buffer = morsel_size.map(RowBasedBuffer::new); while let Some(morsel) = receiver.recv().await { if morsel.should_broadcast() { for worker_sender in &worker_senders { - let _ = worker_sender.send(morsel.clone()).await; + let _ = worker_sender.send_async(morsel.clone()).await; } } else if let Some(buffer) = buffer.as_mut() { buffer.push(morsel.as_data().clone()); @@ -205,21 +205,24 @@ impl PipelineNode for BlockingSinkNode { fn start( &mut self, - maintain_order: bool, + _maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + ) -> crate::Result> { let child = self.child.as_mut(); - let child_results_receiver = child - .start(false, runtime_handle)? - .get_receiver_with_stats(&self.runtime_stats); + let child_results_receiver = child.start(false, runtime_handle)?; + let child_results_receiver = + CountingReceiver::new(child_results_receiver, self.runtime_stats.clone()); - let mut destination_channel = PipelineChannel::new(1, maintain_order); + let (destination_sender, destination_receiver) = create_channel(1); let destination_sender = - destination_channel.get_next_sender_with_stats(&self.runtime_stats); + CountingSender::new(destination_sender, self.runtime_stats.clone()); let op = self.op.clone(); let runtime_stats = self.runtime_stats.clone(); let num_workers = op.max_concurrency(); - let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); + let (input_senders, input_receivers) = { + let (tx, rx) = create_channel(num_workers); + (0..num_workers).map(|_| (tx.clone(), rx.clone())).unzip() + }; let morsel_size = runtime_handle.determine_morsel_size(op.morsel_size()); runtime_handle.spawn( Self::forward_input_to_workers(child_results_receiver, input_senders, morsel_size), @@ -256,7 +259,7 @@ impl PipelineNode for BlockingSinkNode { }, self.name(), ); - Ok(destination_channel) + Ok(destination_receiver) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 112c6d5905..ccee886b25 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -9,9 +9,9 @@ use tracing::{info_span, instrument}; use crate::{ buffer::RowBasedBuffer, - channel::{create_channel, PipelineChannel, Receiver, Sender}, + channel::{create_channel, create_ordering_aware_channel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, - runtime_stats::{CountingReceiver, RuntimeStatsContext}, + runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, ExecutionRuntimeHandle, JoinSnafu, TaskSet, NUM_CPUS, }; @@ -109,7 +109,7 @@ impl StreamingSinkNode { #[instrument(level = "info", skip_all, name = "StreamingSink::run_worker")] async fn run_worker( op: Arc, - mut input_receiver: Receiver<(usize, PipelineResultType)>, + input_receiver: Receiver<(usize, PipelineResultType)>, output_sender: Sender>, rt_context: Arc, ) -> DaftResult> { @@ -117,7 +117,7 @@ impl StreamingSinkNode { let compute_runtime = get_compute_runtime(); let state_wrapper = StreamingSinkState::new(op.make_state()); let mut finished = false; - while let Some((idx, morsel)) = input_receiver.recv().await { + while let Some((idx, morsel)) = input_receiver.recv_async().await.ok() { if finished { break; } @@ -134,16 +134,16 @@ impl StreamingSinkNode { match result { StreamingSinkOutput::NeedMoreInput(mp) => { if let Some(mp) = mp { - let _ = output_sender.send(mp).await; + let _ = output_sender.send_async(mp).await; } break; } StreamingSinkOutput::HasMoreOutput(mp) => { - let _ = output_sender.send(mp).await; + let _ = output_sender.send_async(mp).await; } StreamingSinkOutput::Finished(mp) => { if let Some(mp) = mp { - let _ = output_sender.send(mp).await; + let _ = output_sender.send_async(mp).await; } finished = true; break; @@ -164,19 +164,18 @@ impl StreamingSinkNode { fn spawn_workers( op: Arc, input_receivers: Vec>, + output_senders: Vec>>, task_set: &mut TaskSet>>, stats: Arc, - ) -> Receiver> { - let (output_sender, output_receiver) = create_channel(input_receivers.len()); - for input_receiver in input_receivers { + ) { + for (input_receiver, output_sender) in input_receivers.into_iter().zip(output_senders) { task_set.spawn(Self::run_worker( op.clone(), input_receiver, - output_sender.clone(), + output_sender, stats.clone(), )); } - output_receiver } // Forwards input from the children to the workers in a round-robin fashion. @@ -190,7 +189,7 @@ impl StreamingSinkNode { let mut send_to_next_worker = |idx, data: PipelineResultType| { let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); - next_worker_sender.send((idx, data)) + next_worker_sender.send_async((idx, data)) }; for (idx, mut receiver) in receivers.into_iter().enumerate() { @@ -198,7 +197,7 @@ impl StreamingSinkNode { while let Some(morsel) = receiver.recv().await { if morsel.should_broadcast() { for worker_sender in &worker_senders { - let _ = worker_sender.send((idx, morsel.clone())).await; + let _ = worker_sender.send_async((idx, morsel.clone())).await; } } else if let Some(buffer) = buffer.as_mut() { buffer.push(morsel.as_data().clone()); @@ -259,22 +258,30 @@ impl PipelineNode for StreamingSinkNode { &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + ) -> crate::Result> { let mut child_result_receivers = Vec::with_capacity(self.children.len()); for child in &mut self.children { let child_result_channel = child.start(maintain_order, runtime_handle)?; - child_result_receivers - .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats.clone())); + let counting_receiver = + CountingReceiver::new(child_result_channel, self.runtime_stats.clone()); + child_result_receivers.push(counting_receiver); } - let mut destination_channel = PipelineChannel::new(1, maintain_order); + let (destination_sender, destination_receiver) = create_channel(1); let destination_sender = - destination_channel.get_next_sender_with_stats(&self.runtime_stats); + CountingSender::new(destination_sender, self.runtime_stats.clone()); let op = self.op.clone(); let runtime_stats = self.runtime_stats.clone(); let num_workers = op.max_concurrency(); - let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); + let (input_senders, input_receivers) = if maintain_order { + (0..num_workers).map(|_| create_channel(1)).unzip() + } else { + let (tx, rx) = create_channel(num_workers); + (vec![tx; 1], vec![rx; num_workers]) + }; + let (output_senders, mut output_receiver) = + create_ordering_aware_channel(maintain_order, num_workers); let morsel_size = runtime_handle.determine_morsel_size(op.morsel_size()); runtime_handle.spawn( Self::forward_input_to_workers(child_result_receivers, input_senders, morsel_size), @@ -283,9 +290,10 @@ impl PipelineNode for StreamingSinkNode { runtime_handle.spawn( async move { let mut task_set = TaskSet::new(); - let mut output_receiver = Self::spawn_workers( + Self::spawn_workers( op.clone(), input_receivers, + output_senders, &mut task_set, runtime_stats.clone(), ); @@ -315,7 +323,7 @@ impl PipelineNode for StreamingSinkNode { }, self.name(), ); - Ok(destination_channel) + Ok(destination_receiver) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index 3be2f61691..ffe1678f9e 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -14,7 +14,6 @@ use daft_parquet::read::{read_parquet_bulk_async, ParquetSchemaInferenceOptions} use daft_scan::{storage_config::StorageConfig, ChunkSpec, ScanTask}; use futures::{Stream, StreamExt}; use snafu::ResultExt; -use tokio_stream::wrappers::ReceiverStream; use tracing::instrument; use crate::{ @@ -49,12 +48,12 @@ impl ScanTaskSource { stream_scan_task(scan_task, Some(io_stats), delete_map, maintain_order).await?; let mut has_data = false; while let Some(partition) = stream.next().await { - let _ = sender.send(partition?).await; + let _ = sender.send_async(partition?).await; has_data = true; } if !has_data { let empty = Arc::new(MicroPartition::empty(Some(schema.clone()))); - let _ = sender.send(empty).await; + let _ = sender.send_async(empty).await; } Ok(()) } @@ -105,7 +104,7 @@ impl Source for ScanTaskSource { self.name(), ); - let stream = futures::stream::iter(receivers.into_iter().map(ReceiverStream::new)); + let stream = futures::stream::iter(receivers.into_iter().map(|r| r.into_stream())); Ok(Box::pin(stream.flatten())) } diff --git a/src/daft-local-execution/src/sources/source.rs b/src/daft-local-execution/src/sources/source.rs index 8c55401db2..8e34edaddd 100644 --- a/src/daft-local-execution/src/sources/source.rs +++ b/src/daft-local-execution/src/sources/source.rs @@ -6,7 +6,9 @@ use daft_micropartition::MicroPartition; use futures::{stream::BoxStream, StreamExt}; use crate::{ - channel::PipelineChannel, pipeline::PipelineNode, runtime_stats::RuntimeStatsContext, + channel::{create_channel, Receiver}, + pipeline::{PipelineNode, PipelineResultType}, + runtime_stats::{CountingSender, RuntimeStatsContext}, ExecutionRuntimeHandle, }; @@ -69,13 +71,13 @@ impl PipelineNode for SourceNode { &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + ) -> crate::Result> { let mut source_stream = self.source .get_data(maintain_order, runtime_handle, self.io_stats.clone())?; - let mut channel = PipelineChannel::new(1, maintain_order); - let counting_sender = channel.get_next_sender_with_stats(&self.runtime_stats); + let (tx, rx) = create_channel(1); + let counting_sender = CountingSender::new(tx, self.runtime_stats.clone()); runtime_handle.spawn( async move { while let Some(part) = source_stream.next().await { @@ -85,7 +87,7 @@ impl PipelineNode for SourceNode { }, self.name(), ); - Ok(channel) + Ok(rx) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/tests/benchmarks/test_local_tpch.py b/tests/benchmarks/test_local_tpch.py index 07165d9ebc..f5fb87f35b 100644 --- a/tests/benchmarks/test_local_tpch.py +++ b/tests/benchmarks/test_local_tpch.py @@ -20,7 +20,7 @@ IS_CI = True if os.getenv("CI") else False -SCALE_FACTOR = 0.2 +SCALE_FACTOR = 1.0 ENGINES = ["native"] if IS_CI else ["native", "python"] NUM_PARTS = [1] if IS_CI else [1, 2] SOURCE_TYPES = ["in-memory"] if IS_CI else ["parquet", "in-memory"] @@ -93,7 +93,7 @@ def _get_df(tbl_name: str): return _get_df, num_parts -TPCH_QUESTIONS = list(range(1, 11)) +TPCH_QUESTIONS = [1] @pytest.mark.skipif( From 9ebdccf8ec6053c2937915499fa97ef7137c7948 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 25 Oct 2024 21:41:37 -0700 Subject: [PATCH 04/10] loole channel, dispatcher, probe bridge --- Cargo.lock | 1 + src/daft-local-execution/Cargo.toml | 1 + src/daft-local-execution/src/channel.rs | 78 +++++++++++-- src/daft-local-execution/src/dispatcher.rs | 63 +++++++++++ .../anti_semi_hash_join_probe.rs | 60 +++++----- .../src/intermediate_ops/explode.rs | 10 +- .../src/intermediate_ops/filter.rs | 6 +- .../intermediate_ops/inner_hash_join_probe.rs | 54 ++++----- .../src/intermediate_ops/intermediate_op.rs | 96 +++++----------- .../src/intermediate_ops/pivot.rs | 6 +- .../src/intermediate_ops/project.rs | 7 +- .../src/intermediate_ops/sample.rs | 9 +- .../src/intermediate_ops/unpivot.rs | 6 +- src/daft-local-execution/src/lib.rs | 48 +++++++- src/daft-local-execution/src/pipeline.rs | 60 +++------- src/daft-local-execution/src/run.rs | 6 +- src/daft-local-execution/src/runtime_stats.rs | 48 +++----- .../src/sinks/aggregate.rs | 6 +- .../src/sinks/blocking_sink.rs | 72 +++--------- src/daft-local-execution/src/sinks/concat.rs | 42 +++---- .../src/sinks/hash_join_build.rs | 10 +- src/daft-local-execution/src/sinks/limit.rs | 8 +- .../src/sinks/outer_hash_join_probe.rs | 107 +++++++----------- src/daft-local-execution/src/sinks/sort.rs | 6 +- .../src/sinks/streaming_sink.rs | 83 ++++---------- .../src/sources/source.rs | 10 +- 26 files changed, 425 insertions(+), 478 deletions(-) create mode 100644 src/daft-local-execution/src/dispatcher.rs diff --git a/Cargo.lock b/Cargo.lock index 3186bf148c..e90546b274 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1968,6 +1968,7 @@ dependencies = [ name = "daft-local-execution" version = "0.3.0-dev0" dependencies = [ + "async-trait", "common-daft-config", "common-display", "common-error", diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index 3516de9079..d8bbe28174 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -1,4 +1,5 @@ [dependencies] +async-trait = {workspace = true} common-daft-config = {path = "../common/daft-config", default-features = false} common-display = {path = "../common/display", default-features = false} common-error = {path = "../common/error", default-features = false} diff --git a/src/daft-local-execution/src/channel.rs b/src/daft-local-execution/src/channel.rs index c518de28e6..4d37a4985a 100644 --- a/src/daft-local-execution/src/channel.rs +++ b/src/daft-local-execution/src/channel.rs @@ -1,11 +1,13 @@ -pub type Sender = loole::Sender; -pub type Receiver = loole::Receiver; +use loole::SendError; -pub fn create_channel(buffer_size: usize) -> (Sender, Receiver) { +pub(crate) type Sender = loole::Sender; +pub(crate) type Receiver = loole::Receiver; + +pub(crate) fn create_channel(buffer_size: usize) -> (Sender, Receiver) { loole::bounded(buffer_size) } -pub fn create_ordering_aware_channel( +pub(crate) fn create_ordering_aware_receiver_channel( ordered: bool, buffer_size: usize, ) -> (Vec>, OrderingAwareReceiver) { @@ -27,13 +29,69 @@ pub fn create_ordering_aware_channel( } } -pub enum OrderingAwareReceiver { +pub(crate) fn create_ordering_aware_sender_channel( + ordered: bool, + buffer_size: usize, +) -> (OrderingAwareSender, Vec>) { + match ordered { + true => { + let (sender, receivers) = (0..buffer_size).map(|_| create_channel::(1)).unzip(); + ( + OrderingAwareSender::InOrder(RoundRobinSender::new(sender)), + receivers, + ) + } + false => { + let (sender, receiver) = create_channel::(buffer_size); + ( + OrderingAwareSender::OutOfOrder(sender), + (0..buffer_size).map(|_| receiver.clone()).collect(), + ) + } + } +} + +pub(crate) enum OrderingAwareSender { + InOrder(RoundRobinSender), + OutOfOrder(Sender), +} + +impl OrderingAwareSender { + pub(crate) async fn send(&mut self, val: T) -> Result<(), SendError> { + match self { + Self::InOrder(rr) => rr.send(val).await, + Self::OutOfOrder(s) => s.send_async(val).await, + } + } +} + +pub(crate) struct RoundRobinSender { + senders: Vec>, + next_sender_idx: usize, +} + +impl RoundRobinSender { + fn new(senders: Vec>) -> Self { + Self { + senders, + next_sender_idx: 0, + } + } + + async fn send(&mut self, val: T) -> Result<(), SendError> { + let next_sender_idx = self.next_sender_idx; + self.next_sender_idx = (next_sender_idx + 1) % self.senders.len(); + self.senders[next_sender_idx].send_async(val).await + } +} + +pub(crate) enum OrderingAwareReceiver { InOrder(RoundRobinReceiver), OutOfOrder(Receiver), } impl OrderingAwareReceiver { - pub async fn recv(&mut self) -> Option { + pub(crate) async fn recv(&mut self) -> Option { match self { Self::InOrder(rr) => rr.recv().await, Self::OutOfOrder(r) => r.recv_async().await.ok(), @@ -41,14 +99,14 @@ impl OrderingAwareReceiver { } } -pub struct RoundRobinReceiver { +pub(crate) struct RoundRobinReceiver { receivers: Vec>, curr_receiver_idx: usize, is_done: bool, } impl RoundRobinReceiver { - pub fn new(receivers: Vec>) -> Self { + fn new(receivers: Vec>) -> Self { Self { receivers, curr_receiver_idx: 0, @@ -56,13 +114,13 @@ impl RoundRobinReceiver { } } - pub async fn recv(&mut self) -> Option { + async fn recv(&mut self) -> Option { if self.is_done { return None; } for i in 0..self.receivers.len() { let next_idx = (i + self.curr_receiver_idx) % self.receivers.len(); - if let Some(val) = self.receivers[next_idx].recv_async().await.ok() { + if let Ok(val) = self.receivers[next_idx].recv_async().await { self.curr_receiver_idx = (next_idx + 1) % self.receivers.len(); return Some(val); } diff --git a/src/daft-local-execution/src/dispatcher.rs b/src/daft-local-execution/src/dispatcher.rs new file mode 100644 index 0000000000..38e4c4aa89 --- /dev/null +++ b/src/daft-local-execution/src/dispatcher.rs @@ -0,0 +1,63 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_micropartition::MicroPartition; + +use crate::{ + buffer::RowBasedBuffer, channel::OrderingAwareSender, runtime_stats::CountingReceiver, +}; + +pub(crate) async fn dispatch( + receivers: Vec, + worker_sender: OrderingAwareSender<(usize, Arc)>, + morsel_size: Option, +) -> DaftResult<()> { + let mut dispatcher = Dispatcher::new(worker_sender, morsel_size); + for (idx, receiver) in receivers.into_iter().enumerate() { + while let Some(morsel) = receiver.recv().await { + dispatcher.handle_morsel(idx, morsel).await; + } + dispatcher.finalize(idx).await; + } + Ok(()) +} + +pub(crate) struct Dispatcher { + worker_sender: OrderingAwareSender<(usize, Arc)>, + buffer: Option, +} + +impl Dispatcher { + fn new( + worker_sender: OrderingAwareSender<(usize, Arc)>, + morsel_size: Option, + ) -> Self { + let buffer = morsel_size.map(RowBasedBuffer::new); + Self { + worker_sender, + buffer, + } + } + + async fn handle_morsel(&mut self, idx: usize, morsel: Arc) { + if let Some(buffer) = self.buffer.as_mut() { + buffer.push(morsel); + if let Some(ready) = buffer.pop_enough().unwrap() { + for r in ready { + let _ = self.worker_sender.send((idx, r)).await; + } + } + } else { + let _ = self.worker_sender.send((idx, morsel)).await; + } + } + + async fn finalize(&mut self, idx: usize) { + if let Some(buffer) = self.buffer.as_mut() { + // Clear all remaining morsels + if let Some(last_morsel) = buffer.pop_all().unwrap() { + let _ = self.worker_sender.send((idx, last_morsel)).await; + } + } + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index 45c71d75df..f80dafe02c 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -12,28 +12,19 @@ use super::intermediate_op::{ DynIntermediateOpState, IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, }; -use crate::pipeline::PipelineResultType; +use crate::ProbeStateBridgeRef; -enum AntiSemiProbeState { - Building, - ReadyToProbe(Arc), +struct AntiSemiProbeState { + probeable: Arc, } impl AntiSemiProbeState { - fn set_table(&mut self, table: &Arc) { - if matches!(self, Self::Building) { - *self = Self::ReadyToProbe(table.clone()); - } else { - panic!("AntiSemiProbeState should only be in Building state when setting table") - } + fn new(probeable: Arc) -> Self { + Self { probeable } } fn get_probeable(&self) -> &Arc { - if let Self::ReadyToProbe(probeable) = self { - probeable - } else { - panic!("AntiSemiProbeState should only be in ReadyToProbe state when getting probeable") - } + &self.probeable } } @@ -43,20 +34,27 @@ impl DynIntermediateOpState for AntiSemiProbeState { } } -pub struct AntiSemiProbeOperator { +pub(crate) struct AntiSemiProbeOperator { probe_on: Vec, is_semi: bool, output_schema: SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, } impl AntiSemiProbeOperator { const DEFAULT_GROWABLE_SIZE: usize = 20; - pub fn new(probe_on: Vec, join_type: &JoinType, output_schema: &SchemaRef) -> Self { + pub fn new( + probe_on: Vec, + join_type: &JoinType, + output_schema: &SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, + ) -> Self { Self { probe_on, is_semi: *join_type == JoinType::Semi, output_schema: output_schema.clone(), + probe_state_bridge, } } @@ -103,28 +101,22 @@ impl AntiSemiProbeOperator { } } +#[async_trait::async_trait] impl IntermediateOperator for AntiSemiProbeOperator { #[instrument(skip_all, name = "AntiSemiOperator::execute")] fn execute( &self, - idx: usize, - input: &PipelineResultType, + _idx: usize, + input: &Arc, state: &IntermediateOperatorState, ) -> DaftResult { state.with_state_mut::(|state| { - if idx == 0 { - let probe_state = input.as_probe_state(); - state.set_table(probe_state.get_probeable()); - Ok(IntermediateOperatorResult::NeedMoreInput(None)) - } else { - let input = input.as_data(); - if input.is_empty() { - let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); - return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); - } - let out = self.probe_anti_semi(input, state)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); } + let out = self.probe_anti_semi(input, state)?; + Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) }) } @@ -132,7 +124,9 @@ impl IntermediateOperator for AntiSemiProbeOperator { "AntiSemiProbeOperator" } - fn make_state(&self) -> Box { - Box::new(AntiSemiProbeState::Building) + async fn make_state(&self) -> Box { + let probe_state = self.probe_state_bridge.get_probe_state().await; + let probe_table = probe_state.get_probeable(); + Box::new(AntiSemiProbeState::new(probe_table.clone())) } } diff --git a/src/daft-local-execution/src/intermediate_ops/explode.rs b/src/daft-local-execution/src/intermediate_ops/explode.rs index 30ec8b02b5..dadd3a6097 100644 --- a/src/daft-local-execution/src/intermediate_ops/explode.rs +++ b/src/daft-local-execution/src/intermediate_ops/explode.rs @@ -3,19 +3,19 @@ use std::sync::Arc; use common_error::DaftResult; use daft_dsl::ExprRef; use daft_functions::list::explode; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, }; -use crate::pipeline::PipelineResultType; -pub struct ExplodeOperator { +pub(crate) struct ExplodeOperator { to_explode: Vec, } impl ExplodeOperator { - pub fn new(to_explode: Vec) -> Self { + pub(crate) fn new(to_explode: Vec) -> Self { Self { to_explode: to_explode.into_iter().map(explode).collect(), } @@ -27,10 +27,10 @@ impl IntermediateOperator for ExplodeOperator { fn execute( &self, _idx: usize, - input: &PipelineResultType, + input: &Arc, _state: &IntermediateOperatorState, ) -> DaftResult { - let out = input.as_data().explode(&self.to_explode)?; + let out = input.explode(&self.to_explode)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/intermediate_ops/filter.rs b/src/daft-local-execution/src/intermediate_ops/filter.rs index aad3bd7e7d..cca2b236a3 100644 --- a/src/daft-local-execution/src/intermediate_ops/filter.rs +++ b/src/daft-local-execution/src/intermediate_ops/filter.rs @@ -2,12 +2,12 @@ use std::sync::Arc; use common_error::DaftResult; use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, }; -use crate::pipeline::PipelineResultType; pub struct FilterOperator { predicate: ExprRef, @@ -24,10 +24,10 @@ impl IntermediateOperator for FilterOperator { fn execute( &self, _idx: usize, - input: &PipelineResultType, + input: &Arc, _state: &IntermediateOperatorState, ) -> DaftResult { - let out = input.as_data().filter(&[self.predicate.clone()])?; + let out = input.filter(&[self.predicate.clone()])?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs index c257b1e618..1a70b5056e 100644 --- a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs @@ -12,28 +12,19 @@ use super::intermediate_op::{ DynIntermediateOpState, IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, }; -use crate::pipeline::PipelineResultType; +use crate::ProbeStateBridgeRef; -enum InnerHashJoinProbeState { - Building, - ReadyToProbe(Arc), +struct InnerHashJoinProbeState { + probe_state: Arc, } impl InnerHashJoinProbeState { - fn set_probe_state(&mut self, probe_state: Arc) { - if matches!(self, Self::Building) { - *self = Self::ReadyToProbe(probe_state); - } else { - panic!("InnerHashJoinProbeState should only be in Building state when setting table") - } + fn new(probe_state: Arc) -> Self { + Self { probe_state } } fn get_probe_state(&self) -> &Arc { - if let Self::ReadyToProbe(probe_state) = self { - probe_state - } else { - panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") - } + &self.probe_state } } @@ -50,6 +41,7 @@ pub struct InnerHashJoinProbeOperator { right_non_join_columns: Vec, build_on_left: bool, output_schema: SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, } impl InnerHashJoinProbeOperator { @@ -62,6 +54,7 @@ impl InnerHashJoinProbeOperator { build_on_left: bool, common_join_keys: IndexSet, output_schema: &SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, ) -> Self { let left_non_join_columns = left_schema .fields @@ -83,6 +76,7 @@ impl InnerHashJoinProbeOperator { right_non_join_columns, build_on_left, output_schema: output_schema.clone(), + probe_state_bridge, } } @@ -159,29 +153,22 @@ impl InnerHashJoinProbeOperator { } } +#[async_trait::async_trait] impl IntermediateOperator for InnerHashJoinProbeOperator { #[instrument(skip_all, name = "InnerHashJoinOperator::execute")] fn execute( &self, - idx: usize, - input: &PipelineResultType, + _idx: usize, + input: &Arc, state: &IntermediateOperatorState, ) -> DaftResult { - state.with_state_mut::(|state| match idx { - 0 => { - let probe_state = input.as_probe_state(); - state.set_probe_state(probe_state.clone()); - Ok(IntermediateOperatorResult::NeedMoreInput(None)) - } - _ => { - let input = input.as_data(); - if input.is_empty() { - let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); - return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); - } - let out = self.probe_inner(input, state)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) + state.with_state_mut::(|state| { + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); } + let out = self.probe_inner(input, state)?; + Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) }) } @@ -189,7 +176,8 @@ impl IntermediateOperator for InnerHashJoinProbeOperator { "InnerHashJoinProbeOperator" } - fn make_state(&self) -> Box { - Box::new(InnerHashJoinProbeState::Building) + async fn make_state(&self) -> Box { + let probe_state = self.probe_state_bridge.get_probe_state().await; + Box::new(InnerHashJoinProbeState::new(probe_state)) } } diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 1b672add3f..f04f2a2cdf 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -1,5 +1,6 @@ use std::sync::{Arc, Mutex}; +use async_trait::async_trait; use common_daft_config::DaftExecutionConfig; use common_display::tree::TreeDisplay; use common_error::DaftResult; @@ -8,9 +9,12 @@ use daft_micropartition::MicroPartition; use tracing::{info_span, instrument}; use crate::{ - buffer::RowBasedBuffer, - channel::{create_channel, create_ordering_aware_channel, Receiver, Sender}, - pipeline::{PipelineNode, PipelineResultType}, + channel::{ + create_channel, create_ordering_aware_receiver_channel, + create_ordering_aware_sender_channel, OrderingAwareSender, Receiver, Sender, + }, + dispatcher::dispatch, + pipeline::PipelineNode, runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, ExecutionRuntimeHandle, NUM_CPUS, }; @@ -56,15 +60,16 @@ pub enum IntermediateOperatorResult { HasMoreOutput(Arc), } +#[async_trait] pub trait IntermediateOperator: Send + Sync { fn execute( &self, idx: usize, - input: &PipelineResultType, + input: &Arc, state: &IntermediateOperatorState, ) -> DaftResult; fn name(&self) -> &'static str; - fn make_state(&self) -> Box { + async fn make_state(&self) -> Box { Box::new(DefaultIntermediateOperatorState {}) } fn morsel_size(&self) -> Option { @@ -106,14 +111,14 @@ impl IntermediateNode { #[instrument(level = "info", skip_all, name = "IntermediateOperator::run_worker")] pub async fn run_worker( op: Arc, - receiver: Receiver<(usize, PipelineResultType)>, - sender: Sender, + receiver: Receiver<(usize, Arc)>, + sender: Sender>, rt_context: Arc, ) -> DaftResult<()> { let span = info_span!("IntermediateOp::execute"); let compute_runtime = get_compute_runtime(); - let state_wrapper = IntermediateOperatorState::new(op.make_state()); - while let Some((idx, morsel)) = receiver.recv_async().await.ok() { + let state_wrapper = IntermediateOperatorState::new(op.make_state().await); + while let Ok((idx, morsel)) = receiver.recv_async().await { loop { let op = op.clone(); let morsel = morsel.clone(); @@ -126,14 +131,14 @@ impl IntermediateNode { let result = compute_runtime.await_on(fut).await??; match result { IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { - let _ = sender.send_async(mp.into()).await; + let _ = sender.send_async(mp).await; break; } IntermediateOperatorResult::NeedMoreInput(None) => { break; } IntermediateOperatorResult::HasMoreOutput(mp) => { - let _ = sender.send_async(mp.into()).await; + let _ = sender.send_async(mp).await; } } } @@ -144,16 +149,12 @@ impl IntermediateNode { pub fn spawn_workers( &self, num_workers: usize, - output_senders: Vec>, + output_senders: Vec>>, runtime_handle: &mut ExecutionRuntimeHandle, maintain_order: bool, - ) -> Vec> { - let (worker_senders, worker_receivers): (Vec<_>, Vec<_>) = if maintain_order { - (0..num_workers).map(|_| create_channel(1)).unzip() - } else { - let (sender, receiver) = create_channel(num_workers); - (vec![sender; 1], vec![receiver; num_workers]) - }; + ) -> OrderingAwareSender<(usize, Arc)> { + let (worker_sender, worker_receivers) = + create_ordering_aware_sender_channel(maintain_order, num_workers); for (receiver, destination_channel) in worker_receivers.into_iter().zip(output_senders) { runtime_handle.spawn( Self::run_worker( @@ -165,47 +166,7 @@ impl IntermediateNode { self.intermediate_op.name(), ); } - worker_senders - } - - pub async fn send_to_workers( - receivers: Vec, - worker_senders: Vec>, - morsel_size: Option, - ) -> DaftResult<()> { - let mut next_worker_idx = 0; - let mut send_to_next_worker = |idx, data: PipelineResultType| { - let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); - next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); - next_worker_sender.send_async((idx, data)) - }; - - for (idx, mut receiver) in receivers.into_iter().enumerate() { - let mut buffer = morsel_size.map(RowBasedBuffer::new); - while let Some(morsel) = receiver.recv().await { - if morsel.should_broadcast() { - for worker_sender in &worker_senders { - let _ = worker_sender.send_async((idx, morsel.clone())).await; - } - } else if let Some(buffer) = buffer.as_mut() { - buffer.push(morsel.as_data().clone()); - if let Some(ready) = buffer.pop_enough()? { - for r in ready { - let _ = send_to_next_worker(idx, r.into()).await; - } - } - } else { - let _ = send_to_next_worker(idx, morsel).await; - } - } - if let Some(buffer) = buffer.as_mut() { - // Clear all remaining morsels - if let Some(last_morsel) = buffer.pop_all()? { - let _ = send_to_next_worker(idx, last_morsel.into()).await; - } - } - } - Ok(()) + worker_sender } } @@ -244,12 +205,13 @@ impl PipelineNode for IntermediateNode { &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result> { + ) -> crate::Result>> { + let num_workers = *NUM_CPUS; let mut child_result_receivers = Vec::with_capacity(self.children.len()); for child in &mut self.children { - let child_result_channel = child.start(maintain_order, runtime_handle)?; + let child_result_receiver = child.start(maintain_order, runtime_handle)?; child_result_receivers.push(CountingReceiver::new( - child_result_channel, + child_result_receiver, self.runtime_stats.clone(), )); } @@ -257,12 +219,12 @@ impl PipelineNode for IntermediateNode { let counting_sender = CountingSender::new(destination_sender, self.runtime_stats.clone()); let (output_senders, mut output_receiver) = - create_ordering_aware_channel(maintain_order, *NUM_CPUS); - let worker_senders = - self.spawn_workers(*NUM_CPUS, output_senders, runtime_handle, maintain_order); + create_ordering_aware_receiver_channel(maintain_order, num_workers); + let worker_sender = + self.spawn_workers(num_workers, output_senders, runtime_handle, maintain_order); let morsel_size = runtime_handle.determine_morsel_size(self.intermediate_op.morsel_size()); runtime_handle.spawn( - Self::send_to_workers(child_result_receivers, worker_senders, morsel_size), + dispatch(child_result_receivers, worker_sender, morsel_size), self.intermediate_op.name(), ); runtime_handle.spawn( diff --git a/src/daft-local-execution/src/intermediate_ops/pivot.rs b/src/daft-local-execution/src/intermediate_ops/pivot.rs index c54ea3e69b..bd03c835e5 100644 --- a/src/daft-local-execution/src/intermediate_ops/pivot.rs +++ b/src/daft-local-execution/src/intermediate_ops/pivot.rs @@ -2,12 +2,12 @@ use std::sync::Arc; use common_error::DaftResult; use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, }; -use crate::pipeline::PipelineResultType; pub struct PivotOperator { group_by: Vec, @@ -37,10 +37,10 @@ impl IntermediateOperator for PivotOperator { fn execute( &self, _idx: usize, - input: &PipelineResultType, + input: &Arc, _state: &IntermediateOperatorState, ) -> DaftResult { - let out = input.as_data().pivot( + let out = input.pivot( &self.group_by, self.pivot_col.clone(), self.values_col.clone(), diff --git a/src/daft-local-execution/src/intermediate_ops/project.rs b/src/daft-local-execution/src/intermediate_ops/project.rs index 370de989aa..e5f17ea2d3 100644 --- a/src/daft-local-execution/src/intermediate_ops/project.rs +++ b/src/daft-local-execution/src/intermediate_ops/project.rs @@ -2,12 +2,12 @@ use std::sync::Arc; use common_error::DaftResult; use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, }; -use crate::pipeline::PipelineResultType; pub struct ProjectOperator { projection: Vec, @@ -24,10 +24,11 @@ impl IntermediateOperator for ProjectOperator { fn execute( &self, _idx: usize, - input: &PipelineResultType, + input: &Arc, _state: &IntermediateOperatorState, ) -> DaftResult { - let out = input.as_data().eval_expression_list(&self.projection)?; + println!("ProjectOperator::execute"); + let out = input.eval_expression_list(&self.projection)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/intermediate_ops/sample.rs b/src/daft-local-execution/src/intermediate_ops/sample.rs index b0e4610292..573069ea71 100644 --- a/src/daft-local-execution/src/intermediate_ops/sample.rs +++ b/src/daft-local-execution/src/intermediate_ops/sample.rs @@ -1,12 +1,12 @@ use std::sync::Arc; use common_error::DaftResult; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, }; -use crate::pipeline::PipelineResultType; pub struct SampleOperator { fraction: f64, @@ -29,13 +29,10 @@ impl IntermediateOperator for SampleOperator { fn execute( &self, _idx: usize, - input: &PipelineResultType, + input: &Arc, _state: &IntermediateOperatorState, ) -> DaftResult { - let out = - input - .as_data() - .sample_by_fraction(self.fraction, self.with_replacement, self.seed)?; + let out = input.sample_by_fraction(self.fraction, self.with_replacement, self.seed)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/intermediate_ops/unpivot.rs b/src/daft-local-execution/src/intermediate_ops/unpivot.rs index 5171f9ad42..615ac523ee 100644 --- a/src/daft-local-execution/src/intermediate_ops/unpivot.rs +++ b/src/daft-local-execution/src/intermediate_ops/unpivot.rs @@ -2,12 +2,12 @@ use std::sync::Arc; use common_error::DaftResult; use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; use tracing::instrument; use super::intermediate_op::{ IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, }; -use crate::pipeline::PipelineResultType; pub struct UnpivotOperator { ids: Vec, @@ -37,10 +37,10 @@ impl IntermediateOperator for UnpivotOperator { fn execute( &self, _idx: usize, - input: &PipelineResultType, + input: &Arc, _state: &IntermediateOperatorState, ) -> DaftResult { - let out = input.as_data().unpivot( + let out = input.unpivot( &self.ids, &self.values, &self.variable_name, diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index 399363c1fd..183cd37d30 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -1,6 +1,7 @@ #![feature(let_chains)] mod buffer; mod channel; +mod dispatcher; mod intermediate_ops; mod pipeline; mod run; @@ -8,8 +9,11 @@ mod runtime_stats; mod sinks; mod sources; +use std::sync::{Arc, OnceLock}; + use common_daft_config::DaftExecutionConfig; use common_error::{DaftError, DaftResult}; +use daft_table::ProbeState; use lazy_static::lazy_static; pub use run::NativeExecutor; use snafu::{futures::TryFutureExt, Snafu}; @@ -18,6 +22,38 @@ lazy_static! { pub static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get(); } +pub(crate) type ProbeStateBridgeRef = Arc; +pub(crate) struct ProbeStateBridge { + inner: OnceLock>, + notify: tokio::sync::Notify, +} + +impl ProbeStateBridge { + fn new() -> Arc { + Arc::new(Self { + inner: OnceLock::new(), + notify: tokio::sync::Notify::new(), + }) + } + + fn set_probe_state(&self, state: Arc) { + assert!( + !self.inner.set(state).is_err(), + "ProbeStateBridge should be set only once" + ); + self.notify.notify_waiters(); + } + + async fn get_probe_state(&self) -> Arc { + loop { + if let Some(state) = self.inner.get() { + return state.clone(); + } + self.notify.notified().await; + } + } +} + pub(crate) struct TaskSet { inner: tokio::task::JoinSet, } @@ -45,20 +81,20 @@ impl TaskSet { } } -pub struct ExecutionRuntimeHandle { +pub(crate) struct ExecutionRuntimeHandle { worker_set: TaskSet>, default_morsel_size: usize, } impl ExecutionRuntimeHandle { #[must_use] - pub fn new(default_morsel_size: usize) -> Self { + fn new(default_morsel_size: usize) -> Self { Self { worker_set: TaskSet::new(), default_morsel_size, } } - pub fn spawn( + fn spawn( &mut self, task: impl std::future::Future> + 'static, node_name: &str, @@ -68,15 +104,15 @@ impl ExecutionRuntimeHandle { .spawn(task.with_context(|_| PipelineExecutionSnafu { node_name })); } - pub async fn join_next(&mut self) -> Option, tokio::task::JoinError>> { + async fn join_next(&mut self) -> Option, tokio::task::JoinError>> { self.worker_set.join_next().await } - pub async fn shutdown(&mut self) { + async fn shutdown(&mut self) { self.worker_set.shutdown().await; } - pub fn determine_morsel_size(&self, operator_morsel_size: Option) -> Option { + fn determine_morsel_size(&self, operator_morsel_size: Option) -> Option { match operator_morsel_size { None => None, Some(_) diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 7f92f2d2ed..dd6e257267 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -14,7 +14,6 @@ use daft_physical_plan::{ LocalPhysicalPlan, Pivot, Project, Sample, Sort, UnGroupedAggregate, Unpivot, }; use daft_plan::JoinType; -use daft_table::ProbeState; use indexmap::IndexSet; use snafu::ResultExt; @@ -33,60 +32,22 @@ use crate::{ streaming_sink::StreamingSinkNode, }, sources::{empty_scan::EmptyScanSource, in_memory::InMemorySource}, - ExecutionRuntimeHandle, PipelineCreationSnafu, + ExecutionRuntimeHandle, PipelineCreationSnafu, ProbeStateBridge, }; -#[derive(Clone)] -pub enum PipelineResultType { - Data(Arc), - ProbeState(Arc), -} - -impl From> for PipelineResultType { - fn from(data: Arc) -> Self { - Self::Data(data) - } -} - -impl From> for PipelineResultType { - fn from(probe_state: Arc) -> Self { - Self::ProbeState(probe_state) - } -} - -impl PipelineResultType { - pub fn as_data(&self) -> &Arc { - match self { - Self::Data(data) => data, - _ => panic!("Expected data"), - } - } - - pub fn as_probe_state(&self) -> &Arc { - match self { - Self::ProbeState(probe_state) => probe_state, - _ => panic!("Expected probe table"), - } - } - - pub fn should_broadcast(&self) -> bool { - matches!(self, Self::ProbeState(_)) - } -} - -pub trait PipelineNode: Sync + Send + TreeDisplay { +pub(crate) trait PipelineNode: Sync + Send + TreeDisplay { fn children(&self) -> Vec<&dyn PipelineNode>; fn name(&self) -> &'static str; fn start( &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result>; + ) -> crate::Result>>; fn as_tree_display(&self) -> &dyn TreeDisplay; } -pub fn viz_pipeline(root: &dyn PipelineNode) -> String { +pub(crate) fn viz_pipeline(root: &dyn PipelineNode) -> String { let mut output = String::new(); let mut visitor = MermaidDisplayVisitor::new( &mut output, @@ -98,7 +59,7 @@ pub fn viz_pipeline(root: &dyn PipelineNode) -> String { output } -pub fn physical_plan_to_pipeline( +pub(crate) fn physical_plan_to_pipeline( physical_plan: &LocalPhysicalPlan, psets: &HashMap>>, ) -> crate::Result> { @@ -296,7 +257,13 @@ pub fn physical_plan_to_pipeline( .collect::>(); // we should move to a builder pattern - let build_sink = HashJoinBuildSink::new(key_schema, casted_build_on, join_type)?; + let probe_state_bridge = ProbeStateBridge::new(); + let build_sink = HashJoinBuildSink::new( + key_schema, + casted_build_on, + join_type, + probe_state_bridge.clone(), + )?; let build_child_node = physical_plan_to_pipeline(build_child, psets)?; let build_node = BlockingSinkNode::new(Arc::new(build_sink), build_child_node).boxed(); @@ -309,6 +276,7 @@ pub fn physical_plan_to_pipeline( casted_probe_on, join_type, schema, + probe_state_bridge, )), vec![build_node, probe_child_node], ) @@ -321,6 +289,7 @@ pub fn physical_plan_to_pipeline( build_on_left, common_join_keys, schema, + probe_state_bridge, )), vec![build_node, probe_child_node], ) @@ -334,6 +303,7 @@ pub fn physical_plan_to_pipeline( *join_type, common_join_keys, schema, + probe_state_bridge, )), vec![build_node, probe_child_node], ) diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 62d10b381b..9e4a564335 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -111,7 +111,7 @@ fn should_enable_explain_analyze() -> bool { } } -pub fn run_local( +fn run_local( physical_plan: &LocalPhysicalPlan, psets: HashMap>>, cfg: Arc, @@ -129,8 +129,8 @@ pub fn run_local( let mut runtime_handle = ExecutionRuntimeHandle::new(cfg.default_morsel_size); let receiver = pipeline.start(true, &mut runtime_handle)?; - while let Some(val) = receiver.recv_async().await.ok() { - let _ = tx.send_async(val.as_data().clone()).await; + while let Ok(val) = receiver.recv_async().await { + let _ = tx.send_async(val).await; } while let Some(result) = runtime_handle.join_next().await { diff --git a/src/daft-local-execution/src/runtime_stats.rs b/src/daft-local-execution/src/runtime_stats.rs index 23af82d59a..a23bdfe030 100644 --- a/src/daft-local-execution/src/runtime_stats.rs +++ b/src/daft-local-execution/src/runtime_stats.rs @@ -5,25 +5,23 @@ use std::{ time::Instant, }; +use daft_micropartition::MicroPartition; use loole::SendError; -use crate::{ - channel::{Receiver, Sender}, - pipeline::PipelineResultType, -}; +use crate::channel::{Receiver, Sender}; #[derive(Default)] -pub struct RuntimeStatsContext { +pub(crate) struct RuntimeStatsContext { rows_received: AtomicU64, rows_emitted: AtomicU64, cpu_us: AtomicU64, } #[derive(Debug)] -pub struct RuntimeStats { - pub rows_received: u64, - pub rows_emitted: u64, - pub cpu_us: u64, +pub(crate) struct RuntimeStats { + rows_received: u64, + rows_emitted: u64, + cpu_us: u64, } impl RuntimeStats { @@ -108,54 +106,44 @@ impl RuntimeStatsContext { } } -pub struct CountingSender { - sender: Sender, +pub(crate) struct CountingSender { + sender: Sender>, rt: Arc, } impl CountingSender { - pub(crate) fn new(sender: Sender, rt: Arc) -> Self { + pub(crate) fn new(sender: Sender>, rt: Arc) -> Self { Self { sender, rt } } #[inline] pub(crate) async fn send( &self, - v: PipelineResultType, - ) -> Result<(), SendError> { - let len = match v { - PipelineResultType::Data(ref mp) => mp.len(), - PipelineResultType::ProbeState(ref state) => { - state.get_tables().iter().map(|t| t.len()).sum() - } - }; + v: Arc, + ) -> Result<(), SendError>> { + let len = v.len(); self.sender.send_async(v).await?; self.rt.mark_rows_emitted(len as u64); Ok(()) } } -pub struct CountingReceiver { - receiver: Receiver, +pub(crate) struct CountingReceiver { + receiver: Receiver>, rt: Arc, } impl CountingReceiver { pub(crate) fn new( - receiver: Receiver, + receiver: Receiver>, rt: Arc, ) -> Self { Self { receiver, rt } } #[inline] - pub(crate) async fn recv(&mut self) -> Option { + pub(crate) async fn recv(&self) -> Option> { let v = self.receiver.recv_async().await.ok(); if let Some(ref v) = v { - let len = match v { - PipelineResultType::Data(ref mp) => mp.len(), - PipelineResultType::ProbeState(state) => { - state.get_tables().iter().map(|t| t.len()).sum() - } - }; + let len = v.len(); self.rt.mark_rows_received(len as u64); } v diff --git a/src/daft-local-execution/src/sinks/aggregate.rs b/src/daft-local-execution/src/sinks/aggregate.rs index 8c38c81a4f..d76f5b1e2c 100644 --- a/src/daft-local-execution/src/sinks/aggregate.rs +++ b/src/daft-local-execution/src/sinks/aggregate.rs @@ -10,7 +10,7 @@ use tracing::instrument; use super::blocking_sink::{ BlockingSink, BlockingSinkState, BlockingSinkStatus, DynBlockingSinkState, }; -use crate::{pipeline::PipelineResultType, NUM_CPUS}; +use crate::NUM_CPUS; enum AggregateState { Accumulating(Vec>), @@ -102,7 +102,7 @@ impl BlockingSink for AggregateSink { fn finalize( &self, states: Vec>, - ) -> DaftResult> { + ) -> DaftResult>> { let mut all_parts = vec![]; for mut state in states { let state = state @@ -123,7 +123,7 @@ impl BlockingSink for AggregateSink { )?; let agged = Arc::new(concated.agg(&self.finalize_aggs, &self.finalize_group_by)?); let projected = Arc::new(agged.eval_expression_list(&self.final_projections)?); - Ok(Some(projected.into())) + Ok(Some(projected)) } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index e3901770b1..d0416be997 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -9,9 +9,9 @@ use snafu::ResultExt; use tracing::{info_span, instrument}; use crate::{ - buffer::RowBasedBuffer, - channel::{create_channel, Receiver, Sender}, - pipeline::{PipelineNode, PipelineResultType}, + channel::{create_channel, create_ordering_aware_sender_channel, Receiver}, + dispatcher::dispatch, + pipeline::PipelineNode, runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, ExecutionRuntimeHandle, JoinSnafu, TaskSet, }; @@ -58,7 +58,7 @@ pub trait BlockingSink: Send + Sync { fn finalize( &self, states: Vec>, - ) -> DaftResult>; + ) -> DaftResult>>; fn name(&self) -> &'static str; fn make_state(&self) -> DaftResult>; fn max_concurrency(&self) -> usize; @@ -91,21 +91,18 @@ impl BlockingSinkNode { #[instrument(level = "info", skip_all, name = "BlockingSink::run_worker")] async fn run_worker( op: Arc, - input_receiver: Receiver, + input_receiver: Receiver<(usize, Arc)>, rt_context: Arc, ) -> DaftResult> { let span = info_span!("BlockingSink::Sink"); let compute_runtime = get_compute_runtime(); let state_wrapper = BlockingSinkState::new(op.make_state()?); - while let Some(morsel) = input_receiver.recv_async().await.ok() { + while let Ok((_, morsel)) = input_receiver.recv_async().await { let op = op.clone(); - let morsel = morsel.clone(); let span = span.clone(); let rt_context = rt_context.clone(); let state_wrapper = state_wrapper.clone(); - let fut = async move { - rt_context.in_span(&span, || op.sink(morsel.as_data(), &state_wrapper)) - }; + let fut = async move { rt_context.in_span(&span, || op.sink(&morsel, &state_wrapper)) }; let result = compute_runtime.await_on(fut).await??; match result { BlockingSinkStatus::NeedMoreInput => {} @@ -126,53 +123,14 @@ impl BlockingSinkNode { fn spawn_workers( op: Arc, - input_receivers: Vec>, + input_receivers: Vec)>>, task_set: &mut TaskSet>>, stats: Arc, ) { - for input_receiver in input_receivers { - task_set.spawn(Self::run_worker(op.clone(), input_receiver, stats.clone())); + for receiver in input_receivers { + task_set.spawn(Self::run_worker(op.clone(), receiver, stats.clone())); } } - - // Forwards input from the child to the workers in a round-robin fashion. - pub async fn forward_input_to_workers( - mut receiver: CountingReceiver, - worker_senders: Vec>, - morsel_size: Option, - ) -> DaftResult<()> { - let mut next_worker_idx = 0; - let mut send_to_next_worker = |data: PipelineResultType| { - let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); - next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); - next_worker_sender.send_async(data) - }; - - let mut buffer = morsel_size.map(RowBasedBuffer::new); - while let Some(morsel) = receiver.recv().await { - if morsel.should_broadcast() { - for worker_sender in &worker_senders { - let _ = worker_sender.send_async(morsel.clone()).await; - } - } else if let Some(buffer) = buffer.as_mut() { - buffer.push(morsel.as_data().clone()); - if let Some(ready) = buffer.pop_enough()? { - for r in ready { - let _ = send_to_next_worker(r.into()).await; - } - } - } else { - let _ = send_to_next_worker(morsel).await; - } - } - if let Some(buffer) = buffer.as_mut() { - // Clear all remaining morsels - if let Some(last_morsel) = buffer.pop_all()? { - let _ = send_to_next_worker(last_morsel.into()).await; - } - } - Ok(()) - } } impl TreeDisplay for BlockingSinkNode { @@ -207,7 +165,7 @@ impl PipelineNode for BlockingSinkNode { &mut self, _maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result> { + ) -> crate::Result>> { let child = self.child.as_mut(); let child_results_receiver = child.start(false, runtime_handle)?; let child_results_receiver = @@ -219,13 +177,11 @@ impl PipelineNode for BlockingSinkNode { let op = self.op.clone(); let runtime_stats = self.runtime_stats.clone(); let num_workers = op.max_concurrency(); - let (input_senders, input_receivers) = { - let (tx, rx) = create_channel(num_workers); - (0..num_workers).map(|_| (tx.clone(), rx.clone())).unzip() - }; + let (input_senders, input_receivers) = + create_ordering_aware_sender_channel(false, num_workers); let morsel_size = runtime_handle.determine_morsel_size(op.morsel_size()); runtime_handle.spawn( - Self::forward_input_to_workers(child_results_receiver, input_senders, morsel_size), + dispatch(vec![child_results_receiver], input_senders, morsel_size), self.name(), ); runtime_handle.spawn( diff --git a/src/daft-local-execution/src/sinks/concat.rs b/src/daft-local-execution/src/sinks/concat.rs index a178d287e1..a48b7ba566 100644 --- a/src/daft-local-execution/src/sinks/concat.rs +++ b/src/daft-local-execution/src/sinks/concat.rs @@ -1,18 +1,16 @@ use std::sync::Arc; -use common_error::{DaftError, DaftResult}; +use async_trait::async_trait; +use common_error::DaftResult; use daft_micropartition::MicroPartition; use tracing::instrument; use super::streaming_sink::{ DynStreamingSinkState, StreamingSink, StreamingSinkOutput, StreamingSinkState, }; -use crate::pipeline::PipelineResultType; +use crate::NUM_CPUS; -struct ConcatSinkState { - // The index of the last morsel of data that was received, which should be strictly non-decreasing. - pub curr_idx: usize, -} +struct ConcatSinkState {} impl DynStreamingSinkState for ConcatSinkState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self @@ -21,28 +19,19 @@ impl DynStreamingSinkState for ConcatSinkState { pub struct ConcatSink {} +#[async_trait] impl StreamingSink for ConcatSink { - /// Execute for the ConcatSink operator does not do any computation and simply returns the input data. - /// It only expects that the indices of the input data are strictly non-decreasing. - /// TODO(Colin): If maintain_order is false, technically we could accept any index. Make this optimization later. + /// By default, if the streaming_sink is called with maintain_order = true, input is distributed round-robin to the workers, + /// and the output is received in the same order. Therefore, the 'execute' method does not need to do anything. + /// If maintain_order = false, the input is distributed randomly to the workers, and the output is received in random order. #[instrument(skip_all, name = "ConcatSink::sink")] fn execute( &self, - index: usize, - input: &PipelineResultType, - state_handle: &StreamingSinkState, + _index: usize, + input: &Arc, + _state_handle: &StreamingSinkState, ) -> DaftResult { - state_handle.with_state_mut::(|state| { - // If the index is the same as the current index or one more than the current index, then we can accept the morsel. - if state.curr_idx == index || state.curr_idx + 1 == index { - state.curr_idx = index; - Ok(StreamingSinkOutput::NeedMoreInput(Some( - input.as_data().clone(), - ))) - } else { - Err(DaftError::ComputeError(format!("Concat sink received out-of-order data. Expected index to be {} or {}, but got {}.", state.curr_idx, state.curr_idx + 1, index))) - } - }) + Ok(StreamingSinkOutput::NeedMoreInput(Some(input.clone()))) } fn name(&self) -> &'static str { @@ -56,13 +45,12 @@ impl StreamingSink for ConcatSink { Ok(None) } - fn make_state(&self) -> Box { - Box::new(ConcatSinkState { curr_idx: 0 }) + async fn make_state(&self) -> Box { + Box::new(ConcatSinkState {}) } - /// Since the ConcatSink does not do any computation, it does not need to spawn multiple workers. fn max_concurrency(&self) -> usize { - 1 + *NUM_CPUS } /// The ConcatSink does not do any computation in the sink method, so no need to buffer. diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index c8c584d268..ba54a712ff 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -10,7 +10,7 @@ use daft_table::{make_probeable_builder, ProbeState, ProbeableBuilder, Table}; use super::blocking_sink::{ BlockingSink, BlockingSinkState, BlockingSinkStatus, DynBlockingSinkState, }; -use crate::pipeline::PipelineResultType; +use crate::ProbeStateBridgeRef; enum ProbeTableState { Building { @@ -86,6 +86,7 @@ pub struct HashJoinBuildSink { key_schema: SchemaRef, projection: Vec, join_type: JoinType, + probe_state_bridge: ProbeStateBridgeRef, } impl HashJoinBuildSink { @@ -93,11 +94,13 @@ impl HashJoinBuildSink { key_schema: SchemaRef, projection: Vec, join_type: &JoinType, + probe_state_bridge: ProbeStateBridgeRef, ) -> DaftResult { Ok(Self { key_schema, projection, join_type: *join_type, + probe_state_bridge, }) } } @@ -121,7 +124,7 @@ impl BlockingSink for HashJoinBuildSink { fn finalize( &self, states: Vec>, - ) -> DaftResult> { + ) -> DaftResult>> { assert_eq!(states.len(), 1); let mut state = states.into_iter().next().unwrap(); let probe_table_state = state @@ -130,10 +133,11 @@ impl BlockingSink for HashJoinBuildSink { .expect("State type mismatch"); probe_table_state.finalize()?; if let ProbeTableState::Done { probe_state } = probe_table_state { - Ok(Some(probe_state.clone().into())) + self.probe_state_bridge.set_probe_state(probe_state.clone()); } else { panic!("finalize should only be called after the probe table is built") } + Ok(None) } fn max_concurrency(&self) -> usize { diff --git a/src/daft-local-execution/src/sinks/limit.rs b/src/daft-local-execution/src/sinks/limit.rs index e65a4afb1c..0f9c6a90c9 100644 --- a/src/daft-local-execution/src/sinks/limit.rs +++ b/src/daft-local-execution/src/sinks/limit.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use async_trait::async_trait; use common_error::DaftResult; use daft_micropartition::MicroPartition; use tracing::instrument; @@ -7,7 +8,6 @@ use tracing::instrument; use super::streaming_sink::{ DynStreamingSinkState, StreamingSink, StreamingSinkOutput, StreamingSinkState, }; -use crate::pipeline::PipelineResultType; struct LimitSinkState { remaining: usize, @@ -39,16 +39,16 @@ impl LimitSink { } } +#[async_trait] impl StreamingSink for LimitSink { #[instrument(skip_all, name = "LimitSink::sink")] fn execute( &self, index: usize, - input: &PipelineResultType, + input: &Arc, state_handle: &StreamingSinkState, ) -> DaftResult { assert_eq!(index, 0); - let input = input.as_data(); let input_num_rows = input.len(); state_handle.with_state_mut::(|state| { @@ -83,7 +83,7 @@ impl StreamingSink for LimitSink { Ok(None) } - fn make_state(&self) -> Box { + async fn make_state(&self) -> Box { Box::new(LimitSinkState::new(self.limit)) } diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index 23cefecf92..66fec77e3c 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use async_trait::async_trait; use common_error::DaftResult; use daft_core::{ prelude::{ @@ -18,7 +19,7 @@ use tracing::{info_span, instrument}; use super::streaming_sink::{ DynStreamingSinkState, StreamingSink, StreamingSinkOutput, StreamingSinkState, }; -use crate::pipeline::PipelineResultType; +use crate::ProbeStateBridgeRef; struct IndexBitmapBuilder { mutable_bitmaps: Vec, @@ -69,42 +70,30 @@ impl IndexBitmap { } } -enum OuterHashJoinProbeState { - Building, - ReadyToProbe(Arc, Option), +struct OuterHashJoinProbeState { + probe_state: Arc, + index_bitmap_builder: Option, } impl OuterHashJoinProbeState { - fn initialize_probe_state(&mut self, probe_state: Arc, needs_bitmap: bool) { - let tables = probe_state.get_tables(); - if matches!(self, Self::Building) { - *self = Self::ReadyToProbe( - probe_state.clone(), - if needs_bitmap { - Some(IndexBitmapBuilder::new(tables)) - } else { - None - }, - ); + fn new(probe_state: Arc, needs_bitmap: bool) -> Self { + let index_bitmap_builder = if needs_bitmap { + Some(IndexBitmapBuilder::new(probe_state.get_tables())) } else { - panic!("OuterHashJoinProbeState should only be in Building state when setting table") + None + }; + Self { + probe_state, + index_bitmap_builder, } } fn get_probe_state(&self) -> &ProbeState { - if let Self::ReadyToProbe(probe_state, _) = self { - probe_state - } else { - panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") - } + &self.probe_state } fn get_bitmap_builder(&mut self) -> &mut Option { - if let Self::ReadyToProbe(_, bitmap_builder) = self { - bitmap_builder - } else { - panic!("get_bitmap can only be used during the ReadyToProbe Phase") - } + &mut self.index_bitmap_builder } } @@ -122,6 +111,7 @@ pub(crate) struct OuterHashJoinProbeSink { right_non_join_schema: SchemaRef, join_type: JoinType, output_schema: SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, } impl OuterHashJoinProbeSink { @@ -132,6 +122,7 @@ impl OuterHashJoinProbeSink { join_type: JoinType, common_join_keys: IndexSet, output_schema: &SchemaRef, + probe_state_bridge: ProbeStateBridgeRef, ) -> Self { let left_non_join_columns = left_schema .fields @@ -157,6 +148,7 @@ impl OuterHashJoinProbeSink { right_non_join_schema, join_type, output_schema: output_schema.clone(), + probe_state_bridge, } } @@ -315,14 +307,10 @@ impl OuterHashJoinProbeSink { let merged_bitmap = { let bitmaps = states.into_iter().map(|s| { - if let OuterHashJoinProbeState::ReadyToProbe(_, bitmap) = s { - bitmap - .take() - .expect("bitmap should be present in outer join") - .build() - } else { - panic!("OuterHashJoinProbeState should be in ReadyToProbe state") - } + s.get_bitmap_builder() + .take() + .expect("bitmap should be set in outer join") + .build() }); bitmaps.fold(None, |acc, x| match acc { None => Some(x), @@ -359,48 +347,41 @@ impl OuterHashJoinProbeSink { } } +#[async_trait] impl StreamingSink for OuterHashJoinProbeSink { #[instrument(skip_all, name = "OuterHashJoinProbeSink::execute")] fn execute( &self, - idx: usize, - input: &PipelineResultType, + _idx: usize, + input: &Arc, state_handle: &StreamingSinkState, ) -> DaftResult { - match idx { - 0 => { - state_handle.with_state_mut::(|state| { - state.initialize_probe_state( - input.as_probe_state().clone(), - self.join_type == JoinType::Outer, - ); - }); - Ok(StreamingSinkOutput::NeedMoreInput(None)) + state_handle.with_state_mut::(|state| { + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(StreamingSinkOutput::NeedMoreInput(Some(empty))); } - _ => state_handle.with_state_mut::(|state| { - let input = input.as_data(); - if input.is_empty() { - let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); - return Ok(StreamingSinkOutput::NeedMoreInput(Some(empty))); - } - let out = match self.join_type { - JoinType::Left | JoinType::Right => self.probe_left_right(input, state), - JoinType::Outer => self.probe_outer(input, state), - _ => unreachable!( - "Only Left, Right, and Outer joins are supported in OuterHashJoinProbeSink" - ), - }?; - Ok(StreamingSinkOutput::NeedMoreInput(Some(out))) - }), - } + let out = match self.join_type { + JoinType::Left | JoinType::Right => self.probe_left_right(input, state), + JoinType::Outer => self.probe_outer(input, state), + _ => unreachable!( + "Only Left, Right, and Outer joins are supported in OuterHashJoinProbeSink" + ), + }?; + Ok(StreamingSinkOutput::NeedMoreInput(Some(out))) + }) } fn name(&self) -> &'static str { "OuterHashJoinProbeSink" } - fn make_state(&self) -> Box { - Box::new(OuterHashJoinProbeState::Building) + async fn make_state(&self) -> Box { + let probe_state = self.probe_state_bridge.get_probe_state().await; + Box::new(OuterHashJoinProbeState::new( + probe_state, + self.join_type == JoinType::Outer, + )) } fn finalize( diff --git a/src/daft-local-execution/src/sinks/sort.rs b/src/daft-local-execution/src/sinks/sort.rs index 0e5a2a7e41..14ffd9e422 100644 --- a/src/daft-local-execution/src/sinks/sort.rs +++ b/src/daft-local-execution/src/sinks/sort.rs @@ -8,7 +8,7 @@ use tracing::instrument; use super::blocking_sink::{ BlockingSink, BlockingSinkState, BlockingSinkStatus, DynBlockingSinkState, }; -use crate::{pipeline::PipelineResultType, NUM_CPUS}; +use crate::NUM_CPUS; enum SortState { Building(Vec>), @@ -71,7 +71,7 @@ impl BlockingSink for SortSink { fn finalize( &self, states: Vec>, - ) -> DaftResult> { + ) -> DaftResult>> { let mut parts = Vec::new(); for mut state in states { let state = state @@ -91,7 +91,7 @@ impl BlockingSink for SortSink { .collect::>(), )?; let sorted = Arc::new(concated.sort(&self.sort_by, &self.descending)?); - Ok(Some(sorted.into())) + Ok(Some(sorted)) } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index ccee886b25..89b8ceb9fc 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -1,5 +1,6 @@ use std::sync::{Arc, Mutex}; +use async_trait::async_trait; use common_display::tree::TreeDisplay; use common_error::DaftResult; use common_runtime::get_compute_runtime; @@ -8,9 +9,12 @@ use snafu::ResultExt; use tracing::{info_span, instrument}; use crate::{ - buffer::RowBasedBuffer, - channel::{create_channel, create_ordering_aware_channel, Receiver, Sender}, - pipeline::{PipelineNode, PipelineResultType}, + channel::{ + create_channel, create_ordering_aware_receiver_channel, + create_ordering_aware_sender_channel, Receiver, Sender, + }, + dispatcher::dispatch, + pipeline::PipelineNode, runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, ExecutionRuntimeHandle, JoinSnafu, TaskSet, NUM_CPUS, }; @@ -50,6 +54,7 @@ pub enum StreamingSinkOutput { Finished(Option>), } +#[async_trait] pub trait StreamingSink: Send + Sync { /// Execute the StreamingSink operator on the morsel of input data, /// received from the child with the given index, @@ -57,7 +62,7 @@ pub trait StreamingSink: Send + Sync { fn execute( &self, index: usize, - input: &PipelineResultType, + input: &Arc, state_handle: &StreamingSinkState, ) -> DaftResult; @@ -71,7 +76,7 @@ pub trait StreamingSink: Send + Sync { fn name(&self) -> &'static str; /// Create a new worker-local state for this StreamingSink. - fn make_state(&self) -> Box; + async fn make_state(&self) -> Box; /// The maximum number of concurrent workers that can be spawned for this sink. /// Each worker will has its own StreamingSinkState. @@ -109,15 +114,15 @@ impl StreamingSinkNode { #[instrument(level = "info", skip_all, name = "StreamingSink::run_worker")] async fn run_worker( op: Arc, - input_receiver: Receiver<(usize, PipelineResultType)>, + input_receiver: Receiver<(usize, Arc)>, output_sender: Sender>, rt_context: Arc, ) -> DaftResult> { let span = info_span!("StreamingSink::Execute"); let compute_runtime = get_compute_runtime(); - let state_wrapper = StreamingSinkState::new(op.make_state()); + let state_wrapper = StreamingSinkState::new(op.make_state().await); let mut finished = false; - while let Some((idx, morsel)) = input_receiver.recv_async().await.ok() { + while let Ok((idx, morsel)) = input_receiver.recv_async().await { if finished { break; } @@ -163,7 +168,7 @@ impl StreamingSinkNode { fn spawn_workers( op: Arc, - input_receivers: Vec>, + input_receivers: Vec)>>, output_senders: Vec>>, task_set: &mut TaskSet>>, stats: Arc, @@ -177,48 +182,6 @@ impl StreamingSinkNode { )); } } - - // Forwards input from the children to the workers in a round-robin fashion. - // Always exhausts the input from one child before moving to the next. - async fn forward_input_to_workers( - receivers: Vec, - worker_senders: Vec>, - morsel_size: Option, - ) -> DaftResult<()> { - let mut next_worker_idx = 0; - let mut send_to_next_worker = |idx, data: PipelineResultType| { - let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); - next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); - next_worker_sender.send_async((idx, data)) - }; - - for (idx, mut receiver) in receivers.into_iter().enumerate() { - let mut buffer = morsel_size.map(RowBasedBuffer::new); - while let Some(morsel) = receiver.recv().await { - if morsel.should_broadcast() { - for worker_sender in &worker_senders { - let _ = worker_sender.send_async((idx, morsel.clone())).await; - } - } else if let Some(buffer) = buffer.as_mut() { - buffer.push(morsel.as_data().clone()); - if let Some(ready) = buffer.pop_enough()? { - for r in ready { - let _ = send_to_next_worker(idx, r.into()).await; - } - } - } else { - let _ = send_to_next_worker(idx, morsel).await; - } - } - if let Some(buffer) = buffer.as_mut() { - // Clear all remaining morsels - if let Some(last_morsel) = buffer.pop_all()? { - let _ = send_to_next_worker(idx, last_morsel.into()).await; - } - } - } - Ok(()) - } } impl TreeDisplay for StreamingSinkNode { @@ -258,7 +221,7 @@ impl PipelineNode for StreamingSinkNode { &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result> { + ) -> crate::Result>> { let mut child_result_receivers = Vec::with_capacity(self.children.len()); for child in &mut self.children { let child_result_channel = child.start(maintain_order, runtime_handle)?; @@ -274,17 +237,13 @@ impl PipelineNode for StreamingSinkNode { let op = self.op.clone(); let runtime_stats = self.runtime_stats.clone(); let num_workers = op.max_concurrency(); - let (input_senders, input_receivers) = if maintain_order { - (0..num_workers).map(|_| create_channel(1)).unzip() - } else { - let (tx, rx) = create_channel(num_workers); - (vec![tx; 1], vec![rx; num_workers]) - }; + let (input_senders, input_receivers) = + create_ordering_aware_sender_channel(maintain_order, num_workers); let (output_senders, mut output_receiver) = - create_ordering_aware_channel(maintain_order, num_workers); + create_ordering_aware_receiver_channel(maintain_order, num_workers); let morsel_size = runtime_handle.determine_morsel_size(op.morsel_size()); runtime_handle.spawn( - Self::forward_input_to_workers(child_result_receivers, input_senders, morsel_size), + dispatch(child_result_receivers, input_senders, morsel_size), self.name(), ); runtime_handle.spawn( @@ -299,7 +258,7 @@ impl PipelineNode for StreamingSinkNode { ); while let Some(morsel) = output_receiver.recv().await { - let _ = destination_sender.send(morsel.into()).await; + let _ = destination_sender.send(morsel).await; } let mut finished_states = Vec::with_capacity(num_workers); @@ -317,7 +276,7 @@ impl PipelineNode for StreamingSinkNode { }) .await??; if let Some(res) = finalized_result { - let _ = destination_sender.send(res.into()).await; + let _ = destination_sender.send(res).await; } Ok(()) }, diff --git a/src/daft-local-execution/src/sources/source.rs b/src/daft-local-execution/src/sources/source.rs index 8e34edaddd..5b10cfc202 100644 --- a/src/daft-local-execution/src/sources/source.rs +++ b/src/daft-local-execution/src/sources/source.rs @@ -7,14 +7,14 @@ use futures::{stream::BoxStream, StreamExt}; use crate::{ channel::{create_channel, Receiver}, - pipeline::{PipelineNode, PipelineResultType}, + pipeline::PipelineNode, runtime_stats::{CountingSender, RuntimeStatsContext}, ExecutionRuntimeHandle, }; -pub type SourceStream<'a> = BoxStream<'a, Arc>; +pub(crate) type SourceStream<'a> = BoxStream<'a, Arc>; -pub trait Source: Send + Sync { +pub(crate) trait Source: Send + Sync { fn name(&self) -> &'static str; fn get_data( &self, @@ -71,7 +71,7 @@ impl PipelineNode for SourceNode { &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result> { + ) -> crate::Result>> { let mut source_stream = self.source .get_data(maintain_order, runtime_handle, self.io_stats.clone())?; @@ -81,7 +81,7 @@ impl PipelineNode for SourceNode { runtime_handle.spawn( async move { while let Some(part) = source_stream.next().await { - let _ = counting_sender.send(part.into()).await; + let _ = counting_sender.send(part).await; } Ok(()) }, From fc2609886ec753dafef005507ba1da03abe00731 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 25 Oct 2024 21:43:23 -0700 Subject: [PATCH 05/10] undo test changes --- tests/benchmarks/test_local_tpch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/benchmarks/test_local_tpch.py b/tests/benchmarks/test_local_tpch.py index f5fb87f35b..07165d9ebc 100644 --- a/tests/benchmarks/test_local_tpch.py +++ b/tests/benchmarks/test_local_tpch.py @@ -20,7 +20,7 @@ IS_CI = True if os.getenv("CI") else False -SCALE_FACTOR = 1.0 +SCALE_FACTOR = 0.2 ENGINES = ["native"] if IS_CI else ["native", "python"] NUM_PARTS = [1] if IS_CI else [1, 2] SOURCE_TYPES = ["in-memory"] if IS_CI else ["parquet", "in-memory"] @@ -93,7 +93,7 @@ def _get_df(tbl_name: str): return _get_df, num_parts -TPCH_QUESTIONS = [1] +TPCH_QUESTIONS = list(range(1, 11)) @pytest.mark.skipif( From 8767751aaf86de6ae507ed351f866f10c5069839 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 25 Oct 2024 21:48:51 -0700 Subject: [PATCH 06/10] why are my tests not running --- .../src/intermediate_ops/intermediate_op.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index f04f2a2cdf..086600ab88 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -217,12 +217,12 @@ impl PipelineNode for IntermediateNode { } let (destination_sender, destination_receiver) = create_channel(1); let counting_sender = CountingSender::new(destination_sender, self.runtime_stats.clone()); - + let morsel_size = runtime_handle.determine_morsel_size(self.intermediate_op.morsel_size()); let (output_senders, mut output_receiver) = create_ordering_aware_receiver_channel(maintain_order, num_workers); let worker_sender = self.spawn_workers(num_workers, output_senders, runtime_handle, maintain_order); - let morsel_size = runtime_handle.determine_morsel_size(self.intermediate_op.morsel_size()); + runtime_handle.spawn( dispatch(child_result_receivers, worker_sender, morsel_size), self.intermediate_op.name(), From 3dd1ff3c91b624ea962c781564b246c505772126 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Sun, 27 Oct 2024 18:33:20 -0700 Subject: [PATCH 07/10] pipeline node start &self --- .../src/intermediate_ops/intermediate_op.rs | 4 ++-- src/daft-local-execution/src/pipeline.rs | 2 +- src/daft-local-execution/src/run.rs | 2 +- src/daft-local-execution/src/sinks/blocking_sink.rs | 5 ++--- src/daft-local-execution/src/sinks/streaming_sink.rs | 4 ++-- src/daft-local-execution/src/sources/source.rs | 2 +- 6 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 086600ab88..b4b8ef709f 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -202,13 +202,13 @@ impl PipelineNode for IntermediateNode { } fn start( - &mut self, + &self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result>> { let num_workers = *NUM_CPUS; let mut child_result_receivers = Vec::with_capacity(self.children.len()); - for child in &mut self.children { + for child in &self.children { let child_result_receiver = child.start(maintain_order, runtime_handle)?; child_result_receivers.push(CountingReceiver::new( child_result_receiver, diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index dd6e257267..13a85872d5 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -39,7 +39,7 @@ pub(crate) trait PipelineNode: Sync + Send + TreeDisplay { fn children(&self) -> Vec<&dyn PipelineNode>; fn name(&self) -> &'static str; fn start( - &mut self, + &self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result>>; diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 9e4a564335..355fae5346 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -118,7 +118,7 @@ fn run_local( results_buffer_size: Option, ) -> DaftResult>> + Send>> { refresh_chrome_trace(); - let mut pipeline = physical_plan_to_pipeline(physical_plan, &psets)?; + let pipeline = physical_plan_to_pipeline(physical_plan, &psets)?; let (tx, rx) = create_channel(results_buffer_size.unwrap_or(1)); let handle = std::thread::spawn(move || { let runtime = tokio::runtime::Builder::new_current_thread() diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index d0416be997..414e09e6a9 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -162,12 +162,11 @@ impl PipelineNode for BlockingSinkNode { } fn start( - &mut self, + &self, _maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result>> { - let child = self.child.as_mut(); - let child_results_receiver = child.start(false, runtime_handle)?; + let child_results_receiver = self.child.start(false, runtime_handle)?; let child_results_receiver = CountingReceiver::new(child_results_receiver, self.runtime_stats.clone()); diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 89b8ceb9fc..66368d0e63 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -218,12 +218,12 @@ impl PipelineNode for StreamingSinkNode { } fn start( - &mut self, + &self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result>> { let mut child_result_receivers = Vec::with_capacity(self.children.len()); - for child in &mut self.children { + for child in &self.children { let child_result_channel = child.start(maintain_order, runtime_handle)?; let counting_receiver = CountingReceiver::new(child_result_channel, self.runtime_stats.clone()); diff --git a/src/daft-local-execution/src/sources/source.rs b/src/daft-local-execution/src/sources/source.rs index 5b10cfc202..0e7e740c3c 100644 --- a/src/daft-local-execution/src/sources/source.rs +++ b/src/daft-local-execution/src/sources/source.rs @@ -68,7 +68,7 @@ impl PipelineNode for SourceNode { vec![] } fn start( - &mut self, + &self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result>> { From d7bfa8f812ffb349f43931113b49730358d77da0 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 1 Nov 2024 16:42:48 -0700 Subject: [PATCH 08/10] no async trait --- Cargo.lock | 1 - src/daft-local-execution/Cargo.toml | 1 - .../src/intermediate_ops/anti_semi_hash_join_probe.rs | 1 - .../src/intermediate_ops/inner_hash_join_probe.rs | 1 - .../src/intermediate_ops/intermediate_op.rs | 2 -- src/daft-local-execution/src/sinks/concat.rs | 2 -- src/daft-local-execution/src/sinks/limit.rs | 2 -- src/daft-local-execution/src/sinks/outer_hash_join_probe.rs | 2 -- 8 files changed, 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3e5b97b54a..b85d3fcb72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2086,7 +2086,6 @@ dependencies = [ name = "daft-local-execution" version = "0.3.0-dev0" dependencies = [ - "async-trait", "common-daft-config", "common-display", "common-error", diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index b7908501ae..06c250ee00 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -1,5 +1,4 @@ [dependencies] -async-trait = {workspace = true} common-daft-config = {path = "../common/daft-config", default-features = false} common-display = {path = "../common/display", default-features = false} common-error = {path = "../common/error", default-features = false} diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index 4b7dc10b27..0360f0dab5 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -106,7 +106,6 @@ impl AntiSemiProbeOperator { } } -#[async_trait::async_trait] impl IntermediateOperator for AntiSemiProbeOperator { #[instrument(skip_all, name = "AntiSemiOperator::execute")] fn execute( diff --git a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs index 609e898209..b1373adee7 100644 --- a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs @@ -160,7 +160,6 @@ impl InnerHashJoinProbeOperator { } } -#[async_trait::async_trait] impl IntermediateOperator for InnerHashJoinProbeOperator { #[instrument(skip_all, name = "InnerHashJoinOperator::execute")] fn execute( diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 97185d6ddf..e89baddea3 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use async_trait::async_trait; use common_display::tree::TreeDisplay; use common_error::DaftResult; use common_runtime::{get_compute_runtime, RuntimeRef}; @@ -35,7 +34,6 @@ pub(crate) enum IntermediateOperatorResultType { pub(crate) type IntermediateOperatorResult = MaybeFuture, IntermediateOperatorResultType)>>; -#[async_trait] pub trait IntermediateOperator: Send + Sync { fn execute( &self, diff --git a/src/daft-local-execution/src/sinks/concat.rs b/src/daft-local-execution/src/sinks/concat.rs index 029b19d67f..9fb86c435c 100644 --- a/src/daft-local-execution/src/sinks/concat.rs +++ b/src/daft-local-execution/src/sinks/concat.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use async_trait::async_trait; use common_runtime::RuntimeRef; use daft_micropartition::MicroPartition; use tracing::instrument; @@ -23,7 +22,6 @@ impl StreamingSinkState for ConcatSinkState { pub struct ConcatSink {} -#[async_trait] impl StreamingSink for ConcatSink { /// By default, if the streaming_sink is called with maintain_order = true, input is distributed round-robin to the workers, /// and the output is received in the same order. Therefore, the 'execute' method does not need to do anything. diff --git a/src/daft-local-execution/src/sinks/limit.rs b/src/daft-local-execution/src/sinks/limit.rs index 090b2acae8..66a4af8122 100644 --- a/src/daft-local-execution/src/sinks/limit.rs +++ b/src/daft-local-execution/src/sinks/limit.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use async_trait::async_trait; use common_runtime::RuntimeRef; use daft_micropartition::MicroPartition; use tracing::instrument; @@ -44,7 +43,6 @@ impl LimitSink { } } -#[async_trait] impl StreamingSink for LimitSink { #[instrument(skip_all, name = "LimitSink::sink")] fn execute( diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index 76226f6041..11a7da951b 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use async_trait::async_trait; use common_error::DaftResult; use common_runtime::RuntimeRef; use daft_core::{ @@ -365,7 +364,6 @@ impl OuterHashJoinProbeSink { } } -#[async_trait] impl StreamingSink for OuterHashJoinProbeSink { #[instrument(skip_all, name = "OuterHashJoinProbeSink::execute")] fn execute( From 3ace4741802b7061eafaf2c99082671d43bca1a4 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 1 Nov 2024 17:23:17 -0700 Subject: [PATCH 09/10] fix unpivot tests --- src/daft-local-execution/src/sinks/limit.rs | 10 ++- .../src/sinks/outer_hash_join_probe.rs | 77 +++++++++++-------- tests/dataframe/test_unpivot.py | 22 +++--- 3 files changed, 65 insertions(+), 44 deletions(-) diff --git a/src/daft-local-execution/src/sinks/limit.rs b/src/daft-local-execution/src/sinks/limit.rs index 66a4af8122..2baa5cd359 100644 --- a/src/daft-local-execution/src/sinks/limit.rs +++ b/src/daft-local-execution/src/sinks/limit.rs @@ -62,11 +62,17 @@ impl StreamingSink for LimitSink { match input_num_rows.cmp(remaining) { Less => { *remaining -= input_num_rows; - MaybeFuture::Immediate(Ok((state, StreamingSinkOutputType::NeedMoreInput(None)))) + MaybeFuture::Immediate(Ok(( + state, + StreamingSinkOutputType::NeedMoreInput(Some(input.clone())), + ))) } Equal => { *remaining = 0; - MaybeFuture::Immediate(Ok((state, StreamingSinkOutputType::Finished(None)))) + MaybeFuture::Immediate(Ok(( + state, + StreamingSinkOutputType::Finished(Some(input.clone())), + ))) } Greater => { let input = input.clone(); diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index 11a7da951b..e8dd7d2f09 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -13,6 +13,7 @@ use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use daft_plan::JoinType; use daft_table::{GrowableTable, ProbeState, Table}; +use futures::{stream, StreamExt}; use indexmap::IndexSet; use tracing::{info_span, instrument}; @@ -77,12 +78,12 @@ impl IndexBitmap { } } -enum OuterHashJoinProbeState { +enum OuterHashJoinState { Building(ProbeStateBridgeRef, bool), Probing(Arc, Option), } -impl OuterHashJoinProbeState { +impl OuterHashJoinState { async fn get_or_build_probe_state(&mut self) -> Arc { match self { Self::Building(bridge, needs_bitmap) => { @@ -96,16 +97,23 @@ impl OuterHashJoinProbeState { } } - fn get_bitmap_builder(&mut self) -> &mut Option { - if let Self::Probing(_, builder) = self { - builder - } else { - panic!("Bitmap builder should be set in Probing state") + async fn get_or_build_bitmap(&mut self) -> &mut Option { + match self { + Self::Building(bridge, _) => { + let probe_state = bridge.get_probe_state().await; + let builder = IndexBitmapBuilder::new(probe_state.get_tables()); + *self = Self::Probing(probe_state, Some(builder)); + match self { + Self::Probing(_, builder) => builder, + _ => unreachable!(), + } + } + Self::Probing(_, builder) => builder, } } } -impl StreamingSinkState for OuterHashJoinProbeState { +impl StreamingSinkState for OuterHashJoinState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } @@ -165,7 +173,7 @@ impl OuterHashJoinProbeSink { fn probe_left_right( input: &Arc, - probe_state: Arc, + probe_state: &ProbeState, join_type: JoinType, probe_on: &[ExprRef], common_join_keys: &[String], @@ -234,7 +242,7 @@ impl OuterHashJoinProbeSink { fn probe_outer( input: &Arc, - probe_state: Arc, + probe_state: &ProbeState, bitmap_builder: &mut IndexBitmapBuilder, probe_on: &[ExprRef], common_join_keys: &[String], @@ -305,7 +313,7 @@ impl OuterHashJoinProbeSink { .next() .expect("at least one state should be present") .as_any_mut() - .downcast_mut::() + .downcast_mut::() .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); let tables = first_state .get_or_build_probe_state() @@ -313,24 +321,30 @@ impl OuterHashJoinProbeSink { .get_tables() .clone(); let first_bitmap = first_state - .get_bitmap_builder() + .get_or_build_bitmap() + .await .take() - .expect("bitmap should be set in outer join") + .expect("bitmap should be set") .build(); let merged_bitmap = { - let bitmaps = std::iter::once(first_bitmap).chain(states_iter.map(|s| { - let state = s - .as_any_mut() - .downcast_mut::() - .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); - state - .get_bitmap_builder() - .take() - .expect("bitmap should be set in outer join") - .build() - })); - bitmaps.fold(None, |acc, x| match acc { + let bitmaps = stream::once(async move { first_bitmap }) + .chain(stream::iter(states_iter).then(|s| async move { + let state = s + .as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); + state + .get_or_build_bitmap() + .await + .take() + .expect("bitmap should be set") + .build() + })) + .collect::>() + .await; + + bitmaps.into_iter().fold(None, |acc, x| match acc { None => Some(x), Some(acc) => Some(acc.merge(&x)), }) @@ -389,13 +403,13 @@ impl StreamingSink for OuterHashJoinProbeSink { let fut = runtime_ref.spawn(async move { let outer_join_state = state .as_any_mut() - .downcast_mut::() + .downcast_mut::() .expect("OuterHashJoinProbeSink should have OuterHashJoinProbeState"); let probe_state = outer_join_state.get_or_build_probe_state().await; let out = match join_type { JoinType::Left | JoinType::Right => Self::probe_left_right( &input, - probe_state, + &probe_state, join_type, &probe_on, &common_join_keys, @@ -404,12 +418,13 @@ impl StreamingSink for OuterHashJoinProbeSink { ), JoinType::Outer => { let bitmap_builder = outer_join_state - .get_bitmap_builder() + .get_or_build_bitmap() + .await .as_mut() - .expect("bitmap builder should be set in Outer join"); + .expect("bitmap should be set"); Self::probe_outer( &input, - probe_state, + &probe_state, bitmap_builder, &probe_on, &common_join_keys, @@ -431,7 +446,7 @@ impl StreamingSink for OuterHashJoinProbeSink { } fn make_state(&self) -> Box { - Box::new(OuterHashJoinProbeState::Building( + Box::new(OuterHashJoinState::Building( self.probe_state_bridge.clone(), self.join_type == JoinType::Outer, )) diff --git a/tests/dataframe/test_unpivot.py b/tests/dataframe/test_unpivot.py index ef20c453cc..9a929bd594 100644 --- a/tests/dataframe/test_unpivot.py +++ b/tests/dataframe/test_unpivot.py @@ -11,7 +11,7 @@ def set_default_morsel_size(): yield -@pytest.mark.parametrize("n_partitions", [1, 2, 4]) +@pytest.mark.parametrize("n_partitions", [2]) def test_unpivot(make_df, n_partitions): df = make_df( { @@ -23,7 +23,7 @@ def test_unpivot(make_df, n_partitions): ) df = df.unpivot("id", ["a", "b"]) - df = df.sort("id") + df = df.sort(["id", "variable"]) df = df.collect() expected = { @@ -47,7 +47,7 @@ def test_unpivot_no_values(make_df, n_partitions): ) df = df.unpivot("id") - df = df.sort("id") + df = df.sort(["id", "variable"]) df = df.collect() expected = { @@ -71,7 +71,7 @@ def test_unpivot_different_types(make_df, n_partitions): ) df = df.unpivot("id", ["a", "b"]) - df = df.sort("id") + df = df.sort(["id", "variable"]) df = df.collect() expected = { @@ -110,7 +110,7 @@ def test_unpivot_nulls(make_df, n_partitions): ) df = df.unpivot("id", ["a", "b"]) - df = df.sort("id") + df = df.sort(["id", "variable"]) df = df.collect() expected = { @@ -134,7 +134,7 @@ def test_unpivot_null_column(make_df, n_partitions): ) df = df.unpivot("id", ["a", "b"]) - df = df.sort("id") + df = df.sort(["id", "variable"]) df = df.collect() expected = { @@ -159,7 +159,7 @@ def test_unpivot_multiple_ids(make_df, n_partitions): ) df = df.unpivot(["id1", "id2"], ["a", "b"]) - df = df.sort("id1") + df = df.sort(["id1", "id2", "variable"]) df = df.collect() expected = { @@ -206,13 +206,13 @@ def test_unpivot_expr(make_df, n_partitions): ) df = df.unpivot("id", ["a", "b", (col("a") + col("b")).alias("a_plus_b")]) - df = df.sort("id") + df = df.sort(["id", "variable"]) df = df.collect() expected = { "id": ["x", "x", "x", "y", "y", "y", "z", "z", "z"], - "variable": ["a", "b", "a_plus_b", "a", "b", "a_plus_b", "a", "b", "a_plus_b"], - "value": [1, 2, 3, 3, 4, 7, 5, 6, 11], + "variable": ["a", "a_plus_b", "b", "a", "a_plus_b", "b", "a", "a_plus_b", "b"], + "value": [1, 3, 2, 3, 7, 4, 5, 11, 6], } assert df.to_pydict() == expected @@ -251,7 +251,7 @@ def test_unpivot_empty_partition(make_df): df = df.into_partitions(4) df = df.unpivot("id", ["a", "b"]) - df = df.sort("id") + df = df.sort(["id", "variable"]) df = df.collect() expected = { From 7deb3194b80c6b9cf8ee16998cde2017252b568e Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 4 Nov 2024 17:07:13 -0800 Subject: [PATCH 10/10] fix tests --- tests/cookbook/conftest.py | 6 ------ tests/dataframe/test_aggregations.py | 6 ------ tests/dataframe/test_approx_count_distinct.py | 7 ------- tests/dataframe/test_approx_percentiles_aggregations.py | 7 ------- tests/dataframe/test_concat.py | 8 -------- tests/dataframe/test_distinct.py | 7 ------- tests/dataframe/test_joins.py | 6 ------ tests/dataframe/test_map_groups.py | 6 ------ tests/dataframe/test_pivot.py | 8 -------- tests/dataframe/test_sort.py | 7 ------- tests/dataframe/test_stddev.py | 6 ------ 11 files changed, 74 deletions(-) diff --git a/tests/cookbook/conftest.py b/tests/cookbook/conftest.py index d70108af04..f62b74203e 100644 --- a/tests/cookbook/conftest.py +++ b/tests/cookbook/conftest.py @@ -51,9 +51,3 @@ def repartition_nparts(request): partitions that the test case should repartition its dataset into for testing """ return request.param - - -@pytest.fixture(scope="function", autouse=True) -def set_default_morsel_size(): - with daft.context.execution_config_ctx(default_morsel_size=1): - yield diff --git a/tests/dataframe/test_aggregations.py b/tests/dataframe/test_aggregations.py index 4f5075508e..f942410d77 100644 --- a/tests/dataframe/test_aggregations.py +++ b/tests/dataframe/test_aggregations.py @@ -15,12 +15,6 @@ from tests.utils import sort_arrow_table -@pytest.fixture(scope="function", autouse=True) -def set_default_morsel_size(): - with daft.context.execution_config_ctx(default_morsel_size=1): - yield - - @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) def test_agg_global(make_df, repartition_nparts, with_morsel_size): daft_df = make_df( diff --git a/tests/dataframe/test_approx_count_distinct.py b/tests/dataframe/test_approx_count_distinct.py index e644aef22c..ac664009fb 100644 --- a/tests/dataframe/test_approx_count_distinct.py +++ b/tests/dataframe/test_approx_count_distinct.py @@ -4,13 +4,6 @@ import daft from daft import col - -@pytest.fixture(scope="function", autouse=True) -def set_default_morsel_size(): - with daft.context.execution_config_ctx(default_morsel_size=1): - yield - - TESTS = [ [[], 0], [[None] * 10, 0], diff --git a/tests/dataframe/test_approx_percentiles_aggregations.py b/tests/dataframe/test_approx_percentiles_aggregations.py index d993f428cf..d64f1a2381 100644 --- a/tests/dataframe/test_approx_percentiles_aggregations.py +++ b/tests/dataframe/test_approx_percentiles_aggregations.py @@ -4,16 +4,9 @@ import pyarrow as pa import pytest -import daft from daft import col -@pytest.fixture(scope="function", autouse=True) -def set_default_morsel_size(): - with daft.context.execution_config_ctx(default_morsel_size=1): - yield - - @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) @pytest.mark.parametrize("percentiles_expected", [(0.5, [2.0]), ([0.5], [[2.0]]), ([0.5, 0.5], [[2.0, 2.0]])]) def test_approx_percentiles_global(make_df, repartition_nparts, percentiles_expected): diff --git a/tests/dataframe/test_concat.py b/tests/dataframe/test_concat.py index 8ed89bb387..3e18e1e80b 100644 --- a/tests/dataframe/test_concat.py +++ b/tests/dataframe/test_concat.py @@ -2,14 +2,6 @@ import pytest -import daft - - -@pytest.fixture(scope="function", autouse=True) -def set_default_morsel_size(): - with daft.context.execution_config_ctx(default_morsel_size=1): - yield - def test_simple_concat(make_df, with_morsel_size): df1 = make_df({"foo": [1, 2, 3]}) diff --git a/tests/dataframe/test_distinct.py b/tests/dataframe/test_distinct.py index ac1c2a3c8d..d09e0cefcd 100644 --- a/tests/dataframe/test_distinct.py +++ b/tests/dataframe/test_distinct.py @@ -3,17 +3,10 @@ import pyarrow as pa import pytest -import daft from daft.datatype import DataType from tests.utils import sort_arrow_table -@pytest.fixture(scope="function", autouse=True) -def set_default_morsel_size(): - with daft.context.execution_config_ctx(default_morsel_size=1): - yield - - @pytest.mark.parametrize("repartition_nparts", [1, 2, 5]) def test_distinct_with_nulls(make_df, repartition_nparts, with_morsel_size): daft_df = make_df( diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index 97506cb09e..ceac56283f 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -10,12 +10,6 @@ from tests.utils import sort_arrow_table -@pytest.fixture(scope="function", autouse=True) -def set_default_morsel_size(): - with context.execution_config_ctx(default_morsel_size=1): - yield - - def skip_invalid_join_strategies(join_strategy, join_type): if context.get_context().daft_execution_config.enable_native_executor is True: if join_strategy not in [None, "hash"]: diff --git a/tests/dataframe/test_map_groups.py b/tests/dataframe/test_map_groups.py index 424a91efbe..0024c02797 100644 --- a/tests/dataframe/test_map_groups.py +++ b/tests/dataframe/test_map_groups.py @@ -5,12 +5,6 @@ import daft -@pytest.fixture(scope="function", autouse=True) -def set_default_morsel_size(): - with daft.context.execution_config_ctx(default_morsel_size=1): - yield - - @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) def test_map_groups(make_df, repartition_nparts, with_morsel_size): daft_df = make_df( diff --git a/tests/dataframe/test_pivot.py b/tests/dataframe/test_pivot.py index 4eb65cc7f4..7c2b2d45a3 100644 --- a/tests/dataframe/test_pivot.py +++ b/tests/dataframe/test_pivot.py @@ -1,13 +1,5 @@ import pytest -import daft - - -@pytest.fixture(scope="function", autouse=True) -def set_default_morsel_size(): - with daft.context.execution_config_ctx(default_morsel_size=1): - yield - @pytest.mark.parametrize("repartition_nparts", [1, 2, 5]) def test_pivot(make_df, repartition_nparts, with_morsel_size): diff --git a/tests/dataframe/test_sort.py b/tests/dataframe/test_sort.py index c59668057f..a6e325de7e 100644 --- a/tests/dataframe/test_sort.py +++ b/tests/dataframe/test_sort.py @@ -5,7 +5,6 @@ import pyarrow as pa import pytest -import daft from daft.datatype import DataType from daft.errors import ExpressionTypeError @@ -14,12 +13,6 @@ ### -@pytest.fixture(scope="function", autouse=True) -def set_default_morsel_size(): - with daft.context.execution_config_ctx(default_morsel_size=1): - yield - - def test_disallowed_sort_null(make_df): df = make_df({"A": [None, None]}) diff --git a/tests/dataframe/test_stddev.py b/tests/dataframe/test_stddev.py index a8bbda9401..0c24cbc31c 100644 --- a/tests/dataframe/test_stddev.py +++ b/tests/dataframe/test_stddev.py @@ -8,12 +8,6 @@ import daft -@pytest.fixture(scope="function", autouse=True) -def set_default_morsel_size(): - with daft.context.execution_config_ctx(default_morsel_size=1): - yield - - def grouped_stddev(rows) -> Tuple[List[Any], List[Any]]: map = {} for key, data in rows: