diff --git a/daft/context.py b/daft/context.py index 0b071a431b..dbec115099 100644 --- a/daft/context.py +++ b/daft/context.py @@ -344,6 +344,8 @@ def set_execution_config( csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, + partial_aggregation_threshold: int | None = None, + high_cardinality_aggregation_threshold: float | None = None, read_sql_partition_size_bytes: int | None = None, enable_aqe: bool | None = None, enable_native_executor: bool | None = None, @@ -384,7 +386,9 @@ def set_execution_config( parquet_inflation_factor: Inflation Factor of parquet files (In-Memory-Size / File-Size) ratio. Defaults to 3.0 csv_target_filesize: Target File Size when writing out CSV Files. Defaults to 512MB csv_inflation_factor: Inflation Factor of CSV files (In-Memory-Size / File-Size) ratio. Defaults to 0.5 - shuffle_aggregation_default_partitions: Maximum number of partitions to create when performing aggregations. Defaults to 200, unless the number of input partitions is less than 200. + shuffle_aggregation_default_partitions: Maximum number of partitions to create when performing aggregations on the Ray Runner. Defaults to 200, unless the number of input partitions is less than 200. + partial_aggregation_threshold: Threshold for performing partial aggregations on the Native Runner. Defaults to 10000 rows. + high_cardinality_aggregation_threshold: Threshold selectivity for performing high cardinality aggregations on the Native Runner. Defaults to 0.8. read_sql_partition_size_bytes: Target size of partition when reading from SQL databases. Defaults to 512MB enable_aqe: Enables Adaptive Query Execution, Defaults to False enable_native_executor: Enables the native executor, Defaults to False @@ -413,6 +417,8 @@ def set_execution_config( csv_target_filesize=csv_target_filesize, csv_inflation_factor=csv_inflation_factor, shuffle_aggregation_default_partitions=shuffle_aggregation_default_partitions, + partial_aggregation_threshold=partial_aggregation_threshold, + high_cardinality_aggregation_threshold=high_cardinality_aggregation_threshold, read_sql_partition_size_bytes=read_sql_partition_size_bytes, enable_aqe=enable_aqe, enable_native_executor=enable_native_executor, diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 47e88d9afd..da292d2df1 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1730,6 +1730,8 @@ class PyDaftExecutionConfig: csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, + partial_aggregation_threshold: int | None = None, + high_cardinality_aggregation_threshold: float | None = None, read_sql_partition_size_bytes: int | None = None, enable_aqe: bool | None = None, enable_native_executor: bool | None = None, @@ -1765,6 +1767,10 @@ class PyDaftExecutionConfig: @property def shuffle_aggregation_default_partitions(self) -> int: ... @property + def partial_aggregation_threshold(self) -> int: ... + @property + def high_cardinality_aggregation_threshold(self) -> float: ... + @property def read_sql_partition_size_bytes(self) -> int: ... @property def enable_aqe(self) -> bool: ... diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index 3ffa4ea158..590fd5cf6c 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -44,6 +44,8 @@ pub struct DaftExecutionConfig { pub csv_target_filesize: usize, pub csv_inflation_factor: f64, pub shuffle_aggregation_default_partitions: usize, + pub partial_aggregation_threshold: usize, + pub high_cardinality_aggregation_threshold: f64, pub read_sql_partition_size_bytes: usize, pub enable_aqe: bool, pub enable_native_executor: bool, @@ -70,6 +72,8 @@ impl Default for DaftExecutionConfig { csv_target_filesize: 512 * 1024 * 1024, // 512MB csv_inflation_factor: 0.5, shuffle_aggregation_default_partitions: 200, + partial_aggregation_threshold: 10000, + high_cardinality_aggregation_threshold: 0.8, read_sql_partition_size_bytes: 512 * 1024 * 1024, // 512MB enable_aqe: false, enable_native_executor: false, diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index aceefd63d2..3228263b07 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -90,6 +90,8 @@ impl PyDaftExecutionConfig { csv_target_filesize: Option, csv_inflation_factor: Option, shuffle_aggregation_default_partitions: Option, + partial_aggregation_threshold: Option, + high_cardinality_aggregation_threshold: Option, read_sql_partition_size_bytes: Option, enable_aqe: Option, enable_native_executor: Option, @@ -146,6 +148,13 @@ impl PyDaftExecutionConfig { { config.shuffle_aggregation_default_partitions = shuffle_aggregation_default_partitions; } + if let Some(partial_aggregation_threshold) = partial_aggregation_threshold { + config.partial_aggregation_threshold = partial_aggregation_threshold; + } + if let Some(high_cardinality_aggregation_threshold) = high_cardinality_aggregation_threshold + { + config.high_cardinality_aggregation_threshold = high_cardinality_aggregation_threshold; + } if let Some(read_sql_partition_size_bytes) = read_sql_partition_size_bytes { config.read_sql_partition_size_bytes = read_sql_partition_size_bytes; } @@ -245,6 +254,16 @@ impl PyDaftExecutionConfig { Ok(self.config.shuffle_aggregation_default_partitions) } + #[getter] + fn get_partial_aggregation_threshold(&self) -> PyResult { + Ok(self.config.partial_aggregation_threshold) + } + + #[getter] + fn get_high_cardinality_aggregation_threshold(&self) -> PyResult { + Ok(self.config.high_cardinality_aggregation_threshold) + } + #[getter] fn get_read_sql_partition_size_bytes(&self) -> PyResult { Ok(self.config.read_sql_partition_size_bytes) diff --git a/src/daft-local-execution/src/intermediate_ops/aggregate.rs b/src/daft-local-execution/src/intermediate_ops/aggregate.rs deleted file mode 100644 index cb9344b160..0000000000 --- a/src/daft-local-execution/src/intermediate_ops/aggregate.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::sync::Arc; - -use common_runtime::RuntimeRef; -use daft_dsl::ExprRef; -use daft_micropartition::MicroPartition; -use tracing::instrument; - -use super::intermediate_op::{ - IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator, - IntermediateOperatorResult, -}; - -struct AggParams { - agg_exprs: Vec, - group_by: Vec, -} - -pub struct AggregateOperator { - params: Arc, -} - -impl AggregateOperator { - pub fn new(agg_exprs: Vec, group_by: Vec) -> Self { - Self { - params: Arc::new(AggParams { - agg_exprs, - group_by, - }), - } - } -} - -impl IntermediateOperator for AggregateOperator { - #[instrument(skip_all, name = "AggregateOperator::execute")] - fn execute( - &self, - input: Arc, - state: Box, - runtime: &RuntimeRef, - ) -> IntermediateOpExecuteResult { - let params = self.params.clone(); - runtime - .spawn(async move { - let out = input.agg(¶ms.agg_exprs, ¶ms.group_by)?; - Ok(( - state, - IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(out))), - )) - }) - .into() - } - - fn name(&self) -> &'static str { - "AggregateOperator" - } -} diff --git a/src/daft-local-execution/src/intermediate_ops/mod.rs b/src/daft-local-execution/src/intermediate_ops/mod.rs index f0af96b763..ca01bafaa4 100644 --- a/src/daft-local-execution/src/intermediate_ops/mod.rs +++ b/src/daft-local-execution/src/intermediate_ops/mod.rs @@ -1,5 +1,4 @@ pub mod actor_pool_project; -pub mod aggregate; pub mod anti_semi_hash_join_probe; pub mod cross_join; pub mod explode; diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index ec6020eaef..b1b5d99bb0 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -1,4 +1,6 @@ #![feature(let_chains)] +#![feature(option_get_or_insert_default)] + mod buffer; mod channel; mod dispatcher; diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 0c75de9be6..8881cb37de 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -10,7 +10,7 @@ use daft_core::{ prelude::{Schema, SchemaRef}, utils::supertype, }; -use daft_dsl::{col, join::get_common_join_keys, Expr}; +use daft_dsl::{col, join::get_common_join_keys}; use daft_local_plan::{ ActorPoolProject, Concat, CrossJoin, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, MonotonicallyIncreasingId, PhysicalWrite, Pivot, @@ -18,7 +18,6 @@ use daft_local_plan::{ }; use daft_logical_plan::{stats::StatsState, JoinType}; use daft_micropartition::{partitioning::PartitionSet, MicroPartition}; -use daft_physical_plan::{extract_agg_expr, populate_aggregation_stages}; use daft_scan::ScanTaskRef; use daft_writers::make_physical_writer_factory; use indexmap::IndexSet; @@ -27,7 +26,7 @@ use snafu::ResultExt; use crate::{ channel::Receiver, intermediate_ops::{ - actor_pool_project::ActorPoolProjectOperator, aggregate::AggregateOperator, + actor_pool_project::ActorPoolProjectOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator, cross_join::CrossJoinOperator, explode::ExplodeOperator, filter::FilterOperator, inner_hash_join_probe::InnerHashJoinProbeOperator, intermediate_op::IntermediateNode, @@ -38,6 +37,7 @@ use crate::{ blocking_sink::BlockingSinkNode, concat::ConcatSink, cross_join_collect::CrossJoinCollectSink, + grouped_aggregate::GroupedAggregateSink, hash_join_build::HashJoinBuildSink, limit::LimitSink, monotonically_increasing_id::MonotonicallyIncreasingIdSink, @@ -170,42 +170,13 @@ pub fn physical_plan_to_pipeline( schema, .. }) => { - let aggregations = aggregations - .iter() - .map(extract_agg_expr) - .collect::>>() - .with_context(|_| PipelineCreationSnafu { - plan_name: physical_plan.name(), - })?; - - let (first_stage_aggs, second_stage_aggs, final_exprs) = - populate_aggregation_stages(&aggregations, schema, &[]); - let first_stage_agg_op = AggregateOperator::new( - first_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - vec![], - ); let child_node = physical_plan_to_pipeline(input, psets, cfg)?; - let post_first_agg_node = - IntermediateNode::new(Arc::new(first_stage_agg_op), vec![child_node]).boxed(); - - let second_stage_agg_sink = AggregateSink::new( - second_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - vec![], - ); - let second_stage_node = - BlockingSinkNode::new(Arc::new(second_stage_agg_sink), post_first_agg_node).boxed(); - - let final_stage_project = ProjectOperator::new(final_exprs); - - IntermediateNode::new(Arc::new(final_stage_project), vec![second_stage_node]).boxed() + let agg_sink = AggregateSink::new(aggregations, schema).with_context(|_| { + PipelineCreationSnafu { + plan_name: physical_plan.name(), + } + })?; + BlockingSinkNode::new(Arc::new(agg_sink), child_node).boxed() } LocalPhysicalPlan::HashAggregate(HashAggregate { input, @@ -214,48 +185,12 @@ pub fn physical_plan_to_pipeline( schema, .. }) => { - let aggregations = aggregations - .iter() - .map(extract_agg_expr) - .collect::>>() + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; + let agg_sink = GroupedAggregateSink::new(aggregations, group_by, schema, cfg) .with_context(|_| PipelineCreationSnafu { plan_name: physical_plan.name(), })?; - - let (first_stage_aggs, second_stage_aggs, final_exprs) = - populate_aggregation_stages(&aggregations, schema, group_by); - let child_node = physical_plan_to_pipeline(input, psets, cfg)?; - let (post_first_agg_node, group_by) = if !first_stage_aggs.is_empty() { - let agg_op = AggregateOperator::new( - first_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - group_by.clone(), - ); - ( - IntermediateNode::new(Arc::new(agg_op), vec![child_node]).boxed(), - &group_by.iter().map(|e| col(e.name())).collect(), - ) - } else { - (child_node, group_by) - }; - - let second_stage_agg_sink = AggregateSink::new( - second_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - group_by.clone(), - ); - let second_stage_node = - BlockingSinkNode::new(Arc::new(second_stage_agg_sink), post_first_agg_node).boxed(); - - let final_stage_project = ProjectOperator::new(final_exprs); - - IntermediateNode::new(Arc::new(final_stage_project), vec![second_stage_node]).boxed() + BlockingSinkNode::new(Arc::new(agg_sink), child_node).boxed() } LocalPhysicalPlan::Unpivot(Unpivot { input, diff --git a/src/daft-local-execution/src/sinks/aggregate.rs b/src/daft-local-execution/src/sinks/aggregate.rs index 4ac74c8fb5..148fb7e2e4 100644 --- a/src/daft-local-execution/src/sinks/aggregate.rs +++ b/src/daft-local-execution/src/sinks/aggregate.rs @@ -2,8 +2,10 @@ use std::sync::Arc; use common_error::DaftResult; use common_runtime::RuntimeRef; -use daft_dsl::ExprRef; +use daft_core::prelude::SchemaRef; +use daft_dsl::{Expr, ExprRef}; use daft_micropartition::MicroPartition; +use daft_physical_plan::extract_agg_expr; use tracing::instrument; use super::blocking_sink::{ @@ -44,8 +46,9 @@ impl BlockingSinkState for AggregateState { } struct AggParams { - agg_exprs: Vec, - group_by: Vec, + sink_agg_exprs: Vec, + finalize_agg_exprs: Vec, + final_projections: Vec, } pub struct AggregateSink { @@ -53,13 +56,31 @@ pub struct AggregateSink { } impl AggregateSink { - pub fn new(agg_exprs: Vec, group_by: Vec) -> Self { - Self { + pub fn new(aggregations: &[ExprRef], schema: &SchemaRef) -> DaftResult { + let aggregations = aggregations + .iter() + .map(extract_agg_expr) + .collect::>>()?; + let (sink_aggs, finalize_aggs, final_projections) = + daft_physical_plan::populate_aggregation_stages(&aggregations, schema, &[]); + let sink_agg_exprs = sink_aggs + .values() + .cloned() + .map(|e| Arc::new(Expr::Agg(e))) + .collect(); + let finalize_agg_exprs = finalize_aggs + .values() + .cloned() + .map(|e| Arc::new(Expr::Agg(e))) + .collect(); + + Ok(Self { agg_sink_params: Arc::new(AggParams { - agg_exprs, - group_by, + sink_agg_exprs, + finalize_agg_exprs, + final_projections, }), - } + }) } } @@ -69,14 +90,20 @@ impl BlockingSink for AggregateSink { &self, input: Arc, mut state: Box, - _runtime: &RuntimeRef, + runtime: &RuntimeRef, ) -> BlockingSinkSinkResult { - state - .as_any_mut() - .downcast_mut::() - .expect("AggregateSink should have AggregateState") - .push(input); - Ok(BlockingSinkStatus::NeedMoreInput(state)).into() + let params = self.agg_sink_params.clone(); + runtime + .spawn(async move { + let agg_state = state + .as_any_mut() + .downcast_mut::() + .expect("AggregateSink should have AggregateState"); + let agged = Arc::new(input.agg(¶ms.sink_agg_exprs, &[])?); + agg_state.push(agged); + Ok(BlockingSinkStatus::NeedMoreInput(state)) + }) + .into() } #[instrument(skip_all, name = "AggregateSink::finalize")] @@ -96,8 +123,9 @@ impl BlockingSink for AggregateSink { .finalize() }); let concated = MicroPartition::concat(all_parts)?; - let agged = Arc::new(concated.agg(¶ms.agg_exprs, ¶ms.group_by)?); - Ok(Some(agged)) + let agged = concated.agg(¶ms.finalize_agg_exprs, &[])?; + let projected = agged.eval_expression_list(¶ms.final_projections)?; + Ok(Some(Arc::new(projected))) }) .into() } diff --git a/src/daft-local-execution/src/sinks/grouped_aggregate.rs b/src/daft-local-execution/src/sinks/grouped_aggregate.rs new file mode 100644 index 0000000000..c240480eed --- /dev/null +++ b/src/daft-local-execution/src/sinks/grouped_aggregate.rs @@ -0,0 +1,438 @@ +use std::{ + collections::HashSet, + sync::{Arc, Mutex}, +}; + +use common_daft_config::DaftExecutionConfig; +use common_error::DaftResult; +use common_runtime::RuntimeRef; +use daft_core::prelude::SchemaRef; +use daft_dsl::{col, Expr, ExprRef}; +use daft_micropartition::MicroPartition; +use daft_physical_plan::extract_agg_expr; +use tracing::{info_span, instrument, Instrument}; + +use super::blocking_sink::{ + BlockingSink, BlockingSinkFinalizeResult, BlockingSinkSinkResult, BlockingSinkState, + BlockingSinkStatus, +}; +use crate::NUM_CPUS; + +#[derive(Clone)] +enum AggStrategy { + // TODO: This would probably benefit from doing sharded aggs. + AggThenPartition, + PartitionThenAgg(usize), + PartitionOnly, +} + +impl AggStrategy { + fn execute_strategy( + &self, + inner_states: &mut [Option], + input: Arc, + params: &GroupedAggregateParams, + ) -> DaftResult<()> { + match self { + Self::AggThenPartition => Self::execute_agg_then_partition(inner_states, input, params), + Self::PartitionThenAgg(threshold) => { + Self::execute_partition_then_agg(inner_states, input, params, *threshold) + } + Self::PartitionOnly => Self::execute_partition_only(inner_states, input, params), + } + } + + fn execute_agg_then_partition( + inner_states: &mut [Option], + input: Arc, + params: &GroupedAggregateParams, + ) -> DaftResult<()> { + let agged = input.agg( + params.partial_agg_exprs.as_slice(), + params.group_by.as_slice(), + )?; + let partitioned = + agged.partition_by_hash(params.final_group_by.as_slice(), inner_states.len())?; + for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { + let state = state.get_or_insert_default(); + state.partially_aggregated.push(p); + } + Ok(()) + } + + fn execute_partition_then_agg( + inner_states: &mut [Option], + input: Arc, + params: &GroupedAggregateParams, + partial_agg_threshold: usize, + ) -> DaftResult<()> { + let partitioned = + input.partition_by_hash(params.group_by.as_slice(), inner_states.len())?; + for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { + let state = state.get_or_insert_default(); + if state.unaggregated_size + p.len() >= partial_agg_threshold { + let unaggregated = std::mem::take(&mut state.unaggregated); + let aggregated = + MicroPartition::concat(unaggregated.iter().chain(std::iter::once(&p)))?.agg( + params.partial_agg_exprs.as_slice(), + params.group_by.as_slice(), + )?; + state.partially_aggregated.push(aggregated); + state.unaggregated_size = 0; + } else { + state.unaggregated_size += p.len(); + state.unaggregated.push(p); + } + } + Ok(()) + } + + fn execute_partition_only( + inner_states: &mut [Option], + input: Arc, + params: &GroupedAggregateParams, + ) -> DaftResult<()> { + let partitioned = + input.partition_by_hash(params.group_by.as_slice(), inner_states.len())?; + for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { + let state = state.get_or_insert_default(); + state.unaggregated_size += p.len(); + state.unaggregated.push(p); + } + Ok(()) + } +} + +#[derive(Default)] +struct SinglePartitionAggregateState { + partially_aggregated: Vec, + unaggregated: Vec, + unaggregated_size: usize, +} + +enum GroupedAggregateState { + Accumulating { + inner_states: Vec>, + strategy: Option, + partial_agg_threshold: usize, + high_cardinality_threshold_ratio: f64, + }, + Done, +} + +impl GroupedAggregateState { + fn new( + num_partitions: usize, + partial_agg_threshold: usize, + high_cardinality_threshold_ratio: f64, + ) -> Self { + let inner_states = (0..num_partitions).map(|_| None).collect::>(); + Self::Accumulating { + inner_states, + strategy: None, + partial_agg_threshold, + high_cardinality_threshold_ratio, + } + } + + fn push( + &mut self, + input: Arc, + params: &GroupedAggregateParams, + global_strategy_lock: &Arc>>, + ) -> DaftResult<()> { + let Self::Accumulating { + ref mut inner_states, + strategy, + partial_agg_threshold, + high_cardinality_threshold_ratio, + } = self + else { + panic!("GroupedAggregateSink should be in Accumulating state"); + }; + + // If we have determined a strategy, execute it. + if let Some(strategy) = strategy { + strategy.execute_strategy(inner_states, input, params)?; + } else { + // Otherwise, determine the strategy and execute + let decided_strategy = Self::determine_agg_strategy( + &input, + params, + *high_cardinality_threshold_ratio, + *partial_agg_threshold, + strategy, + global_strategy_lock, + )?; + decided_strategy.execute_strategy(inner_states, input, params)?; + } + Ok(()) + } + + fn determine_agg_strategy( + input: &Arc, + params: &GroupedAggregateParams, + high_cardinality_threshold_ratio: f64, + partial_agg_threshold: usize, + local_strategy_cache: &mut Option, + global_strategy_lock: &Arc>>, + ) -> DaftResult { + let mut global_strategy = global_strategy_lock.lock().unwrap(); + // If some other worker has determined a strategy, use that. + if let Some(global_strat) = global_strategy.as_ref() { + *local_strategy_cache = Some(global_strat.clone()); + return Ok(global_strat.clone()); + } + + // Else determine the strategy. + let groupby = input.eval_expression_list(params.group_by.as_slice())?; + + let groupkey_hashes = groupby + .get_tables()? + .iter() + .map(|t| t.hash_rows()) + .collect::>>()?; + let estimated_num_groups = groupkey_hashes + .iter() + .flatten() + .collect::>() + .len(); + + let decided_strategy = if estimated_num_groups as f64 / input.len() as f64 + >= high_cardinality_threshold_ratio + { + AggStrategy::PartitionThenAgg(partial_agg_threshold) + } else { + AggStrategy::AggThenPartition + }; + + *local_strategy_cache = Some(decided_strategy.clone()); + *global_strategy = Some(decided_strategy.clone()); + Ok(decided_strategy) + } + + fn finalize(&mut self) -> Vec> { + let res = if let Self::Accumulating { + ref mut inner_states, + .. + } = self + { + std::mem::take(inner_states) + } else { + panic!("GroupedAggregateSink should be in Accumulating state"); + }; + *self = Self::Done; + res + } +} + +impl BlockingSinkState for GroupedAggregateState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +struct GroupedAggregateParams { + // The original aggregations and group by expressions + original_aggregations: Vec, + group_by: Vec, + // The expressions for to be used for partial aggregation + partial_agg_exprs: Vec, + // The expressions for the final aggregation + final_agg_exprs: Vec, + final_group_by: Vec, + final_projections: Vec, +} + +pub struct GroupedAggregateSink { + grouped_aggregate_params: Arc, + partial_agg_threshold: usize, + high_cardinality_threshold_ratio: f64, + global_strategy_lock: Arc>>, +} + +impl GroupedAggregateSink { + pub fn new( + aggregations: &[ExprRef], + group_by: &[ExprRef], + schema: &SchemaRef, + cfg: &DaftExecutionConfig, + ) -> DaftResult { + let aggregations = aggregations + .iter() + .map(extract_agg_expr) + .collect::>>()?; + let (partial_aggs, final_aggs, final_projections) = + daft_physical_plan::populate_aggregation_stages(&aggregations, schema, group_by); + let partial_agg_exprs = partial_aggs + .into_values() + .map(|e| Arc::new(Expr::Agg(e))) + .collect::>(); + let final_agg_exprs = final_aggs + .into_values() + .map(|e| Arc::new(Expr::Agg(e))) + .collect::>(); + let final_group_by = if !partial_agg_exprs.is_empty() { + group_by.iter().map(|e| col(e.name())).collect::>() + } else { + group_by.to_vec() + }; + let strategy = if partial_agg_exprs.is_empty() && !final_agg_exprs.is_empty() { + Some(AggStrategy::PartitionOnly) + } else { + None + }; + Ok(Self { + grouped_aggregate_params: Arc::new(GroupedAggregateParams { + original_aggregations: aggregations + .into_iter() + .map(|e| Expr::Agg(e).into()) + .collect(), + group_by: group_by.to_vec(), + partial_agg_exprs, + final_agg_exprs, + final_group_by, + final_projections, + }), + partial_agg_threshold: cfg.partial_aggregation_threshold, + high_cardinality_threshold_ratio: cfg.high_cardinality_aggregation_threshold, + global_strategy_lock: Arc::new(Mutex::new(strategy)), + }) + } + + fn num_partitions(&self) -> usize { + *NUM_CPUS + } +} + +impl BlockingSink for GroupedAggregateSink { + #[instrument(skip_all, name = "GroupedAggregateSink::sink")] + fn sink( + &self, + input: Arc, + mut state: Box, + runtime: &RuntimeRef, + ) -> BlockingSinkSinkResult { + let params = self.grouped_aggregate_params.clone(); + let strategy_lock = self.global_strategy_lock.clone(); + runtime + .spawn( + async move { + let agg_state = state + .as_any_mut() + .downcast_mut::() + .expect("GroupedAggregateSink should have GroupedAggregateState"); + + agg_state.push(input, ¶ms, &strategy_lock)?; + Ok(BlockingSinkStatus::NeedMoreInput(state)) + } + .instrument(info_span!("GroupedAggregateSink::sink")), + ) + .into() + } + + #[instrument(skip_all, name = "GroupedAggregateSink::finalize")] + fn finalize( + &self, + states: Vec>, + runtime: &RuntimeRef, + ) -> BlockingSinkFinalizeResult { + let params = self.grouped_aggregate_params.clone(); + let num_partitions = self.num_partitions(); + runtime + .spawn( + async move { + let mut state_iters = states + .into_iter() + .map(|mut state| { + state + .as_any_mut() + .downcast_mut::() + .expect("GroupedAggregateSink should have GroupedAggregateState") + .finalize() + .into_iter() + }) + .collect::>(); + + let mut per_partition_finalize_tasks = tokio::task::JoinSet::new(); + for _ in 0..num_partitions { + let per_partition_state = state_iters + .iter_mut() + .map(|state| { + state.next().expect( + "GroupedAggregateState should have SinglePartitionAggregateState", + ) + }) + .collect::>(); + let params = params.clone(); + per_partition_finalize_tasks.spawn(async move { + let mut unaggregated = vec![]; + let mut partially_aggregated = vec![]; + for state in per_partition_state.into_iter().flatten() { + unaggregated.extend(state.unaggregated); + partially_aggregated.extend(state.partially_aggregated); + } + + // If we have no partially aggregated partitions, aggregate the unaggregated partitions using the original aggregations + if partially_aggregated.is_empty() { + let concated = MicroPartition::concat(&unaggregated)?; + let agged = concated + .agg(¶ms.original_aggregations, ¶ms.group_by)?; + Ok(agged) + } + // If we have no unaggregated partitions, finalize the partially aggregated partitions + else if unaggregated.is_empty() { + let concated = MicroPartition::concat(&partially_aggregated)?; + let agged = concated + .agg(¶ms.final_agg_exprs, ¶ms.final_group_by)?; + let projected = + agged.eval_expression_list(¶ms.final_projections)?; + Ok(projected) + } + // Otherwise, partially aggregate the unaggregated partitions, concatenate them with the partially aggregated partitions, and finalize the result. + else { + let leftover_partial_agg = + MicroPartition::concat(&unaggregated)? + .agg(¶ms.partial_agg_exprs, ¶ms.group_by)?; + let concated = MicroPartition::concat( + partially_aggregated + .iter() + .chain(std::iter::once(&leftover_partial_agg)), + )?; + let agged = concated + .agg(¶ms.final_agg_exprs, ¶ms.final_group_by)?; + let projected = + agged.eval_expression_list(¶ms.final_projections)?; + Ok(projected) + } + }); + } + let results = per_partition_finalize_tasks + .join_all() + .await + .into_iter() + .collect::>>()?; + let concated = MicroPartition::concat(&results)?; + Ok(Some(Arc::new(concated))) + } + .instrument(info_span!("GroupedAggregateSink::finalize")), + ) + .into() + } + + fn name(&self) -> &'static str { + "GroupedAggregateSink" + } + + fn max_concurrency(&self) -> usize { + *NUM_CPUS + } + + fn make_state(&self) -> DaftResult> { + Ok(Box::new(GroupedAggregateState::new( + self.num_partitions(), + self.partial_agg_threshold, + self.high_cardinality_threshold_ratio, + ))) + } +} diff --git a/src/daft-local-execution/src/sinks/mod.rs b/src/daft-local-execution/src/sinks/mod.rs index 497e05e50e..8d36320128 100644 --- a/src/daft-local-execution/src/sinks/mod.rs +++ b/src/daft-local-execution/src/sinks/mod.rs @@ -2,6 +2,7 @@ pub mod aggregate; pub mod blocking_sink; pub mod concat; pub mod cross_join_collect; +pub mod grouped_aggregate; pub mod hash_join_build; pub mod limit; pub mod monotonically_increasing_id; diff --git a/tests/dataframe/test_intersect.py b/tests/dataframe/test_intersect.py index 24330ec84e..4061fe7f71 100644 --- a/tests/dataframe/test_intersect.py +++ b/tests/dataframe/test_intersect.py @@ -14,7 +14,7 @@ def test_simple_intersect(make_df): def test_intersect_with_duplicate(make_df): df1 = make_df({"foo": [1, 2, 2, 3]}) df2 = make_df({"bar": [2, 3, 3]}) - result = df1.intersect(df2) + result = df1.intersect(df2).sort(by="foo") assert result.to_pydict() == {"foo": [2, 3]} @@ -37,7 +37,7 @@ def test_intersect_with_nulls(make_df): df2 = make_df({"bar": [2, 3, None]}) df2_without_null = make_df({"bar": [2, 3]}) - result = df1.intersect(df2) + result = df1.intersect(df2).sort(by="foo") assert result.to_pydict() == {"foo": [2, None]} result = df1_without_mull.intersect(df2) diff --git a/tests/sql/test_aggs.py b/tests/sql/test_aggs.py index 7fede4dc00..ee001afb2c 100644 --- a/tests/sql/test_aggs.py +++ b/tests/sql/test_aggs.py @@ -83,6 +83,7 @@ def test_having(agg, cond, expected): from df group by id having {cond} + order by id """, catalog, ).to_pydict()