From e148248dae8af90c8993d2ec6b2f471521c0a7f2 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 16 Dec 2024 22:02:49 -0800 Subject: [PATCH] feat(swordfish): Optimize grouped aggregations (#3534) Optimize swordfish grouped aggs for high cardinality groups ### Approach There's 3 strategies for grouped aggs: 1. Partition each input morsel into `N` partitions, then do a partial agg. (good for high cardinality). 2. Do a partial agg, then partition into `N` partitions. (good for low cardinality). Can be optimized with https://github.com/Eventual-Inc/Daft/pull/3556 3. Partition only, no partial agg. (only for map_groups, which has no partial agg). ### Notes on alternative approaches - Distributing partitions across workers (i.e. having each worker being responsible for accumulating only one partition) is much slower for low cardinality aggs (TPCH Q1 would have been 1.5x slower). This is because most of the work will end up being on only a few workers, reducing parallelism. - Simply partitioning the input and then only aggregating at the end works well with higher cardinality, but low cardinality takes a hit. (TPCH Q1 would have been 2.5x slower). - Probe Table approach was much slower, due to many calls to the multi-table dyn comparator. It was also much more complex to implement. ### Benchmarks [MrPowers Benchmarks](https://github.com/MrPowers/mrpowers-benchmarks) results (seconds, lower is better). | Query | this PR | Pyrunner | Current swordfish | |-------|---------|----------|-------------------| | q1 | 0.285720| 0.768858 | 0.356499 | | q2 | 4.780064| 6.122199 | 53.340565 | | q3 | 2.201079| 3.922857 | 16.935125 | | q4 | 0.313106| 0.545192 | 0.335541 | | q5 | 1.618228| 2.889354 | 10.665339 | | q7 | 2.087872| 3.856998 | 16.072660 | | q10 | 6.306756| 8.173738 | 53.800501 | --------- Co-authored-by: EC2 Default User Co-authored-by: Colin Ho Co-authored-by: EC2 Default User --- daft/context.py | 8 +- daft/daft/__init__.pyi | 6 + src/common/daft-config/src/lib.rs | 4 + src/common/daft-config/src/python.rs | 19 + .../src/intermediate_ops/aggregate.rs | 56 --- .../src/intermediate_ops/mod.rs | 1 - src/daft-local-execution/src/lib.rs | 2 + src/daft-local-execution/src/pipeline.rs | 89 +--- .../src/sinks/aggregate.rs | 62 ++- .../src/sinks/grouped_aggregate.rs | 438 ++++++++++++++++++ src/daft-local-execution/src/sinks/mod.rs | 1 + tests/dataframe/test_intersect.py | 4 +- tests/sql/test_aggs.py | 1 + 13 files changed, 537 insertions(+), 154 deletions(-) delete mode 100644 src/daft-local-execution/src/intermediate_ops/aggregate.rs create mode 100644 src/daft-local-execution/src/sinks/grouped_aggregate.rs diff --git a/daft/context.py b/daft/context.py index 0b071a431b..dbec115099 100644 --- a/daft/context.py +++ b/daft/context.py @@ -344,6 +344,8 @@ def set_execution_config( csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, + partial_aggregation_threshold: int | None = None, + high_cardinality_aggregation_threshold: float | None = None, read_sql_partition_size_bytes: int | None = None, enable_aqe: bool | None = None, enable_native_executor: bool | None = None, @@ -384,7 +386,9 @@ def set_execution_config( parquet_inflation_factor: Inflation Factor of parquet files (In-Memory-Size / File-Size) ratio. Defaults to 3.0 csv_target_filesize: Target File Size when writing out CSV Files. Defaults to 512MB csv_inflation_factor: Inflation Factor of CSV files (In-Memory-Size / File-Size) ratio. Defaults to 0.5 - shuffle_aggregation_default_partitions: Maximum number of partitions to create when performing aggregations. Defaults to 200, unless the number of input partitions is less than 200. + shuffle_aggregation_default_partitions: Maximum number of partitions to create when performing aggregations on the Ray Runner. Defaults to 200, unless the number of input partitions is less than 200. + partial_aggregation_threshold: Threshold for performing partial aggregations on the Native Runner. Defaults to 10000 rows. + high_cardinality_aggregation_threshold: Threshold selectivity for performing high cardinality aggregations on the Native Runner. Defaults to 0.8. read_sql_partition_size_bytes: Target size of partition when reading from SQL databases. Defaults to 512MB enable_aqe: Enables Adaptive Query Execution, Defaults to False enable_native_executor: Enables the native executor, Defaults to False @@ -413,6 +417,8 @@ def set_execution_config( csv_target_filesize=csv_target_filesize, csv_inflation_factor=csv_inflation_factor, shuffle_aggregation_default_partitions=shuffle_aggregation_default_partitions, + partial_aggregation_threshold=partial_aggregation_threshold, + high_cardinality_aggregation_threshold=high_cardinality_aggregation_threshold, read_sql_partition_size_bytes=read_sql_partition_size_bytes, enable_aqe=enable_aqe, enable_native_executor=enable_native_executor, diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 47e88d9afd..da292d2df1 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1730,6 +1730,8 @@ class PyDaftExecutionConfig: csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, + partial_aggregation_threshold: int | None = None, + high_cardinality_aggregation_threshold: float | None = None, read_sql_partition_size_bytes: int | None = None, enable_aqe: bool | None = None, enable_native_executor: bool | None = None, @@ -1765,6 +1767,10 @@ class PyDaftExecutionConfig: @property def shuffle_aggregation_default_partitions(self) -> int: ... @property + def partial_aggregation_threshold(self) -> int: ... + @property + def high_cardinality_aggregation_threshold(self) -> float: ... + @property def read_sql_partition_size_bytes(self) -> int: ... @property def enable_aqe(self) -> bool: ... diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index 3ffa4ea158..590fd5cf6c 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -44,6 +44,8 @@ pub struct DaftExecutionConfig { pub csv_target_filesize: usize, pub csv_inflation_factor: f64, pub shuffle_aggregation_default_partitions: usize, + pub partial_aggregation_threshold: usize, + pub high_cardinality_aggregation_threshold: f64, pub read_sql_partition_size_bytes: usize, pub enable_aqe: bool, pub enable_native_executor: bool, @@ -70,6 +72,8 @@ impl Default for DaftExecutionConfig { csv_target_filesize: 512 * 1024 * 1024, // 512MB csv_inflation_factor: 0.5, shuffle_aggregation_default_partitions: 200, + partial_aggregation_threshold: 10000, + high_cardinality_aggregation_threshold: 0.8, read_sql_partition_size_bytes: 512 * 1024 * 1024, // 512MB enable_aqe: false, enable_native_executor: false, diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index aceefd63d2..3228263b07 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -90,6 +90,8 @@ impl PyDaftExecutionConfig { csv_target_filesize: Option, csv_inflation_factor: Option, shuffle_aggregation_default_partitions: Option, + partial_aggregation_threshold: Option, + high_cardinality_aggregation_threshold: Option, read_sql_partition_size_bytes: Option, enable_aqe: Option, enable_native_executor: Option, @@ -146,6 +148,13 @@ impl PyDaftExecutionConfig { { config.shuffle_aggregation_default_partitions = shuffle_aggregation_default_partitions; } + if let Some(partial_aggregation_threshold) = partial_aggregation_threshold { + config.partial_aggregation_threshold = partial_aggregation_threshold; + } + if let Some(high_cardinality_aggregation_threshold) = high_cardinality_aggregation_threshold + { + config.high_cardinality_aggregation_threshold = high_cardinality_aggregation_threshold; + } if let Some(read_sql_partition_size_bytes) = read_sql_partition_size_bytes { config.read_sql_partition_size_bytes = read_sql_partition_size_bytes; } @@ -245,6 +254,16 @@ impl PyDaftExecutionConfig { Ok(self.config.shuffle_aggregation_default_partitions) } + #[getter] + fn get_partial_aggregation_threshold(&self) -> PyResult { + Ok(self.config.partial_aggregation_threshold) + } + + #[getter] + fn get_high_cardinality_aggregation_threshold(&self) -> PyResult { + Ok(self.config.high_cardinality_aggregation_threshold) + } + #[getter] fn get_read_sql_partition_size_bytes(&self) -> PyResult { Ok(self.config.read_sql_partition_size_bytes) diff --git a/src/daft-local-execution/src/intermediate_ops/aggregate.rs b/src/daft-local-execution/src/intermediate_ops/aggregate.rs deleted file mode 100644 index cb9344b160..0000000000 --- a/src/daft-local-execution/src/intermediate_ops/aggregate.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::sync::Arc; - -use common_runtime::RuntimeRef; -use daft_dsl::ExprRef; -use daft_micropartition::MicroPartition; -use tracing::instrument; - -use super::intermediate_op::{ - IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator, - IntermediateOperatorResult, -}; - -struct AggParams { - agg_exprs: Vec, - group_by: Vec, -} - -pub struct AggregateOperator { - params: Arc, -} - -impl AggregateOperator { - pub fn new(agg_exprs: Vec, group_by: Vec) -> Self { - Self { - params: Arc::new(AggParams { - agg_exprs, - group_by, - }), - } - } -} - -impl IntermediateOperator for AggregateOperator { - #[instrument(skip_all, name = "AggregateOperator::execute")] - fn execute( - &self, - input: Arc, - state: Box, - runtime: &RuntimeRef, - ) -> IntermediateOpExecuteResult { - let params = self.params.clone(); - runtime - .spawn(async move { - let out = input.agg(¶ms.agg_exprs, ¶ms.group_by)?; - Ok(( - state, - IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(out))), - )) - }) - .into() - } - - fn name(&self) -> &'static str { - "AggregateOperator" - } -} diff --git a/src/daft-local-execution/src/intermediate_ops/mod.rs b/src/daft-local-execution/src/intermediate_ops/mod.rs index f0af96b763..ca01bafaa4 100644 --- a/src/daft-local-execution/src/intermediate_ops/mod.rs +++ b/src/daft-local-execution/src/intermediate_ops/mod.rs @@ -1,5 +1,4 @@ pub mod actor_pool_project; -pub mod aggregate; pub mod anti_semi_hash_join_probe; pub mod cross_join; pub mod explode; diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index ec6020eaef..b1b5d99bb0 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -1,4 +1,6 @@ #![feature(let_chains)] +#![feature(option_get_or_insert_default)] + mod buffer; mod channel; mod dispatcher; diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 0c75de9be6..8881cb37de 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -10,7 +10,7 @@ use daft_core::{ prelude::{Schema, SchemaRef}, utils::supertype, }; -use daft_dsl::{col, join::get_common_join_keys, Expr}; +use daft_dsl::{col, join::get_common_join_keys}; use daft_local_plan::{ ActorPoolProject, Concat, CrossJoin, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, MonotonicallyIncreasingId, PhysicalWrite, Pivot, @@ -18,7 +18,6 @@ use daft_local_plan::{ }; use daft_logical_plan::{stats::StatsState, JoinType}; use daft_micropartition::{partitioning::PartitionSet, MicroPartition}; -use daft_physical_plan::{extract_agg_expr, populate_aggregation_stages}; use daft_scan::ScanTaskRef; use daft_writers::make_physical_writer_factory; use indexmap::IndexSet; @@ -27,7 +26,7 @@ use snafu::ResultExt; use crate::{ channel::Receiver, intermediate_ops::{ - actor_pool_project::ActorPoolProjectOperator, aggregate::AggregateOperator, + actor_pool_project::ActorPoolProjectOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator, cross_join::CrossJoinOperator, explode::ExplodeOperator, filter::FilterOperator, inner_hash_join_probe::InnerHashJoinProbeOperator, intermediate_op::IntermediateNode, @@ -38,6 +37,7 @@ use crate::{ blocking_sink::BlockingSinkNode, concat::ConcatSink, cross_join_collect::CrossJoinCollectSink, + grouped_aggregate::GroupedAggregateSink, hash_join_build::HashJoinBuildSink, limit::LimitSink, monotonically_increasing_id::MonotonicallyIncreasingIdSink, @@ -170,42 +170,13 @@ pub fn physical_plan_to_pipeline( schema, .. }) => { - let aggregations = aggregations - .iter() - .map(extract_agg_expr) - .collect::>>() - .with_context(|_| PipelineCreationSnafu { - plan_name: physical_plan.name(), - })?; - - let (first_stage_aggs, second_stage_aggs, final_exprs) = - populate_aggregation_stages(&aggregations, schema, &[]); - let first_stage_agg_op = AggregateOperator::new( - first_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - vec![], - ); let child_node = physical_plan_to_pipeline(input, psets, cfg)?; - let post_first_agg_node = - IntermediateNode::new(Arc::new(first_stage_agg_op), vec![child_node]).boxed(); - - let second_stage_agg_sink = AggregateSink::new( - second_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - vec![], - ); - let second_stage_node = - BlockingSinkNode::new(Arc::new(second_stage_agg_sink), post_first_agg_node).boxed(); - - let final_stage_project = ProjectOperator::new(final_exprs); - - IntermediateNode::new(Arc::new(final_stage_project), vec![second_stage_node]).boxed() + let agg_sink = AggregateSink::new(aggregations, schema).with_context(|_| { + PipelineCreationSnafu { + plan_name: physical_plan.name(), + } + })?; + BlockingSinkNode::new(Arc::new(agg_sink), child_node).boxed() } LocalPhysicalPlan::HashAggregate(HashAggregate { input, @@ -214,48 +185,12 @@ pub fn physical_plan_to_pipeline( schema, .. }) => { - let aggregations = aggregations - .iter() - .map(extract_agg_expr) - .collect::>>() + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; + let agg_sink = GroupedAggregateSink::new(aggregations, group_by, schema, cfg) .with_context(|_| PipelineCreationSnafu { plan_name: physical_plan.name(), })?; - - let (first_stage_aggs, second_stage_aggs, final_exprs) = - populate_aggregation_stages(&aggregations, schema, group_by); - let child_node = physical_plan_to_pipeline(input, psets, cfg)?; - let (post_first_agg_node, group_by) = if !first_stage_aggs.is_empty() { - let agg_op = AggregateOperator::new( - first_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - group_by.clone(), - ); - ( - IntermediateNode::new(Arc::new(agg_op), vec![child_node]).boxed(), - &group_by.iter().map(|e| col(e.name())).collect(), - ) - } else { - (child_node, group_by) - }; - - let second_stage_agg_sink = AggregateSink::new( - second_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - group_by.clone(), - ); - let second_stage_node = - BlockingSinkNode::new(Arc::new(second_stage_agg_sink), post_first_agg_node).boxed(); - - let final_stage_project = ProjectOperator::new(final_exprs); - - IntermediateNode::new(Arc::new(final_stage_project), vec![second_stage_node]).boxed() + BlockingSinkNode::new(Arc::new(agg_sink), child_node).boxed() } LocalPhysicalPlan::Unpivot(Unpivot { input, diff --git a/src/daft-local-execution/src/sinks/aggregate.rs b/src/daft-local-execution/src/sinks/aggregate.rs index 4ac74c8fb5..148fb7e2e4 100644 --- a/src/daft-local-execution/src/sinks/aggregate.rs +++ b/src/daft-local-execution/src/sinks/aggregate.rs @@ -2,8 +2,10 @@ use std::sync::Arc; use common_error::DaftResult; use common_runtime::RuntimeRef; -use daft_dsl::ExprRef; +use daft_core::prelude::SchemaRef; +use daft_dsl::{Expr, ExprRef}; use daft_micropartition::MicroPartition; +use daft_physical_plan::extract_agg_expr; use tracing::instrument; use super::blocking_sink::{ @@ -44,8 +46,9 @@ impl BlockingSinkState for AggregateState { } struct AggParams { - agg_exprs: Vec, - group_by: Vec, + sink_agg_exprs: Vec, + finalize_agg_exprs: Vec, + final_projections: Vec, } pub struct AggregateSink { @@ -53,13 +56,31 @@ pub struct AggregateSink { } impl AggregateSink { - pub fn new(agg_exprs: Vec, group_by: Vec) -> Self { - Self { + pub fn new(aggregations: &[ExprRef], schema: &SchemaRef) -> DaftResult { + let aggregations = aggregations + .iter() + .map(extract_agg_expr) + .collect::>>()?; + let (sink_aggs, finalize_aggs, final_projections) = + daft_physical_plan::populate_aggregation_stages(&aggregations, schema, &[]); + let sink_agg_exprs = sink_aggs + .values() + .cloned() + .map(|e| Arc::new(Expr::Agg(e))) + .collect(); + let finalize_agg_exprs = finalize_aggs + .values() + .cloned() + .map(|e| Arc::new(Expr::Agg(e))) + .collect(); + + Ok(Self { agg_sink_params: Arc::new(AggParams { - agg_exprs, - group_by, + sink_agg_exprs, + finalize_agg_exprs, + final_projections, }), - } + }) } } @@ -69,14 +90,20 @@ impl BlockingSink for AggregateSink { &self, input: Arc, mut state: Box, - _runtime: &RuntimeRef, + runtime: &RuntimeRef, ) -> BlockingSinkSinkResult { - state - .as_any_mut() - .downcast_mut::() - .expect("AggregateSink should have AggregateState") - .push(input); - Ok(BlockingSinkStatus::NeedMoreInput(state)).into() + let params = self.agg_sink_params.clone(); + runtime + .spawn(async move { + let agg_state = state + .as_any_mut() + .downcast_mut::() + .expect("AggregateSink should have AggregateState"); + let agged = Arc::new(input.agg(¶ms.sink_agg_exprs, &[])?); + agg_state.push(agged); + Ok(BlockingSinkStatus::NeedMoreInput(state)) + }) + .into() } #[instrument(skip_all, name = "AggregateSink::finalize")] @@ -96,8 +123,9 @@ impl BlockingSink for AggregateSink { .finalize() }); let concated = MicroPartition::concat(all_parts)?; - let agged = Arc::new(concated.agg(¶ms.agg_exprs, ¶ms.group_by)?); - Ok(Some(agged)) + let agged = concated.agg(¶ms.finalize_agg_exprs, &[])?; + let projected = agged.eval_expression_list(¶ms.final_projections)?; + Ok(Some(Arc::new(projected))) }) .into() } diff --git a/src/daft-local-execution/src/sinks/grouped_aggregate.rs b/src/daft-local-execution/src/sinks/grouped_aggregate.rs new file mode 100644 index 0000000000..c240480eed --- /dev/null +++ b/src/daft-local-execution/src/sinks/grouped_aggregate.rs @@ -0,0 +1,438 @@ +use std::{ + collections::HashSet, + sync::{Arc, Mutex}, +}; + +use common_daft_config::DaftExecutionConfig; +use common_error::DaftResult; +use common_runtime::RuntimeRef; +use daft_core::prelude::SchemaRef; +use daft_dsl::{col, Expr, ExprRef}; +use daft_micropartition::MicroPartition; +use daft_physical_plan::extract_agg_expr; +use tracing::{info_span, instrument, Instrument}; + +use super::blocking_sink::{ + BlockingSink, BlockingSinkFinalizeResult, BlockingSinkSinkResult, BlockingSinkState, + BlockingSinkStatus, +}; +use crate::NUM_CPUS; + +#[derive(Clone)] +enum AggStrategy { + // TODO: This would probably benefit from doing sharded aggs. + AggThenPartition, + PartitionThenAgg(usize), + PartitionOnly, +} + +impl AggStrategy { + fn execute_strategy( + &self, + inner_states: &mut [Option], + input: Arc, + params: &GroupedAggregateParams, + ) -> DaftResult<()> { + match self { + Self::AggThenPartition => Self::execute_agg_then_partition(inner_states, input, params), + Self::PartitionThenAgg(threshold) => { + Self::execute_partition_then_agg(inner_states, input, params, *threshold) + } + Self::PartitionOnly => Self::execute_partition_only(inner_states, input, params), + } + } + + fn execute_agg_then_partition( + inner_states: &mut [Option], + input: Arc, + params: &GroupedAggregateParams, + ) -> DaftResult<()> { + let agged = input.agg( + params.partial_agg_exprs.as_slice(), + params.group_by.as_slice(), + )?; + let partitioned = + agged.partition_by_hash(params.final_group_by.as_slice(), inner_states.len())?; + for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { + let state = state.get_or_insert_default(); + state.partially_aggregated.push(p); + } + Ok(()) + } + + fn execute_partition_then_agg( + inner_states: &mut [Option], + input: Arc, + params: &GroupedAggregateParams, + partial_agg_threshold: usize, + ) -> DaftResult<()> { + let partitioned = + input.partition_by_hash(params.group_by.as_slice(), inner_states.len())?; + for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { + let state = state.get_or_insert_default(); + if state.unaggregated_size + p.len() >= partial_agg_threshold { + let unaggregated = std::mem::take(&mut state.unaggregated); + let aggregated = + MicroPartition::concat(unaggregated.iter().chain(std::iter::once(&p)))?.agg( + params.partial_agg_exprs.as_slice(), + params.group_by.as_slice(), + )?; + state.partially_aggregated.push(aggregated); + state.unaggregated_size = 0; + } else { + state.unaggregated_size += p.len(); + state.unaggregated.push(p); + } + } + Ok(()) + } + + fn execute_partition_only( + inner_states: &mut [Option], + input: Arc, + params: &GroupedAggregateParams, + ) -> DaftResult<()> { + let partitioned = + input.partition_by_hash(params.group_by.as_slice(), inner_states.len())?; + for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { + let state = state.get_or_insert_default(); + state.unaggregated_size += p.len(); + state.unaggregated.push(p); + } + Ok(()) + } +} + +#[derive(Default)] +struct SinglePartitionAggregateState { + partially_aggregated: Vec, + unaggregated: Vec, + unaggregated_size: usize, +} + +enum GroupedAggregateState { + Accumulating { + inner_states: Vec>, + strategy: Option, + partial_agg_threshold: usize, + high_cardinality_threshold_ratio: f64, + }, + Done, +} + +impl GroupedAggregateState { + fn new( + num_partitions: usize, + partial_agg_threshold: usize, + high_cardinality_threshold_ratio: f64, + ) -> Self { + let inner_states = (0..num_partitions).map(|_| None).collect::>(); + Self::Accumulating { + inner_states, + strategy: None, + partial_agg_threshold, + high_cardinality_threshold_ratio, + } + } + + fn push( + &mut self, + input: Arc, + params: &GroupedAggregateParams, + global_strategy_lock: &Arc>>, + ) -> DaftResult<()> { + let Self::Accumulating { + ref mut inner_states, + strategy, + partial_agg_threshold, + high_cardinality_threshold_ratio, + } = self + else { + panic!("GroupedAggregateSink should be in Accumulating state"); + }; + + // If we have determined a strategy, execute it. + if let Some(strategy) = strategy { + strategy.execute_strategy(inner_states, input, params)?; + } else { + // Otherwise, determine the strategy and execute + let decided_strategy = Self::determine_agg_strategy( + &input, + params, + *high_cardinality_threshold_ratio, + *partial_agg_threshold, + strategy, + global_strategy_lock, + )?; + decided_strategy.execute_strategy(inner_states, input, params)?; + } + Ok(()) + } + + fn determine_agg_strategy( + input: &Arc, + params: &GroupedAggregateParams, + high_cardinality_threshold_ratio: f64, + partial_agg_threshold: usize, + local_strategy_cache: &mut Option, + global_strategy_lock: &Arc>>, + ) -> DaftResult { + let mut global_strategy = global_strategy_lock.lock().unwrap(); + // If some other worker has determined a strategy, use that. + if let Some(global_strat) = global_strategy.as_ref() { + *local_strategy_cache = Some(global_strat.clone()); + return Ok(global_strat.clone()); + } + + // Else determine the strategy. + let groupby = input.eval_expression_list(params.group_by.as_slice())?; + + let groupkey_hashes = groupby + .get_tables()? + .iter() + .map(|t| t.hash_rows()) + .collect::>>()?; + let estimated_num_groups = groupkey_hashes + .iter() + .flatten() + .collect::>() + .len(); + + let decided_strategy = if estimated_num_groups as f64 / input.len() as f64 + >= high_cardinality_threshold_ratio + { + AggStrategy::PartitionThenAgg(partial_agg_threshold) + } else { + AggStrategy::AggThenPartition + }; + + *local_strategy_cache = Some(decided_strategy.clone()); + *global_strategy = Some(decided_strategy.clone()); + Ok(decided_strategy) + } + + fn finalize(&mut self) -> Vec> { + let res = if let Self::Accumulating { + ref mut inner_states, + .. + } = self + { + std::mem::take(inner_states) + } else { + panic!("GroupedAggregateSink should be in Accumulating state"); + }; + *self = Self::Done; + res + } +} + +impl BlockingSinkState for GroupedAggregateState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +struct GroupedAggregateParams { + // The original aggregations and group by expressions + original_aggregations: Vec, + group_by: Vec, + // The expressions for to be used for partial aggregation + partial_agg_exprs: Vec, + // The expressions for the final aggregation + final_agg_exprs: Vec, + final_group_by: Vec, + final_projections: Vec, +} + +pub struct GroupedAggregateSink { + grouped_aggregate_params: Arc, + partial_agg_threshold: usize, + high_cardinality_threshold_ratio: f64, + global_strategy_lock: Arc>>, +} + +impl GroupedAggregateSink { + pub fn new( + aggregations: &[ExprRef], + group_by: &[ExprRef], + schema: &SchemaRef, + cfg: &DaftExecutionConfig, + ) -> DaftResult { + let aggregations = aggregations + .iter() + .map(extract_agg_expr) + .collect::>>()?; + let (partial_aggs, final_aggs, final_projections) = + daft_physical_plan::populate_aggregation_stages(&aggregations, schema, group_by); + let partial_agg_exprs = partial_aggs + .into_values() + .map(|e| Arc::new(Expr::Agg(e))) + .collect::>(); + let final_agg_exprs = final_aggs + .into_values() + .map(|e| Arc::new(Expr::Agg(e))) + .collect::>(); + let final_group_by = if !partial_agg_exprs.is_empty() { + group_by.iter().map(|e| col(e.name())).collect::>() + } else { + group_by.to_vec() + }; + let strategy = if partial_agg_exprs.is_empty() && !final_agg_exprs.is_empty() { + Some(AggStrategy::PartitionOnly) + } else { + None + }; + Ok(Self { + grouped_aggregate_params: Arc::new(GroupedAggregateParams { + original_aggregations: aggregations + .into_iter() + .map(|e| Expr::Agg(e).into()) + .collect(), + group_by: group_by.to_vec(), + partial_agg_exprs, + final_agg_exprs, + final_group_by, + final_projections, + }), + partial_agg_threshold: cfg.partial_aggregation_threshold, + high_cardinality_threshold_ratio: cfg.high_cardinality_aggregation_threshold, + global_strategy_lock: Arc::new(Mutex::new(strategy)), + }) + } + + fn num_partitions(&self) -> usize { + *NUM_CPUS + } +} + +impl BlockingSink for GroupedAggregateSink { + #[instrument(skip_all, name = "GroupedAggregateSink::sink")] + fn sink( + &self, + input: Arc, + mut state: Box, + runtime: &RuntimeRef, + ) -> BlockingSinkSinkResult { + let params = self.grouped_aggregate_params.clone(); + let strategy_lock = self.global_strategy_lock.clone(); + runtime + .spawn( + async move { + let agg_state = state + .as_any_mut() + .downcast_mut::() + .expect("GroupedAggregateSink should have GroupedAggregateState"); + + agg_state.push(input, ¶ms, &strategy_lock)?; + Ok(BlockingSinkStatus::NeedMoreInput(state)) + } + .instrument(info_span!("GroupedAggregateSink::sink")), + ) + .into() + } + + #[instrument(skip_all, name = "GroupedAggregateSink::finalize")] + fn finalize( + &self, + states: Vec>, + runtime: &RuntimeRef, + ) -> BlockingSinkFinalizeResult { + let params = self.grouped_aggregate_params.clone(); + let num_partitions = self.num_partitions(); + runtime + .spawn( + async move { + let mut state_iters = states + .into_iter() + .map(|mut state| { + state + .as_any_mut() + .downcast_mut::() + .expect("GroupedAggregateSink should have GroupedAggregateState") + .finalize() + .into_iter() + }) + .collect::>(); + + let mut per_partition_finalize_tasks = tokio::task::JoinSet::new(); + for _ in 0..num_partitions { + let per_partition_state = state_iters + .iter_mut() + .map(|state| { + state.next().expect( + "GroupedAggregateState should have SinglePartitionAggregateState", + ) + }) + .collect::>(); + let params = params.clone(); + per_partition_finalize_tasks.spawn(async move { + let mut unaggregated = vec![]; + let mut partially_aggregated = vec![]; + for state in per_partition_state.into_iter().flatten() { + unaggregated.extend(state.unaggregated); + partially_aggregated.extend(state.partially_aggregated); + } + + // If we have no partially aggregated partitions, aggregate the unaggregated partitions using the original aggregations + if partially_aggregated.is_empty() { + let concated = MicroPartition::concat(&unaggregated)?; + let agged = concated + .agg(¶ms.original_aggregations, ¶ms.group_by)?; + Ok(agged) + } + // If we have no unaggregated partitions, finalize the partially aggregated partitions + else if unaggregated.is_empty() { + let concated = MicroPartition::concat(&partially_aggregated)?; + let agged = concated + .agg(¶ms.final_agg_exprs, ¶ms.final_group_by)?; + let projected = + agged.eval_expression_list(¶ms.final_projections)?; + Ok(projected) + } + // Otherwise, partially aggregate the unaggregated partitions, concatenate them with the partially aggregated partitions, and finalize the result. + else { + let leftover_partial_agg = + MicroPartition::concat(&unaggregated)? + .agg(¶ms.partial_agg_exprs, ¶ms.group_by)?; + let concated = MicroPartition::concat( + partially_aggregated + .iter() + .chain(std::iter::once(&leftover_partial_agg)), + )?; + let agged = concated + .agg(¶ms.final_agg_exprs, ¶ms.final_group_by)?; + let projected = + agged.eval_expression_list(¶ms.final_projections)?; + Ok(projected) + } + }); + } + let results = per_partition_finalize_tasks + .join_all() + .await + .into_iter() + .collect::>>()?; + let concated = MicroPartition::concat(&results)?; + Ok(Some(Arc::new(concated))) + } + .instrument(info_span!("GroupedAggregateSink::finalize")), + ) + .into() + } + + fn name(&self) -> &'static str { + "GroupedAggregateSink" + } + + fn max_concurrency(&self) -> usize { + *NUM_CPUS + } + + fn make_state(&self) -> DaftResult> { + Ok(Box::new(GroupedAggregateState::new( + self.num_partitions(), + self.partial_agg_threshold, + self.high_cardinality_threshold_ratio, + ))) + } +} diff --git a/src/daft-local-execution/src/sinks/mod.rs b/src/daft-local-execution/src/sinks/mod.rs index 497e05e50e..8d36320128 100644 --- a/src/daft-local-execution/src/sinks/mod.rs +++ b/src/daft-local-execution/src/sinks/mod.rs @@ -2,6 +2,7 @@ pub mod aggregate; pub mod blocking_sink; pub mod concat; pub mod cross_join_collect; +pub mod grouped_aggregate; pub mod hash_join_build; pub mod limit; pub mod monotonically_increasing_id; diff --git a/tests/dataframe/test_intersect.py b/tests/dataframe/test_intersect.py index 24330ec84e..4061fe7f71 100644 --- a/tests/dataframe/test_intersect.py +++ b/tests/dataframe/test_intersect.py @@ -14,7 +14,7 @@ def test_simple_intersect(make_df): def test_intersect_with_duplicate(make_df): df1 = make_df({"foo": [1, 2, 2, 3]}) df2 = make_df({"bar": [2, 3, 3]}) - result = df1.intersect(df2) + result = df1.intersect(df2).sort(by="foo") assert result.to_pydict() == {"foo": [2, 3]} @@ -37,7 +37,7 @@ def test_intersect_with_nulls(make_df): df2 = make_df({"bar": [2, 3, None]}) df2_without_null = make_df({"bar": [2, 3]}) - result = df1.intersect(df2) + result = df1.intersect(df2).sort(by="foo") assert result.to_pydict() == {"foo": [2, None]} result = df1_without_mull.intersect(df2) diff --git a/tests/sql/test_aggs.py b/tests/sql/test_aggs.py index 7fede4dc00..ee001afb2c 100644 --- a/tests/sql/test_aggs.py +++ b/tests/sql/test_aggs.py @@ -83,6 +83,7 @@ def test_having(agg, cond, expected): from df group by id having {cond} + order by id """, catalog, ).to_pydict()