Skip to content

Commit

Permalink
Respecting distinct for min and max in optimizer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
edmondop committed Jul 22, 2024
1 parent 463df9a commit c2e5311
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 8 deletions.
2 changes: 2 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ pub mod expr_fn {
pub use super::grouping::grouping;
pub use super::median::median;
pub use super::min_max::max;
pub use super::min_max::max_distinct;
pub use super::min_max::min;
pub use super::min_max::min_distinct;
pub use super::regr::regr_avgx;
pub use super::regr::regr_avgy;
pub use super::regr::regr_count;
Expand Down
23 changes: 22 additions & 1 deletion datafusion/functions-aggregate/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ use arrow::datatypes::{
use arrow::datatypes::i256;

use datafusion_common::ScalarValue;
use datafusion_expr::GroupsAccumulator;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility,
};
use datafusion_expr::{Expr, GroupsAccumulator};

macro_rules! typed_min_max_float {
($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
Expand Down Expand Up @@ -784,6 +784,27 @@ make_udaf_expr_and_func!(
min_udaf
);

pub fn max_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
max_udaf(),
vec![expr],
true,
None,
None,
None,
))
}

pub fn min_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
min_udaf(),
vec![expr],
true,
None,
None,
None,
))
}
fn min_max_aggregate_data_type(input_type: DataType) -> DataType {
if let DataType::Dictionary(_, value_type) = input_type {
*value_type
Expand Down
18 changes: 11 additions & 7 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,13 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result<bool> {
if filter.is_some() || order_by.is_some() {
return Ok(false);
}
let fun_name = fun.name();
aggregate_count += 1;
if *distinct {
for e in args {
fields_set.insert(e);
}
} else if fun.name() != "sum"
&& fun.name().to_lowercase() != "min"
&& fun.name().to_lowercase() != "max"
{
} else if fun.name() != "sum" {
return Ok(false);
}
} else {
Expand Down Expand Up @@ -360,7 +358,9 @@ mod tests {
use datafusion_expr::AggregateExt;
use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::expr_fn::{count, count_distinct, max, min, sum};
use datafusion_functions_aggregate::expr_fn::{
count, count_distinct, max, max_distinct, min, min_distinct, sum,
};
use datafusion_functions_aggregate::sum::sum_udaf;

fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
Expand Down Expand Up @@ -521,7 +521,7 @@ mod tests {
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![col("a")],
vec![count_distinct(col("b")), max(col("b"))],
vec![count_distinct(col("b")), max_distinct(col("b"))],
)?
.build()?;
// Should work
Expand Down Expand Up @@ -575,7 +575,11 @@ mod tests {
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![col("a")],
vec![sum(col("c")), count_distinct(col("b")), max(col("b"))],
vec![
sum(col("c")),
count_distinct(col("b")),
max_distinct(col("b")),
],
)?
.build()?;
// Should work
Expand Down

0 comments on commit c2e5311

Please sign in to comment.