diff --git a/src/daft-dsl/src/resolve_expr/mod.rs b/src/daft-dsl/src/resolve_expr/mod.rs index c9e9fc53ef..96bbf90513 100644 --- a/src/daft-dsl/src/resolve_expr/mod.rs +++ b/src/daft-dsl/src/resolve_expr/mod.rs @@ -229,11 +229,20 @@ fn convert_udfs_to_map_groups(expr: &ExprRef) -> ExprRef { pub struct ExprResolver<'a> { #[builder(default)] allow_stateful_udf: bool, - - /// Set to Some when in an aggregation context, - /// with groupby expressions when relevant - #[builder(default, setter(strip_option))] - groupby: Option<&'a HashSet>, + #[builder(via_mutators, mutators( + pub fn in_agg_context(&mut self, in_agg_context: bool) { + // workaround since typed_builder can't have defaults for mutator requirements + self.in_agg_context = in_agg_context; + } + ))] + in_agg_context: bool, + #[builder(via_mutators, mutators( + pub fn groupby(&mut self, groupby: &'a Vec) { + self.groupby = HashSet::from_iter(groupby); + self.in_agg_context = true; + } + ))] + groupby: HashSet<&'a ExprRef>, } impl<'a> ExprResolver<'a> { @@ -244,7 +253,7 @@ impl<'a> ExprResolver<'a> { ))); } - let validated_expr = if self.groupby.is_some() { + let validated_expr = if self.in_agg_context { self.validate_expr_in_agg(expr) } else { self.validate_expr(expr) @@ -327,7 +336,7 @@ impl<'a> ExprResolver<'a> { /// - sum(col("a")) + col("b") when "b" is not a group by key /// - not all branches are aggregations, literals, or group by keys fn is_valid_expr_in_agg(&self, expr: &ExprRef) -> bool { - self.groupby.unwrap().contains(expr) + self.groupby.contains(expr) || match expr.as_ref() { Expr::Agg(agg_expr) => !agg_expr.children().iter().any(has_agg), Expr::Column(_) => false, diff --git a/src/daft-logical-plan/src/ops/agg.rs b/src/daft-logical-plan/src/ops/agg.rs index 53702d140d..be82d0d010 100644 --- a/src/daft-logical-plan/src/ops/agg.rs +++ b/src/daft-logical-plan/src/ops/agg.rs @@ -1,4 +1,4 @@ -use std::{collections::HashSet, sync::Arc}; +use std::sync::Arc; use daft_dsl::{ExprRef, ExprResolver}; use daft_schema::schema::{Schema, SchemaRef}; @@ -36,17 +36,15 @@ impl Aggregate { ) -> logical_plan::Result { let upstream_schema = input.schema(); - let groupby_set = HashSet::from_iter(groupby.clone()); + let agg_resolver = ExprResolver::builder().groupby(&groupby).build(); + let (aggregations, aggregation_fields) = agg_resolver + .resolve(aggregations, &upstream_schema) + .context(CreationSnafu)?; let groupby_resolver = ExprResolver::default(); - let agg_resolver = ExprResolver::builder().groupby(&groupby_set).build(); - let (groupby, groupby_fields) = groupby_resolver .resolve(groupby, &upstream_schema) .context(CreationSnafu)?; - let (aggregations, aggregation_fields) = agg_resolver - .resolve(aggregations, &upstream_schema) - .context(CreationSnafu)?; let fields = [groupby_fields, aggregation_fields].concat(); diff --git a/src/daft-logical-plan/src/ops/pivot.rs b/src/daft-logical-plan/src/ops/pivot.rs index f03b667be5..da204fdb34 100644 --- a/src/daft-logical-plan/src/ops/pivot.rs +++ b/src/daft-logical-plan/src/ops/pivot.rs @@ -1,4 +1,4 @@ -use std::{collections::HashSet, sync::Arc}; +use std::sync::Arc; use common_error::DaftError; use daft_core::prelude::*; @@ -34,11 +34,12 @@ impl Pivot { ) -> logical_plan::Result { let upstream_schema = input.schema(); - let groupby_set = HashSet::from_iter(group_by.clone()); + let agg_resolver = ExprResolver::builder().groupby(&group_by).build(); + let (aggregation, _) = agg_resolver + .resolve_single(aggregation, &upstream_schema) + .context(CreationSnafu)?; let expr_resolver = ExprResolver::default(); - let agg_resolver = ExprResolver::builder().groupby(&groupby_set).build(); - let (group_by, group_by_fields) = expr_resolver .resolve(group_by, &upstream_schema) .context(CreationSnafu)?; @@ -49,10 +50,6 @@ impl Pivot { .resolve_single(value_column, &upstream_schema) .context(CreationSnafu)?; - let (aggregation, _) = agg_resolver - .resolve_single(aggregation, &upstream_schema) - .context(CreationSnafu)?; - let Expr::Agg(agg_expr) = aggregation.as_ref() else { return Err(DaftError::ValueError(format!( "Pivot only supports using top level aggregation expressions, received {aggregation}",