diff --git a/src/daft-dsl/src/resolve_expr/mod.rs b/src/daft-dsl/src/resolve_expr/mod.rs index 6425f69e0f..c9e9fc53ef 100644 --- a/src/daft-dsl/src/resolve_expr/mod.rs +++ b/src/daft-dsl/src/resolve_expr/mod.rs @@ -3,7 +3,7 @@ mod tests; use std::{ cmp::Ordering, - collections::{BinaryHeap, HashMap}, + collections::{BinaryHeap, HashMap, HashSet}, sync::Arc, }; @@ -207,31 +207,6 @@ fn expand_wildcards( } } -/// Checks if an expression used in an aggregation is well formed. -/// Expressions for aggregations must be in the form (optional) non-agg expr <- agg exprs or literals <- non-agg exprs -/// -/// # Examples -/// -/// Allowed: -/// - lit("x") -/// - sum(col("a")) -/// - sum(col("a")) > 0 -/// - sum(col("a")) - sum(col("b")) > sum(col("c")) -/// -/// Not allowed: -/// - col("a") -/// - not an aggregation -/// - sum(col("a")) + col("b") -/// - not all branches are aggregations -fn has_single_agg_layer(expr: &ExprRef) -> bool { - match expr.as_ref() { - Expr::Agg(agg_expr) => !agg_expr.children().iter().any(has_agg), - Expr::Column(_) => false, - Expr::Literal(_) => true, - _ => expr.children().iter().all(has_single_agg_layer), - } -} - fn convert_udfs_to_map_groups(expr: &ExprRef) -> ExprRef { expr.clone() .transform(|e| match e.as_ref() { @@ -247,40 +222,21 @@ fn convert_udfs_to_map_groups(expr: &ExprRef) -> ExprRef { .data } -fn validate_expr(expr: ExprRef) -> DaftResult { - if has_agg(&expr) { - return Err(DaftError::ValueError(format!( - "Aggregation expressions are currently only allowed in agg and pivot: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383", - ))); - } - - Ok(expr) -} - -fn validate_expr_in_agg(expr: ExprRef) -> DaftResult { - let converted_expr = convert_udfs_to_map_groups(&expr); - - if !has_single_agg_layer(&converted_expr) { - return Err(DaftError::ValueError(format!( - "Expressions in aggregations must be composed of non-nested aggregation expressions, got {expr}" - ))); - } - - Ok(converted_expr) -} - /// Used for resolving and validating expressions. /// Specifically, makes sure the expression does not contain aggregations or stateful UDFs /// where they are not allowed, and resolves struct accessors and wildcards. #[derive(Default, TypedBuilder)] -pub struct ExprResolver { +pub struct ExprResolver<'a> { #[builder(default)] allow_stateful_udf: bool, - #[builder(default)] - in_agg_context: bool, + + /// Set to Some when in an aggregation context, + /// with groupby expressions when relevant + #[builder(default, setter(strip_option))] + groupby: Option<&'a HashSet>, } -impl ExprResolver { +impl<'a> ExprResolver<'a> { fn resolve_helper(&self, expr: ExprRef, schema: &Schema) -> DaftResult> { if !self.allow_stateful_udf && has_stateful_udf(&expr) { return Err(DaftError::ValueError(format!( @@ -288,10 +244,10 @@ impl ExprResolver { ))); } - let validated_expr = if self.in_agg_context { - validate_expr_in_agg(expr) + let validated_expr = if self.groupby.is_some() { + self.validate_expr_in_agg(expr) } else { - validate_expr(expr) + self.validate_expr(expr) }?; let struct_expr_map = calculate_struct_expr_map(schema); @@ -330,6 +286,55 @@ impl ExprResolver { ))), } } + + fn validate_expr(&self, expr: ExprRef) -> DaftResult { + if has_agg(&expr) { + return Err(DaftError::ValueError(format!( + "Aggregation expressions are currently only allowed in agg and pivot: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383", + ))); + } + + Ok(expr) + } + + fn validate_expr_in_agg(&self, expr: ExprRef) -> DaftResult { + let converted_expr = convert_udfs_to_map_groups(&expr); + + if !self.is_valid_expr_in_agg(&converted_expr) { + return Err(DaftError::ValueError(format!( + "Expressions in aggregations must be composed of non-nested aggregation expressions, got {expr}" + ))); + } + + Ok(converted_expr) + } + + /// Checks if an expression used in an aggregation is well formed. + /// Expressions for aggregations must be in the form (optional) non-agg expr <- [(agg exprs <- non-agg exprs) or literals or group by keys] + /// + /// # Examples + /// + /// Allowed: + /// - lit("x") + /// - sum(col("a")) + /// - sum(col("a")) > 0 + /// - sum(col("a")) - sum(col("b")) > sum(col("c")) + /// - sum(col("a")) + col("b") when "b" is a group by key + /// + /// Not allowed: + /// - col("a") when "a" is not a group by key + /// - not an aggregation + /// - sum(col("a")) + col("b") when "b" is not a group by key + /// - not all branches are aggregations, literals, or group by keys + fn is_valid_expr_in_agg(&self, expr: &ExprRef) -> bool { + self.groupby.unwrap().contains(expr) + || match expr.as_ref() { + Expr::Agg(agg_expr) => !agg_expr.children().iter().any(has_agg), + Expr::Column(_) => false, + Expr::Literal(_) => true, + _ => expr.children().iter().all(|e| self.is_valid_expr_in_agg(e)), + } + } } pub fn check_column_name_validity(name: &str, schema: &Schema) -> DaftResult<()> { diff --git a/src/daft-logical-plan/src/ops/agg.rs b/src/daft-logical-plan/src/ops/agg.rs index b87a8f2a68..53702d140d 100644 --- a/src/daft-logical-plan/src/ops/agg.rs +++ b/src/daft-logical-plan/src/ops/agg.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; use daft_dsl::{ExprRef, ExprResolver}; use daft_schema::schema::{Schema, SchemaRef}; @@ -36,8 +36,10 @@ impl Aggregate { ) -> logical_plan::Result { let upstream_schema = input.schema(); + let groupby_set = HashSet::from_iter(groupby.clone()); + let groupby_resolver = ExprResolver::default(); - let agg_resolver = ExprResolver::builder().in_agg_context(true).build(); + let agg_resolver = ExprResolver::builder().groupby(&groupby_set).build(); let (groupby, groupby_fields) = groupby_resolver .resolve(groupby, &upstream_schema) diff --git a/src/daft-logical-plan/src/ops/pivot.rs b/src/daft-logical-plan/src/ops/pivot.rs index 49176d8048..f03b667be5 100644 --- a/src/daft-logical-plan/src/ops/pivot.rs +++ b/src/daft-logical-plan/src/ops/pivot.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; use common_error::DaftError; use daft_core::prelude::*; @@ -34,8 +34,10 @@ impl Pivot { ) -> logical_plan::Result { let upstream_schema = input.schema(); + let groupby_set = HashSet::from_iter(group_by.clone()); + let expr_resolver = ExprResolver::default(); - let agg_resolver = ExprResolver::builder().in_agg_context(true).build(); + let agg_resolver = ExprResolver::builder().groupby(&groupby_set).build(); let (group_by, group_by_fields) = expr_resolver .resolve(group_by, &upstream_schema) diff --git a/tests/dataframe/test_aggregations.py b/tests/dataframe/test_aggregations.py index 27302c873b..b35d42e164 100644 --- a/tests/dataframe/test_aggregations.py +++ b/tests/dataframe/test_aggregations.py @@ -562,3 +562,33 @@ def test_agg_with_literal_groupby(make_df, repartition_nparts, with_morsel_size) "sum_plus_1": [7, 10, 13], "1_plus_sum": [9, 12, 15], } + + +@pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) +def test_agg_with_groupby_key_in_agg(make_df, repartition_nparts, with_morsel_size): + daft_df = make_df( + { + "group": [1, 1, 1, 2, 2, 2, 3, 3, 3], + "id": [1, 2, 3, 2, 3, 4, 3, 4, 5], + "values": [4, 5, 6, 5, 6, 7, 6, 7, 8], + }, + repartition=repartition_nparts, + ) + + daft_df = ( + daft_df.groupby("group") + .agg( + col("group").alias("group_alias"), + (col("group") + 1).alias("group_plus_1"), + (col("id").sum() + col("group")).alias("id_plus_group"), + ) + .sort("group") + ) + + res = daft_df.to_pydict() + assert res == { + "group": [1, 2, 3], + "group_alias": [1, 2, 3], + "group_plus_1": [2, 3, 4], + "id_plus_group": [7, 11, 15], + }