From e18b719f4836636a42f07afcbc3c4b44ea9f8817 Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Thu, 14 Nov 2024 16:07:18 -0800 Subject: [PATCH] [FEAT] Support for aggregation expressions that use multiple AggExprs (#3296) This enables expressions such as `sum("a") + sum("b")` or `mean("a") / 100` in aggregations. This PR enables Q8 and Q14 of TPC-H and is also necessary for Q17 and Q20 (which are also missing subquery). --- Cargo.lock | 23 +- Cargo.toml | 1 + src/daft-dsl/Cargo.toml | 1 + src/daft-dsl/src/lib.rs | 5 +- src/daft-dsl/src/resolve_expr/mod.rs | 262 +++++++---------- src/daft-local-execution/src/pipeline.rs | 22 +- src/daft-local-plan/src/plan.rs | 8 +- src/daft-logical-plan/src/logical_plan.rs | 2 +- .../src/ops/actor_pool_project.rs | 8 +- src/daft-logical-plan/src/ops/agg.rs | 22 +- src/daft-logical-plan/src/ops/explode.rs | 9 +- src/daft-logical-plan/src/ops/filter.rs | 9 +- src/daft-logical-plan/src/ops/join.rs | 13 +- src/daft-logical-plan/src/ops/pivot.rs | 37 ++- src/daft-logical-plan/src/ops/project.rs | 9 +- src/daft-logical-plan/src/ops/repartition.rs | 6 +- src/daft-logical-plan/src/ops/sink.rs | 8 +- src/daft-logical-plan/src/ops/sort.rs | 9 +- src/daft-logical-plan/src/ops/unpivot.rs | 13 +- .../src/optimization/optimizer.rs | 10 +- .../rules/lift_project_from_agg.rs | 278 ++++++++++++++++++ .../src/optimization/rules/mod.rs | 2 + .../rules/push_down_projection.rs | 2 +- src/daft-physical-plan/src/lib.rs | 4 +- .../src/physical_planner/mod.rs | 2 +- .../src/physical_planner/translate.rs | 61 +++- src/daft-scheduler/src/scheduler.rs | 2 +- tests/dataframe/test_aggregations.py | 98 ++++++ 28 files changed, 700 insertions(+), 226 deletions(-) create mode 100644 src/daft-logical-plan/src/optimization/rules/lift_project_from_agg.rs diff --git a/Cargo.lock b/Cargo.lock index 8b4abf3608..a4aafc4a76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -453,7 +453,7 @@ dependencies = [ "strum 0.18.0", "strum_macros 0.18.0", "thiserror", - "typed-builder", + "typed-builder 0.5.1", "uuid 0.8.2", "zerocopy 0.3.2", ] @@ -2025,6 +2025,7 @@ dependencies = [ "log", "pyo3", "serde", + "typed-builder 0.20.0", "typetag", ] @@ -6639,6 +6640,26 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "typed-builder" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e14ed59dc8b7b26cacb2a92bad2e8b1f098806063898ab42a3bd121d7d45e75" +dependencies = [ + "typed-builder-macro", +] + +[[package]] +name = "typed-builder-macro" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "560b82d656506509d43abe30e0ba64c56b1953ab3d4fe7ba5902747a7a3cedd5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "typeid" version = "1.0.0" diff --git a/Cargo.toml b/Cargo.toml index a83ef36a87..e050ab5368 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -248,6 +248,7 @@ tokio = {version = "1.37.0", features = [ tokio-stream = {version = "0.1.14", features = ["fs", "io-util", "time"]} tokio-util = "0.7.11" tracing = "0.1" +typed-builder = "0.20.0" typetag = "0.2.18" url = "2.4.0" xxhash-rust = "0.8.12" diff --git a/src/daft-dsl/Cargo.toml b/src/daft-dsl/Cargo.toml index 5d7992fba1..6e04a977aa 100644 --- a/src/daft-dsl/Cargo.toml +++ b/src/daft-dsl/Cargo.toml @@ -13,6 +13,7 @@ itertools = {workspace = true} log = {workspace = true} pyo3 = {workspace = true, optional = true} serde = {workspace = true} +typed-builder = {workspace = true} typetag = {workspace = true} [features] diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index 2de36436b1..7b3df61e6a 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -21,10 +21,7 @@ pub use expr::{ pub use lit::{lit, literal_value, literals_to_series, null_lit, Literal, LiteralValue}; #[cfg(feature = "python")] use pyo3::prelude::*; -pub use resolve_expr::{ - check_column_name_validity, resolve_aggexprs, resolve_exprs, resolve_single_aggexpr, - resolve_single_expr, -}; +pub use resolve_expr::{check_column_name_validity, ExprResolver}; #[cfg(feature = "python")] pub fn register_modules(parent: &Bound) -> PyResult<()> { diff --git a/src/daft-dsl/src/resolve_expr/mod.rs b/src/daft-dsl/src/resolve_expr/mod.rs index 5888774fe4..6425f69e0f 100644 --- a/src/daft-dsl/src/resolve_expr/mod.rs +++ b/src/daft-dsl/src/resolve_expr/mod.rs @@ -10,8 +10,11 @@ use std::{ use common_error::{DaftError, DaftResult}; use common_treenode::{Transformed, TransformedResult, TreeNode}; use daft_core::prelude::*; +use typed_builder::TypedBuilder; -use crate::{col, expr::has_agg, has_stateful_udf, AggExpr, ApproxPercentileParams, Expr, ExprRef}; +use crate::{ + col, expr::has_agg, functions::FunctionExpr, has_stateful_udf, AggExpr, Expr, ExprRef, +}; // Calculates all the possible struct get expressions in a schema. // For each sugared string, calculates all possible corresponding expressions, in order of priority. @@ -204,192 +207,131 @@ fn expand_wildcards( } } -fn extract_agg_expr(expr: &Expr) -> DaftResult { - match expr { - Expr::Agg(agg_expr) => Ok(agg_expr.clone()), - Expr::Function { func, inputs } => Ok(AggExpr::MapGroups { - func: func.clone(), - inputs: inputs.clone(), - }), - Expr::Alias(e, name) => extract_agg_expr(e).map(|agg_expr| { - // reorder expressions so that alias goes before agg - match agg_expr { - AggExpr::Count(e, count_mode) => { - AggExpr::Count(Expr::Alias(e, name.clone()).into(), count_mode) - } - AggExpr::Sum(e) => AggExpr::Sum(Expr::Alias(e, name.clone()).into()), - AggExpr::ApproxPercentile(ApproxPercentileParams { - child: e, - percentiles, - force_list_output, - }) => AggExpr::ApproxPercentile(ApproxPercentileParams { - child: Expr::Alias(e, name.clone()).into(), - percentiles, - force_list_output, - }), - AggExpr::ApproxCountDistinct(e) => { - AggExpr::ApproxCountDistinct(Expr::Alias(e, name.clone()).into()) - } - AggExpr::ApproxSketch(e, sketch_type) => { - AggExpr::ApproxSketch(Expr::Alias(e, name.clone()).into(), sketch_type) - } - AggExpr::MergeSketch(e, sketch_type) => { - AggExpr::MergeSketch(Expr::Alias(e, name.clone()).into(), sketch_type) - } - AggExpr::Mean(e) => AggExpr::Mean(Expr::Alias(e, name.clone()).into()), - AggExpr::Stddev(e) => AggExpr::Stddev(Expr::Alias(e, name.clone()).into()), - AggExpr::Min(e) => AggExpr::Min(Expr::Alias(e, name.clone()).into()), - AggExpr::Max(e) => AggExpr::Max(Expr::Alias(e, name.clone()).into()), - AggExpr::AnyValue(e, ignore_nulls) => { - AggExpr::AnyValue(Expr::Alias(e, name.clone()).into(), ignore_nulls) - } - AggExpr::List(e) => AggExpr::List(Expr::Alias(e, name.clone()).into()), - AggExpr::Concat(e) => AggExpr::Concat(Expr::Alias(e, name.clone()).into()), - AggExpr::MapGroups { func, inputs } => AggExpr::MapGroups { - func, - inputs: inputs - .into_iter() - .map(|input| input.alias(name.clone())) - .collect(), - }, - } - }), - // TODO(Kevin): Support a mix of aggregation and non-aggregation expressions - // as long as the final value always has a cardinality of 1. - _ => Err(DaftError::ValueError(format!( - "Expected aggregation expression, but got: {expr}" - ))), +/// Checks if an expression used in an aggregation is well formed. +/// Expressions for aggregations must be in the form (optional) non-agg expr <- agg exprs or literals <- non-agg exprs +/// +/// # Examples +/// +/// Allowed: +/// - lit("x") +/// - sum(col("a")) +/// - sum(col("a")) > 0 +/// - sum(col("a")) - sum(col("b")) > sum(col("c")) +/// +/// Not allowed: +/// - col("a") +/// - not an aggregation +/// - sum(col("a")) + col("b") +/// - not all branches are aggregations +fn has_single_agg_layer(expr: &ExprRef) -> bool { + match expr.as_ref() { + Expr::Agg(agg_expr) => !agg_expr.children().iter().any(has_agg), + Expr::Column(_) => false, + Expr::Literal(_) => true, + _ => expr.children().iter().all(has_single_agg_layer), } } -/// Resolves and validates the expression with a schema, returning the new expression and its field. -/// Specifically, makes sure the expression does not contain aggregations or stateful UDFs when they are not allowed, -/// and resolves struct accessors and wildcards. -/// May return multiple expressions if the expr contains a wildcard. -/// -/// TODO: Use a builder pattern for this functionality -fn resolve_expr( - expr: ExprRef, - schema: &Schema, - allow_stateful_udf: bool, -) -> DaftResult> { - // TODO(Kevin): Support aggregation expressions everywhere +fn convert_udfs_to_map_groups(expr: &ExprRef) -> ExprRef { + expr.clone() + .transform(|e| match e.as_ref() { + Expr::Function { func, inputs } if matches!(func, FunctionExpr::Python(_)) => { + Ok(Transformed::yes(Arc::new(Expr::Agg(AggExpr::MapGroups { + func: func.clone(), + inputs: inputs.clone(), + })))) + } + _ => Ok(Transformed::no(e)), + }) + .unwrap() + .data +} + +fn validate_expr(expr: ExprRef) -> DaftResult { if has_agg(&expr) { return Err(DaftError::ValueError(format!( "Aggregation expressions are currently only allowed in agg and pivot: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383", ))); } - if !allow_stateful_udf && has_stateful_udf(&expr) { + Ok(expr) +} + +fn validate_expr_in_agg(expr: ExprRef) -> DaftResult { + let converted_expr = convert_udfs_to_map_groups(&expr); + + if !has_single_agg_layer(&converted_expr) { return Err(DaftError::ValueError(format!( - "Stateful UDFs are only allowed in projections: {expr}" + "Expressions in aggregations must be composed of non-nested aggregation expressions, got {expr}" ))); } - let struct_expr_map = calculate_struct_expr_map(schema); - expand_wildcards(expr, schema, &struct_expr_map)? - .into_iter() - .map(|e| transform_struct_gets(e, &struct_expr_map)) - .collect() + Ok(converted_expr) } -// Resolve a single expression, erroring if any kind of expansion happens. -pub fn resolve_single_expr( - expr: ExprRef, - schema: &Schema, +/// Used for resolving and validating expressions. +/// Specifically, makes sure the expression does not contain aggregations or stateful UDFs +/// where they are not allowed, and resolves struct accessors and wildcards. +#[derive(Default, TypedBuilder)] +pub struct ExprResolver { + #[builder(default)] allow_stateful_udf: bool, -) -> DaftResult<(ExprRef, Field)> { - let resolved_exprs = resolve_expr(expr.clone(), schema, allow_stateful_udf)?; - match resolved_exprs.as_slice() { - [resolved_expr] => Ok((resolved_expr.clone(), resolved_expr.to_field(schema)?)), - _ => Err(DaftError::ValueError(format!( - "Error resolving expression {}: expanded into {} expressions when 1 was expected", - expr, - resolved_exprs.len() - ))), - } + #[builder(default)] + in_agg_context: bool, } -pub fn resolve_exprs( - exprs: Vec, - schema: &Schema, - allow_stateful_udf: bool, -) -> DaftResult<(Vec, Vec)> { - // can't flat map because we need to deal with errors - let resolved_exprs: DaftResult>> = exprs - .into_iter() - .map(|e| resolve_expr(e, schema, allow_stateful_udf)) - .collect(); - let resolved_exprs: Vec = resolved_exprs?.into_iter().flatten().collect(); - let resolved_fields: DaftResult> = - resolved_exprs.iter().map(|e| e.to_field(schema)).collect(); - Ok((resolved_exprs, resolved_fields?)) -} +impl ExprResolver { + fn resolve_helper(&self, expr: ExprRef, schema: &Schema) -> DaftResult> { + if !self.allow_stateful_udf && has_stateful_udf(&expr) { + return Err(DaftError::ValueError(format!( + "Stateful UDFs are only allowed in projections: {expr}" + ))); + } -/// Resolves and validates the expression with a schema, returning the extracted aggregation expression and its field. -/// Specifically, makes sure the expression does not contain aggregationsnested or stateful UDFs, -/// and resolves struct accessors and wildcards. -/// May return multiple expressions if the expr contains a wildcard. -/// -/// TODO: Use a builder pattern for this functionality -fn resolve_aggexpr(expr: ExprRef, schema: &Schema) -> DaftResult> { - let has_nested_agg = extract_agg_expr(&expr)?.children().iter().any(has_agg); + let validated_expr = if self.in_agg_context { + validate_expr_in_agg(expr) + } else { + validate_expr(expr) + }?; - if has_nested_agg { - return Err(DaftError::ValueError(format!( - "Nested aggregation expressions are not supported: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383" - ))); + let struct_expr_map = calculate_struct_expr_map(schema); + expand_wildcards(validated_expr, schema, &struct_expr_map)? + .into_iter() + .map(|e| transform_struct_gets(e, &struct_expr_map)) + .collect() } - if has_stateful_udf(&expr) { - return Err(DaftError::ValueError(format!( - "Stateful UDFs are only allowed in projections: {expr}" - ))); + /// Resolve multiple expressions. Due to wildcards, output vec may contain more expressions than input. + pub fn resolve( + &self, + exprs: Vec, + schema: &Schema, + ) -> DaftResult<(Vec, Vec)> { + // can't flat map because we need to deal with errors + let resolved_exprs: DaftResult>> = exprs + .into_iter() + .map(|e| self.resolve_helper(e, schema)) + .collect(); + let resolved_exprs: Vec = resolved_exprs?.into_iter().flatten().collect(); + let resolved_fields: DaftResult> = + resolved_exprs.iter().map(|e| e.to_field(schema)).collect(); + Ok((resolved_exprs, resolved_fields?)) } - let struct_expr_map = calculate_struct_expr_map(schema); - expand_wildcards(expr, schema, &struct_expr_map)? - .into_iter() - .map(|expr| { - let agg_expr = extract_agg_expr(&expr)?; - - let resolved_children = agg_expr - .children() - .into_iter() - .map(|e| transform_struct_gets(e, &struct_expr_map)) - .collect::>>()?; - Ok(agg_expr.with_new_children(resolved_children)) - }) - .collect() -} - -pub fn resolve_single_aggexpr(expr: ExprRef, schema: &Schema) -> DaftResult<(AggExpr, Field)> { - let resolved_exprs = resolve_aggexpr(expr.clone(), schema)?; - match resolved_exprs.as_slice() { - [resolved_expr] => Ok((resolved_expr.clone(), resolved_expr.to_field(schema)?)), - _ => Err(DaftError::ValueError(format!( - "Error resolving expression {}: expanded into {} expressions when 1 was expected", - expr, - resolved_exprs.len() - ))), + /// Resolve a single expression, ensuring that the output is also a single expression. + pub fn resolve_single(&self, expr: ExprRef, schema: &Schema) -> DaftResult<(ExprRef, Field)> { + let resolved_exprs = self.resolve_helper(expr.clone(), schema)?; + match resolved_exprs.as_slice() { + [resolved_expr] => Ok((resolved_expr.clone(), resolved_expr.to_field(schema)?)), + _ => Err(DaftError::ValueError(format!( + "Error resolving expression {}: expanded into {} expressions when 1 was expected", + expr, + resolved_exprs.len() + ))), + } } } -pub fn resolve_aggexprs( - exprs: Vec, - schema: &Schema, -) -> DaftResult<(Vec, Vec)> { - // can't flat map because we need to deal with errors - let resolved_exprs: DaftResult>> = exprs - .into_iter() - .map(|e| resolve_aggexpr(e, schema)) - .collect(); - let resolved_exprs: Vec = resolved_exprs?.into_iter().flatten().collect(); - let resolved_fields: DaftResult> = - resolved_exprs.iter().map(|e| e.to_field(schema)).collect(); - Ok((resolved_exprs, resolved_fields?)) -} - pub fn check_column_name_validity(name: &str, schema: &Schema) -> DaftResult<()> { let struct_expr_map = calculate_struct_expr_map(schema); diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index e213e77e0a..83aa411e09 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -17,7 +17,7 @@ use daft_local_plan::{ }; use daft_logical_plan::JoinType; use daft_micropartition::MicroPartition; -use daft_physical_plan::populate_aggregation_stages; +use daft_physical_plan::{extract_agg_expr, populate_aggregation_stages}; use daft_scan::ScanTaskRef; use daft_table::ProbeState; use daft_writers::make_physical_writer_factory; @@ -203,8 +203,16 @@ pub fn physical_plan_to_pipeline( schema, .. }) => { + let aggregations = aggregations + .iter() + .map(extract_agg_expr) + .collect::>>() + .with_context(|_| PipelineCreationSnafu { + plan_name: physical_plan.name(), + })?; + let (first_stage_aggs, second_stage_aggs, final_exprs) = - populate_aggregation_stages(aggregations, schema, &[]); + populate_aggregation_stages(&aggregations, schema, &[]); let first_stage_agg_op = AggregateOperator::new( first_stage_aggs .values() @@ -239,8 +247,16 @@ pub fn physical_plan_to_pipeline( schema, .. }) => { + let aggregations = aggregations + .iter() + .map(extract_agg_expr) + .collect::>>() + .with_context(|_| PipelineCreationSnafu { + plan_name: physical_plan.name(), + })?; + let (first_stage_aggs, second_stage_aggs, final_exprs) = - populate_aggregation_stages(aggregations, schema, group_by); + populate_aggregation_stages(&aggregations, schema, group_by); let child_node = physical_plan_to_pipeline(input, psets, cfg)?; let (post_first_agg_node, group_by) = if !first_stage_aggs.is_empty() { let agg_op = AggregateOperator::new( diff --git a/src/daft-local-plan/src/plan.rs b/src/daft-local-plan/src/plan.rs index 12dd17238a..d0c14a4985 100644 --- a/src/daft-local-plan/src/plan.rs +++ b/src/daft-local-plan/src/plan.rs @@ -152,7 +152,7 @@ impl LocalPhysicalPlan { pub(crate) fn ungrouped_aggregate( input: LocalPhysicalPlanRef, - aggregations: Vec, + aggregations: Vec, schema: SchemaRef, ) -> LocalPhysicalPlanRef { Self::UnGroupedAggregate(UnGroupedAggregate { @@ -166,7 +166,7 @@ impl LocalPhysicalPlan { pub(crate) fn hash_aggregate( input: LocalPhysicalPlanRef, - aggregations: Vec, + aggregations: Vec, group_by: Vec, schema: SchemaRef, ) -> LocalPhysicalPlanRef { @@ -429,7 +429,7 @@ pub struct Sample { #[derive(Debug)] pub struct UnGroupedAggregate { pub input: LocalPhysicalPlanRef, - pub aggregations: Vec, + pub aggregations: Vec, pub schema: SchemaRef, pub plan_stats: PlanStats, } @@ -437,7 +437,7 @@ pub struct UnGroupedAggregate { #[derive(Debug)] pub struct HashAggregate { pub input: LocalPhysicalPlanRef, - pub aggregations: Vec, + pub aggregations: Vec, pub group_by: Vec, pub schema: SchemaRef, pub plan_stats: PlanStats, diff --git a/src/daft-logical-plan/src/logical_plan.rs b/src/daft-logical-plan/src/logical_plan.rs index 594c07e1a9..b9ca2e449c 100644 --- a/src/daft-logical-plan/src/logical_plan.rs +++ b/src/daft-logical-plan/src/logical_plan.rs @@ -263,7 +263,7 @@ impl LogicalPlan { Self::Sort(Sort { sort_by, descending, .. }) => Self::Sort(Sort::try_new(input.clone(), sort_by.clone(), descending.clone()).unwrap()), Self::Repartition(Repartition { repartition_spec: scheme_config, .. }) => Self::Repartition(Repartition::try_new(input.clone(), scheme_config.clone()).unwrap()), Self::Distinct(_) => Self::Distinct(Distinct::new(input.clone())), - Self::Aggregate(Aggregate { aggregations, groupby, ..}) => Self::Aggregate(Aggregate::try_new(input.clone(), aggregations.iter().map(|ae| ae.into()).collect(), groupby.clone()).unwrap()), + Self::Aggregate(Aggregate { aggregations, groupby, ..}) => Self::Aggregate(Aggregate::try_new(input.clone(), aggregations.clone(), groupby.clone()).unwrap()), Self::Pivot(Pivot { group_by, pivot_column, value_column, aggregation, names, ..}) => Self::Pivot(Pivot::try_new(input.clone(), group_by.clone(), pivot_column.clone(), value_column.clone(), aggregation.into(), names.clone()).unwrap()), Self::Sink(Sink { sink_info, .. }) => Self::Sink(Sink::try_new(input.clone(), sink_info.clone()).unwrap()), Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId {column_name, .. }) => Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId::new(input.clone(), Some(column_name))), diff --git a/src/daft-logical-plan/src/ops/actor_pool_project.rs b/src/daft-logical-plan/src/ops/actor_pool_project.rs index 97b511b238..fa1c8bb970 100644 --- a/src/daft-logical-plan/src/ops/actor_pool_project.rs +++ b/src/daft-logical-plan/src/ops/actor_pool_project.rs @@ -8,7 +8,7 @@ use daft_dsl::{ python::{get_concurrency, get_resource_request, PythonUDF, StatefulPythonUDF}, FunctionExpr, }, - resolve_exprs, Expr, ExprRef, + Expr, ExprRef, ExprResolver, }; use daft_schema::schema::{Schema, SchemaRef}; use itertools::Itertools; @@ -29,8 +29,10 @@ pub struct ActorPoolProject { impl ActorPoolProject { pub(crate) fn try_new(input: Arc, projection: Vec) -> Result { - let (projection, fields) = - resolve_exprs(projection, input.schema().as_ref(), true).context(CreationSnafu)?; + let expr_resolver = ExprResolver::builder().allow_stateful_udf(true).build(); + let (projection, fields) = expr_resolver + .resolve(projection, input.schema().as_ref()) + .context(CreationSnafu)?; let num_stateful_udf_exprs: usize = projection .iter() diff --git a/src/daft-logical-plan/src/ops/agg.rs b/src/daft-logical-plan/src/ops/agg.rs index 2a7be5337c..b87a8f2a68 100644 --- a/src/daft-logical-plan/src/ops/agg.rs +++ b/src/daft-logical-plan/src/ops/agg.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use daft_dsl::{resolve_aggexprs, resolve_exprs, AggExpr, ExprRef}; +use daft_dsl::{ExprRef, ExprResolver}; use daft_schema::schema::{Schema, SchemaRef}; use itertools::Itertools; use snafu::ResultExt; @@ -16,7 +16,11 @@ pub struct Aggregate { pub input: Arc, /// Aggregations to apply. - pub aggregations: Vec, + /// + /// Initially, the root level expressions may not be aggregations, + /// but they should be factored out into a project by an optimization rule, + /// leaving only aliases and agg expressions by translation time. + pub aggregations: Vec, /// Grouping to apply. pub groupby: Vec, @@ -31,10 +35,16 @@ impl Aggregate { groupby: Vec, ) -> logical_plan::Result { let upstream_schema = input.schema(); - let (groupby, groupby_fields) = - resolve_exprs(groupby, &upstream_schema, false).context(CreationSnafu)?; - let (aggregations, aggregation_fields) = - resolve_aggexprs(aggregations, &upstream_schema).context(CreationSnafu)?; + + let groupby_resolver = ExprResolver::default(); + let agg_resolver = ExprResolver::builder().in_agg_context(true).build(); + + let (groupby, groupby_fields) = groupby_resolver + .resolve(groupby, &upstream_schema) + .context(CreationSnafu)?; + let (aggregations, aggregation_fields) = agg_resolver + .resolve(aggregations, &upstream_schema) + .context(CreationSnafu)?; let fields = [groupby_fields, aggregation_fields].concat(); diff --git a/src/daft-logical-plan/src/ops/explode.rs b/src/daft-logical-plan/src/ops/explode.rs index 66e5ad0a5c..daa7ca99b0 100644 --- a/src/daft-logical-plan/src/ops/explode.rs +++ b/src/daft-logical-plan/src/ops/explode.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use daft_dsl::{resolve_exprs, ExprRef}; +use daft_dsl::{ExprRef, ExprResolver}; use daft_schema::schema::{Schema, SchemaRef}; use itertools::Itertools; use snafu::ResultExt; @@ -26,8 +26,11 @@ impl Explode { ) -> logical_plan::Result { let upstream_schema = input.schema(); - let (to_explode, _) = - resolve_exprs(to_explode, &upstream_schema, false).context(CreationSnafu)?; + let expr_resolver = ExprResolver::default(); + + let (to_explode, _) = expr_resolver + .resolve(to_explode, &upstream_schema) + .context(CreationSnafu)?; let explode_exprs = to_explode .iter() diff --git a/src/daft-logical-plan/src/ops/filter.rs b/src/daft-logical-plan/src/ops/filter.rs index c5187fca4a..0f12bf9a49 100644 --- a/src/daft-logical-plan/src/ops/filter.rs +++ b/src/daft-logical-plan/src/ops/filter.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use common_error::DaftError; use daft_core::prelude::*; -use daft_dsl::{resolve_single_expr, ExprRef}; +use daft_dsl::{ExprRef, ExprResolver}; use snafu::ResultExt; use crate::{ @@ -20,8 +20,11 @@ pub struct Filter { impl Filter { pub(crate) fn try_new(input: Arc, predicate: ExprRef) -> Result { - let (predicate, field) = - resolve_single_expr(predicate, &input.schema(), false).context(CreationSnafu)?; + let expr_resolver = ExprResolver::default(); + + let (predicate, field) = expr_resolver + .resolve_single(predicate, &input.schema()) + .context(CreationSnafu)?; if !matches!(field.dtype, DataType::Boolean) { return Err(DaftError::ValueError(format!( diff --git a/src/daft-logical-plan/src/ops/join.rs b/src/daft-logical-plan/src/ops/join.rs index 3d485f8997..db787cf3a2 100644 --- a/src/daft-logical-plan/src/ops/join.rs +++ b/src/daft-logical-plan/src/ops/join.rs @@ -9,7 +9,7 @@ use daft_dsl::{ col, join::{get_common_join_keys, infer_join_schema}, optimization::replace_columns_with_expressions, - resolve_exprs, Expr, ExprRef, + Expr, ExprRef, ExprResolver, }; use itertools::Itertools; use snafu::ResultExt; @@ -66,9 +66,14 @@ impl Join { // In SQL the join column is always kept, while in dataframes it is not keep_join_keys: bool, ) -> logical_plan::Result { - let (left_on, _) = resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?; - let (right_on, _) = - resolve_exprs(right_on, &right.schema(), false).context(CreationSnafu)?; + let expr_resolver = ExprResolver::default(); + + let (left_on, _) = expr_resolver + .resolve(left_on, &left.schema()) + .context(CreationSnafu)?; + let (right_on, _) = expr_resolver + .resolve(right_on, &right.schema()) + .context(CreationSnafu)?; let (unique_left_on, unique_right_on) = Self::rename_join_keys(left_on.clone(), right_on.clone()); diff --git a/src/daft-logical-plan/src/ops/pivot.rs b/src/daft-logical-plan/src/ops/pivot.rs index cc223b707c..49176d8048 100644 --- a/src/daft-logical-plan/src/ops/pivot.rs +++ b/src/daft-logical-plan/src/ops/pivot.rs @@ -1,7 +1,8 @@ use std::sync::Arc; +use common_error::DaftError; use daft_core::prelude::*; -use daft_dsl::{resolve_exprs, resolve_single_aggexpr, resolve_single_expr, AggExpr, ExprRef}; +use daft_dsl::{AggExpr, Expr, ExprRef, ExprResolver}; use daft_schema::schema::{Schema, SchemaRef}; use itertools::Itertools; use snafu::ResultExt; @@ -32,14 +33,30 @@ impl Pivot { names: Vec, ) -> logical_plan::Result { let upstream_schema = input.schema(); - let (group_by, group_by_fields) = - resolve_exprs(group_by, &upstream_schema, false).context(CreationSnafu)?; - let (pivot_column, _) = - resolve_single_expr(pivot_column, &upstream_schema, false).context(CreationSnafu)?; - let (value_column, value_col_field) = - resolve_single_expr(value_column, &upstream_schema, false).context(CreationSnafu)?; - let (aggregation, _) = - resolve_single_aggexpr(aggregation, &upstream_schema).context(CreationSnafu)?; + + let expr_resolver = ExprResolver::default(); + let agg_resolver = ExprResolver::builder().in_agg_context(true).build(); + + let (group_by, group_by_fields) = expr_resolver + .resolve(group_by, &upstream_schema) + .context(CreationSnafu)?; + let (pivot_column, _) = expr_resolver + .resolve_single(pivot_column, &upstream_schema) + .context(CreationSnafu)?; + let (value_column, value_col_field) = expr_resolver + .resolve_single(value_column, &upstream_schema) + .context(CreationSnafu)?; + + let (aggregation, _) = agg_resolver + .resolve_single(aggregation, &upstream_schema) + .context(CreationSnafu)?; + + let Expr::Agg(agg_expr) = aggregation.as_ref() else { + return Err(DaftError::ValueError(format!( + "Pivot only supports using top level aggregation expressions, received {aggregation}", + )) + .into()); + }; let output_schema = { let value_col_dtype = value_col_field.dtype; @@ -59,7 +76,7 @@ impl Pivot { group_by, pivot_column, value_column, - aggregation, + aggregation: agg_expr.clone(), names, output_schema, }) diff --git a/src/daft-logical-plan/src/ops/project.rs b/src/daft-logical-plan/src/ops/project.rs index 163acba2b9..633ac9ec12 100644 --- a/src/daft-logical-plan/src/ops/project.rs +++ b/src/daft-logical-plan/src/ops/project.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use common_treenode::Transformed; use daft_core::prelude::*; -use daft_dsl::{optimization, resolve_exprs, AggExpr, ApproxPercentileParams, Expr, ExprRef}; +use daft_dsl::{optimization, AggExpr, ApproxPercentileParams, Expr, ExprRef, ExprResolver}; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; use snafu::ResultExt; @@ -22,8 +22,11 @@ pub struct Project { impl Project { pub(crate) fn try_new(input: Arc, projection: Vec) -> Result { - let (projection, fields) = - resolve_exprs(projection, &input.schema(), true).context(CreationSnafu)?; + let expr_resolver = ExprResolver::builder().allow_stateful_udf(true).build(); + + let (projection, fields) = expr_resolver + .resolve(projection, &input.schema()) + .context(CreationSnafu)?; // Factor the projection and see if there are any substitutions to factor out. let (factored_input, factored_projection) = diff --git a/src/daft-logical-plan/src/ops/repartition.rs b/src/daft-logical-plan/src/ops/repartition.rs index 34a03159a1..1dce616d62 100644 --- a/src/daft-logical-plan/src/ops/repartition.rs +++ b/src/daft-logical-plan/src/ops/repartition.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use common_error::DaftResult; -use daft_dsl::resolve_exprs; +use daft_dsl::ExprResolver; use crate::{ partitioning::{HashRepartitionConfig, RepartitionSpec}, @@ -22,7 +22,9 @@ impl Repartition { ) -> DaftResult { let repartition_spec = match repartition_spec { RepartitionSpec::Hash(HashRepartitionConfig { num_partitions, by }) => { - let (resolved_by, _) = resolve_exprs(by, &input.schema(), false)?; + let expr_resolver = ExprResolver::default(); + + let (resolved_by, _) = expr_resolver.resolve(by, &input.schema())?; RepartitionSpec::Hash(HashRepartitionConfig { num_partitions, by: resolved_by, diff --git a/src/daft-logical-plan/src/ops/sink.rs b/src/daft-logical-plan/src/ops/sink.rs index d84c654c84..1370ab91f4 100644 --- a/src/daft-logical-plan/src/ops/sink.rs +++ b/src/daft-logical-plan/src/ops/sink.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use common_error::DaftResult; use daft_core::prelude::*; -use daft_dsl::resolve_exprs; +use daft_dsl::ExprResolver; #[cfg(feature = "python")] use crate::sink_info::CatalogType; @@ -30,10 +30,14 @@ impl Sink { compression, io_config, }) => { + let expr_resolver = ExprResolver::default(); + let resolved_partition_cols = partition_cols .clone() .map(|cols| { - resolve_exprs(cols, &schema, false).map(|(resolved_cols, _)| resolved_cols) + expr_resolver + .resolve(cols, &schema) + .map(|(resolved_cols, _)| resolved_cols) }) .transpose()?; diff --git a/src/daft-logical-plan/src/ops/sort.rs b/src/daft-logical-plan/src/ops/sort.rs index cab410da86..1c722a85f7 100644 --- a/src/daft-logical-plan/src/ops/sort.rs +++ b/src/daft-logical-plan/src/ops/sort.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use common_error::DaftError; use daft_core::prelude::*; -use daft_dsl::{resolve_exprs, ExprRef}; +use daft_dsl::{ExprRef, ExprResolver}; use itertools::Itertools; use snafu::ResultExt; @@ -29,8 +29,11 @@ impl Sort { .context(CreationSnafu); } - let (sort_by, sort_by_fields) = - resolve_exprs(sort_by, &input.schema(), false).context(CreationSnafu)?; + let expr_resolver = ExprResolver::default(); + + let (sort_by, sort_by_fields) = expr_resolver + .resolve(sort_by, &input.schema()) + .context(CreationSnafu)?; let sort_by_resolved_schema = Schema::new(sort_by_fields).context(CreationSnafu)?; diff --git a/src/daft-logical-plan/src/ops/unpivot.rs b/src/daft-logical-plan/src/ops/unpivot.rs index 7f10f2ee0a..cec9cd1c00 100644 --- a/src/daft-logical-plan/src/ops/unpivot.rs +++ b/src/daft-logical-plan/src/ops/unpivot.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use common_error::DaftError; use daft_core::{prelude::*, utils::supertype::try_get_supertype}; -use daft_dsl::{resolve_exprs, ExprRef}; +use daft_dsl::{ExprRef, ExprResolver}; use itertools::Itertools; use snafu::ResultExt; @@ -36,9 +36,12 @@ impl Unpivot { .context(CreationSnafu); } + let expr_resolver = ExprResolver::default(); + let input_schema = input.schema(); - let (values, values_fields) = - resolve_exprs(values, &input_schema, false).context(CreationSnafu)?; + let (values, values_fields) = expr_resolver + .resolve(values, &input_schema) + .context(CreationSnafu)?; let value_dtype = values_fields .iter() @@ -50,7 +53,9 @@ impl Unpivot { let variable_field = Field::new(variable_name, DataType::Utf8); let value_field = Field::new(value_name, value_dtype); - let (ids, ids_fields) = resolve_exprs(ids, &input_schema, false).context(CreationSnafu)?; + let (ids, ids_fields) = expr_resolver + .resolve(ids, &input_schema) + .context(CreationSnafu)?; let output_fields = ids_fields .into_iter() diff --git a/src/daft-logical-plan/src/optimization/optimizer.rs b/src/daft-logical-plan/src/optimization/optimizer.rs index 31d455f32b..084018522a 100644 --- a/src/daft-logical-plan/src/optimization/optimizer.rs +++ b/src/daft-logical-plan/src/optimization/optimizer.rs @@ -6,8 +6,8 @@ use common_treenode::Transformed; use super::{ logical_plan_tracker::LogicalPlanTracker, rules::{ - DropRepartition, EliminateCrossJoin, OptimizerRule, PushDownFilter, PushDownLimit, - PushDownProjection, SplitActorPoolProjects, + DropRepartition, EliminateCrossJoin, LiftProjectFromAgg, OptimizerRule, PushDownFilter, + PushDownLimit, PushDownProjection, SplitActorPoolProjects, }, }; use crate::LogicalPlan; @@ -106,6 +106,12 @@ impl Optimizer { )); } + // --- Rewrite rules --- + rule_batches.push(RuleBatch::new( + vec![Box::new(LiftProjectFromAgg::new())], + RuleExecutionStrategy::Once, + )); + // --- Bulk of our rules --- rule_batches.push(RuleBatch::new( vec![ diff --git a/src/daft-logical-plan/src/optimization/rules/lift_project_from_agg.rs b/src/daft-logical-plan/src/optimization/rules/lift_project_from_agg.rs new file mode 100644 index 0000000000..872859356f --- /dev/null +++ b/src/daft-logical-plan/src/optimization/rules/lift_project_from_agg.rs @@ -0,0 +1,278 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use common_treenode::{Transformed, TreeNode}; +use daft_dsl::{col, optimization::requires_computation, Expr}; +use indexmap::IndexSet; + +use super::OptimizerRule; +use crate::{ + ops::{Aggregate, Project}, + LogicalPlan, +}; + +/// Rewrite rule for lifting expressions that can be done in a project out of an aggregation. +/// After a pass of this rule, the top level expressions in each aggregate should all be aliases or agg exprs. +/// +///The logical to physical plan translation currently assumes that expressions are lifted out of aggregations, +/// so this rule must be run to rewrite the plan into a valid state. +/// +/// # Examples +/// +/// ### Global Agg +/// Input: `Agg [sum("x") + sum("y")] <- Scan` +/// +/// Output: `Project [col("sum(x)") + col("sum(y)")] <- Agg [sum("x"), sum("y")] <- Scan` +/// +/// ### Groupby Agg +/// Input: `Agg [groupby="key", sum("x") + sum("y")] <- Scan` +/// +/// Output: `Project ["key", col("sum(x)") + col("sum(y)")] <- Agg [groupby="key", sum("x"), sum("y")] <- Scan` +#[derive(Default, Debug)] +pub struct LiftProjectFromAgg {} + +impl LiftProjectFromAgg { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for LiftProjectFromAgg { + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + plan.transform(|node| { + let LogicalPlan::Aggregate(aggregate) = node.as_ref() else { + return Ok(Transformed::no(node)); + }; + + let schema = node.schema(); + + let mut agg_exprs = IndexSet::new(); + + let lifted_exprs = aggregate + .aggregations + .iter() + .map(|expr| { + let name = expr.name(); + let new_expr = expr + .clone() + .transform_down(|e| { + if matches!(e.as_ref(), Expr::Agg(_)) { + let id = e.semantic_id(schema.as_ref()).id; + agg_exprs.insert(e.alias(id.clone())); + Ok(Transformed::yes(col(id))) + } else { + Ok(Transformed::no(e)) + } + }) + .unwrap() + .data; + + if new_expr.name() != name { + new_expr.alias(name) + } else { + new_expr + } + }) + .collect::>(); + + if lifted_exprs + .iter() + .any(|expr| requires_computation(expr.as_ref())) + { + let project_exprs = aggregate + .groupby + .iter() + .map(|e| col(e.name())) + .chain(lifted_exprs) + .collect::>(); + + let new_aggregate = Arc::new(LogicalPlan::Aggregate(Aggregate::try_new( + aggregate.input.clone(), + agg_exprs.into_iter().collect(), + aggregate.groupby.clone(), + )?)); + + let new_project = Arc::new(LogicalPlan::Project(Project::try_new( + new_aggregate, + project_exprs, + )?)); + + Ok(Transformed::yes(new_project)) + } else { + Ok(Transformed::no(node)) + } + }) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use common_error::DaftResult; + use daft_dsl::col; + use daft_schema::{dtype::DataType, field::Field}; + + use super::LiftProjectFromAgg; + use crate::{ + optimization::test::assert_optimized_plan_with_rules_eq, + test::{dummy_scan_node, dummy_scan_operator}, + LogicalPlan, + }; + + fn assert_optimized_plan_eq( + plan: Arc, + expected: Arc, + ) -> DaftResult<()> { + assert_optimized_plan_with_rules_eq( + plan, + expected, + vec![Box::new(LiftProjectFromAgg::new())], + ) + } + + #[test] + fn lift_exprs_global_agg() -> DaftResult<()> { + let scan_op = dummy_scan_operator(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Int64), + ]); + + let plan = dummy_scan_node(scan_op.clone()) + .aggregate( + vec![ + col("a").sum(), + col("a").sum().add(col("b").sum()).alias("a_plus_b"), + col("b").mean(), + col("b").mean().alias("c"), + ], + vec![], + )? + .build(); + + let schema = dummy_scan_node(scan_op.clone()).schema(); + + let a_sum_id = col("a").sum().semantic_id(schema.as_ref()).id; + let b_sum_id = col("b").sum().semantic_id(schema.as_ref()).id; + let b_mean_id = col("b").mean().semantic_id(schema.as_ref()).id; + + let expected = dummy_scan_node(scan_op) + .aggregate( + vec![ + col("a").sum().alias(a_sum_id.clone()), + col("b").sum().alias(b_sum_id.clone()), + col("b").mean().alias(b_mean_id.clone()), + ], + vec![], + )? + .select(vec![ + col(a_sum_id.clone()).alias("a"), + col(a_sum_id).add(col(b_sum_id)).alias("a_plus_b"), + col(b_mean_id.clone()).alias("b"), + col(b_mean_id).alias("c"), + ])? + .build(); + + assert_optimized_plan_eq(plan, expected)?; + Ok(()) + } + + #[test] + fn lift_exprs_groupby_agg() -> DaftResult<()> { + let scan_op = dummy_scan_operator(vec![ + Field::new("groupby_key", DataType::Utf8), + Field::new("a", DataType::Int64), + Field::new("b", DataType::Int64), + ]); + + let plan = dummy_scan_node(scan_op.clone()) + .aggregate( + vec![ + col("a").sum(), + col("a").sum().add(col("b").sum()).alias("a_plus_b"), + col("b").mean(), + col("b").mean().alias("c"), + ], + vec![col("groupby_key")], + )? + .build(); + + let schema = dummy_scan_node(scan_op.clone()).schema(); + + let a_sum_id = col("a").sum().semantic_id(schema.as_ref()).id; + let b_sum_id = col("b").sum().semantic_id(schema.as_ref()).id; + let b_mean_id = col("b").mean().semantic_id(schema.as_ref()).id; + + let expected = dummy_scan_node(scan_op) + .aggregate( + vec![ + col("a").sum().alias(a_sum_id.clone()), + col("b").sum().alias(b_sum_id.clone()), + col("b").mean().alias(b_mean_id.clone()), + ], + vec![col("groupby_key")], + )? + .select(vec![ + col("groupby_key"), + col(a_sum_id.clone()).alias("a"), + col(a_sum_id).add(col(b_sum_id)).alias("a_plus_b"), + col(b_mean_id.clone()).alias("b"), + col(b_mean_id).alias("c"), + ])? + .build(); + + assert_optimized_plan_eq(plan, expected)?; + Ok(()) + } + + #[test] + fn do_not_lift_exprs_global_agg() -> DaftResult<()> { + let scan_op = dummy_scan_operator(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Int64), + ]); + + let plan = dummy_scan_node(scan_op.clone()) + .aggregate( + vec![ + col("a").sum(), + col("a").add(col("b")).sum().alias("a_plus_b"), + col("b").mean(), + col("b").mean().alias("c"), + ], + vec![], + )? + .build(); + + let expected = plan.clone(); + + assert_optimized_plan_eq(plan, expected)?; + Ok(()) + } + + #[test] + fn do_not_lift_exprs_groupby_agg() -> DaftResult<()> { + let scan_op = dummy_scan_operator(vec![ + Field::new("groupby_key", DataType::Utf8), + Field::new("a", DataType::Int64), + Field::new("b", DataType::Int64), + ]); + + let plan = dummy_scan_node(scan_op.clone()) + .aggregate( + vec![ + col("a").sum(), + col("a").add(col("b")).sum().alias("a_plus_b"), + col("b").mean(), + col("b").mean().alias("c"), + ], + vec![col("groupby_key")], + )? + .build(); + + let expected = plan.clone(); + + assert_optimized_plan_eq(plan, expected)?; + Ok(()) + } +} diff --git a/src/daft-logical-plan/src/optimization/rules/mod.rs b/src/daft-logical-plan/src/optimization/rules/mod.rs index 06f6382ea8..78ce51533a 100644 --- a/src/daft-logical-plan/src/optimization/rules/mod.rs +++ b/src/daft-logical-plan/src/optimization/rules/mod.rs @@ -1,5 +1,6 @@ mod drop_repartition; mod eliminate_cross_join; +mod lift_project_from_agg; mod push_down_filter; mod push_down_limit; mod push_down_projection; @@ -8,6 +9,7 @@ mod split_actor_pool_projects; pub use drop_repartition::DropRepartition; pub use eliminate_cross_join::EliminateCrossJoin; +pub use lift_project_from_agg::LiftProjectFromAgg; pub use push_down_filter::PushDownFilter; pub use push_down_limit::PushDownLimit; pub use push_down_projection::PushDownProjection; diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs b/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs index 95f770c0ce..7c2391ccce 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs @@ -210,7 +210,7 @@ impl PushDownProjection { .aggregations .iter() .filter(|&e| required_columns.contains(e.name())) - .map(|ae| ae.into()) + .cloned() .collect::>(); if pruned_aggregate_exprs.len() < aggregate.aggregations.len() { diff --git a/src/daft-physical-plan/src/lib.rs b/src/daft-physical-plan/src/lib.rs index 7e2bb89c9d..ee94ea9ff0 100644 --- a/src/daft-physical-plan/src/lib.rs +++ b/src/daft-physical-plan/src/lib.rs @@ -12,7 +12,7 @@ mod treenode; mod test; pub use physical_planner::{ - logical_to_physical, populate_aggregation_stages, AdaptivePlanner, MaterializedResults, - QueryStageOutput, + extract_agg_expr, logical_to_physical, populate_aggregation_stages, AdaptivePlanner, + MaterializedResults, QueryStageOutput, }; pub use plan::{PhysicalPlan, PhysicalPlanRef}; diff --git a/src/daft-physical-plan/src/physical_planner/mod.rs b/src/daft-physical-plan/src/physical_planner/mod.rs index 2e68ec1e14..0cae9e62d7 100644 --- a/src/daft-physical-plan/src/physical_planner/mod.rs +++ b/src/daft-physical-plan/src/physical_planner/mod.rs @@ -11,7 +11,7 @@ pub use planner::{AdaptivePlanner, MaterializedResults, QueryStageOutput}; use crate::{optimization::optimizer::PhysicalOptimizer, PhysicalPlanRef}; mod translate; -pub use translate::populate_aggregation_stages; +pub use translate::{extract_agg_expr, populate_aggregation_stages}; /// Translate a logical plan to a physical plan. pub fn logical_to_physical( diff --git a/src/daft-physical-plan/src/physical_planner/translate.rs b/src/daft-physical-plan/src/physical_planner/translate.rs index 2ba75f1540..b3790111f2 100644 --- a/src/daft-physical-plan/src/physical_planner/translate.rs +++ b/src/daft-physical-plan/src/physical_planner/translate.rs @@ -11,7 +11,7 @@ use common_scan_info::PhysicalScanInfo; use daft_core::prelude::*; use daft_dsl::{ col, functions::agg::merge_mean, is_partition_compatible, AggExpr, ApproxPercentileParams, - ExprRef, SketchType, + Expr, ExprRef, SketchType, }; use daft_functions::numeric::sqrt; use daft_logical_plan::{ @@ -254,17 +254,22 @@ pub(super) fn translate_single_logical_node( let num_input_partitions = input_physical.clustering_spec().num_partitions(); + let aggregations = aggregations + .iter() + .map(extract_agg_expr) + .collect::>>()?; + let result_plan = match num_input_partitions { 1 => PhysicalPlan::Aggregate(Aggregate::new( input_physical, - aggregations.clone(), + aggregations, groupby.clone(), )), _ => { let schema = logical_plan.schema(); let (first_stage_aggs, second_stage_aggs, final_exprs) = - populate_aggregation_stages(aggregations, &schema, groupby); + populate_aggregation_stages(&aggregations, &schema, groupby); let (first_stage_agg, groupby) = if first_stage_aggs.is_empty() { (input_physical, groupby.clone()) @@ -750,6 +755,56 @@ pub(super) fn translate_single_logical_node( } } +pub fn extract_agg_expr(expr: &ExprRef) -> DaftResult { + match expr.as_ref() { + Expr::Agg(agg_expr) => Ok(agg_expr.clone()), + Expr::Alias(e, name) => extract_agg_expr(e).map(|agg_expr| { + // reorder expressions so that alias goes before agg + match agg_expr { + AggExpr::Count(e, count_mode) => { + AggExpr::Count(Expr::Alias(e, name.clone()).into(), count_mode) + } + AggExpr::Sum(e) => AggExpr::Sum(Expr::Alias(e, name.clone()).into()), + AggExpr::ApproxPercentile(ApproxPercentileParams { + child: e, + percentiles, + force_list_output, + }) => AggExpr::ApproxPercentile(ApproxPercentileParams { + child: Expr::Alias(e, name.clone()).into(), + percentiles, + force_list_output, + }), + AggExpr::ApproxCountDistinct(e) => { + AggExpr::ApproxCountDistinct(Expr::Alias(e, name.clone()).into()) + } + AggExpr::ApproxSketch(e, sketch_type) => { + AggExpr::ApproxSketch(Expr::Alias(e, name.clone()).into(), sketch_type) + } + AggExpr::MergeSketch(e, sketch_type) => { + AggExpr::MergeSketch(Expr::Alias(e, name.clone()).into(), sketch_type) + } + AggExpr::Mean(e) => AggExpr::Mean(Expr::Alias(e, name.clone()).into()), + AggExpr::Stddev(e) => AggExpr::Stddev(Expr::Alias(e, name.clone()).into()), + AggExpr::Min(e) => AggExpr::Min(Expr::Alias(e, name.clone()).into()), + AggExpr::Max(e) => AggExpr::Max(Expr::Alias(e, name.clone()).into()), + AggExpr::AnyValue(e, ignore_nulls) => { + AggExpr::AnyValue(Expr::Alias(e, name.clone()).into(), ignore_nulls) + } + AggExpr::List(e) => AggExpr::List(Expr::Alias(e, name.clone()).into()), + AggExpr::Concat(e) => AggExpr::Concat(Expr::Alias(e, name.clone()).into()), + AggExpr::MapGroups { func, inputs } => AggExpr::MapGroups { + func, + inputs: inputs + .into_iter() + .map(|input| input.alias(name.clone())) + .collect(), + }, + } + }), + _ => Err(DaftError::InternalError("Expected non-agg expressions in aggregation to be factored out before plan translation.".to_string())), + } +} + /// Given a list of aggregation expressions, return the aggregation expressions to apply in the first and second stages, /// as well as the final expressions to project. #[allow(clippy::type_complexity)] diff --git a/src/daft-scheduler/src/scheduler.rs b/src/daft-scheduler/src/scheduler.rs index ec72245253..05ede0e498 100644 --- a/src/daft-scheduler/src/scheduler.rs +++ b/src/daft-scheduler/src/scheduler.rs @@ -28,7 +28,6 @@ use { daft_core::prelude::SchemaRef, daft_core::python::PySchema, daft_dsl::python::PyExpr, - daft_dsl::Expr, daft_logical_plan::{OutputFileInfo, PyLogicalPlanBuilder}, daft_scan::python::pylib::PyScanTask, pyo3::{pyclass, pymethods, types::PyAnyMethods, PyObject, PyRef, PyRefMut, PyResult, Python}, @@ -264,6 +263,7 @@ fn physical_plan_to_partition_tasks( psets: &HashMap>, actor_pool_manager: &PyObject, ) -> PyResult { + use daft_dsl::Expr; use daft_physical_plan::ops::{ShuffleExchange, ShuffleExchangeStrategy}; match physical_plan { diff --git a/tests/dataframe/test_aggregations.py b/tests/dataframe/test_aggregations.py index f942410d77..27302c873b 100644 --- a/tests/dataframe/test_aggregations.py +++ b/tests/dataframe/test_aggregations.py @@ -464,3 +464,101 @@ def test_agg_any_value_ignore_nulls(make_df, repartition_nparts, with_morsel_siz res = daft_df.to_pydict() mapping = {res["group"][i]: res["any_value"][i] for i in range(len(res["group"]))} assert mapping == {1: 2, 2: 4, 3: None} + + +@pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) +def test_agg_with_non_agg_expr_global(make_df, repartition_nparts, with_morsel_size): + daft_df = make_df( + { + "id": [1, 2, 3], + "values": [4, 5, 6], + }, + repartition=repartition_nparts, + ) + + daft_df = daft_df.agg( + col("id").sum(), + col("values").mean().alias("values_mean"), + (col("id").mean() + col("values").mean()).alias("sum_of_means"), + ) + + res = daft_df.to_pydict() + assert res == {"id": [6], "values_mean": [5], "sum_of_means": [7]} + + +@pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) +def test_agg_with_non_agg_expr_groupby(make_df, repartition_nparts, with_morsel_size): + daft_df = make_df( + { + "group": [1, 1, 1, 2, 2, 2, 3, 3, 3], + "id": [1, 2, 3, 2, 3, 4, 3, 4, 5], + "values": [4, 5, 6, 5, 6, 7, 6, 7, 8], + }, + repartition=repartition_nparts, + ) + + daft_df = ( + daft_df.groupby("group") + .agg( + col("id").sum(), + col("values").mean().alias("values_mean"), + (col("id").mean() + col("values").mean()).alias("sum_of_means"), + ) + .sort("group") + ) + + res = daft_df.to_pydict() + assert res == {"group": [1, 2, 3], "id": [6, 9, 12], "values_mean": [5, 6, 7], "sum_of_means": [7, 9, 11]} + + +@pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) +def test_agg_with_literal_global(make_df, repartition_nparts, with_morsel_size): + daft_df = make_df( + { + "id": [1, 2, 3], + "values": [4, 5, 6], + }, + repartition=repartition_nparts, + ) + + daft_df = daft_df.agg( + col("id").sum(), + col("values").mean().alias("values_mean"), + (col("id").sum() + 1).alias("sum_plus_1"), + (col("id") + 1).sum().alias("1_plus_sum"), + ) + + res = daft_df.to_pydict() + assert res == {"id": [6], "values_mean": [5], "sum_plus_1": [7], "1_plus_sum": [9]} + + +@pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) +def test_agg_with_literal_groupby(make_df, repartition_nparts, with_morsel_size): + daft_df = make_df( + { + "group": [1, 1, 1, 2, 2, 2, 3, 3, 3], + "id": [1, 2, 3, 2, 3, 4, 3, 4, 5], + "values": [4, 5, 6, 5, 6, 7, 6, 7, 8], + }, + repartition=repartition_nparts, + ) + + daft_df = ( + daft_df.groupby("group") + .agg( + col("id").sum(), + col("values").mean().alias("values_mean"), + (col("id").sum() + 1).alias("sum_plus_1"), + (col("id") + 1).sum().alias("1_plus_sum"), + ) + .sort("group") + ) + + res = daft_df.to_pydict() + assert res == { + "group": [1, 2, 3], + "id": [6, 9, 12], + "values_mean": [5, 6, 7], + "sum_plus_1": [7, 10, 13], + "1_plus_sum": [9, 12, 15], + }