From c2e5311c8c64dce0535f4ec7e06d4fd985cc2803 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Mon, 22 Jul 2024 12:00:48 +0000 Subject: [PATCH] Respecting distinct for min and max in optimizer tests --- datafusion/functions-aggregate/src/lib.rs | 2 ++ datafusion/functions-aggregate/src/min_max.rs | 23 ++++++++++++++++++- .../src/single_distinct_to_groupby.rs | 18 +++++++++------ 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index ab9bd6571e112..06b0f02dded63 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -112,7 +112,9 @@ pub mod expr_fn { pub use super::grouping::grouping; pub use super::median::median; pub use super::min_max::max; + pub use super::min_max::max_distinct; pub use super::min_max::min; + pub use super::min_max::min_distinct; pub use super::regr::regr_avgx; pub use super::regr::regr_avgy; pub use super::regr::regr_count; diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 4a03cc3203739..ac6bd1234f5f0 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -60,10 +60,10 @@ use arrow::datatypes::{ use arrow::datatypes::i256; use datafusion_common::ScalarValue; -use datafusion_expr::GroupsAccumulator; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, }; +use datafusion_expr::{Expr, GroupsAccumulator}; macro_rules! typed_min_max_float { ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ @@ -784,6 +784,27 @@ make_udaf_expr_and_func!( min_udaf ); +pub fn max_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + max_udaf(), + vec![expr], + true, + None, + None, + None, + )) +} + +pub fn min_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + min_udaf(), + vec![expr], + true, + None, + None, + None, + )) +} fn min_max_aggregate_data_type(input_type: DataType) -> DataType { if let DataType::Dictionary(_, value_type) = input_type { *value_type diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index d651397278040..427f09eae018f 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -101,15 +101,13 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { if filter.is_some() || order_by.is_some() { return Ok(false); } + let fun_name = fun.name(); aggregate_count += 1; if *distinct { for e in args { fields_set.insert(e); } - } else if fun.name() != "sum" - && fun.name().to_lowercase() != "min" - && fun.name().to_lowercase() != "max" - { + } else if fun.name() != "sum" { return Ok(false); } } else { @@ -360,7 +358,9 @@ mod tests { use datafusion_expr::AggregateExt; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use datafusion_functions_aggregate::count::count_udaf; - use datafusion_functions_aggregate::expr_fn::{count, count_distinct, max, min, sum}; + use datafusion_functions_aggregate::expr_fn::{ + count, count_distinct, max, max_distinct, min, min_distinct, sum, + }; use datafusion_functions_aggregate::sum::sum_udaf; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { @@ -521,7 +521,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], - vec![count_distinct(col("b")), max(col("b"))], + vec![count_distinct(col("b")), max_distinct(col("b"))], )? .build()?; // Should work @@ -575,7 +575,11 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], - vec![sum(col("c")), count_distinct(col("b")), max(col("b"))], + vec![ + sum(col("c")), + count_distinct(col("b")), + max_distinct(col("b")), + ], )? .build()?; // Should work