Skip to content

Commit

Permalink
[FEAT] Support for aggregation expressions that use multiple AggExprs (
Browse files Browse the repository at this point in the history
…#3296)

This enables expressions such as `sum("a") + sum("b")` or `mean("a") /
100` in aggregations. This PR enables Q8 and Q14 of TPC-H and is also
necessary for Q17 and Q20 (which are also missing subquery).
  • Loading branch information
kevinzwang authored Nov 15, 2024
1 parent 25c3b26 commit e18b719
Show file tree
Hide file tree
Showing 28 changed files with 700 additions and 226 deletions.
23 changes: 22 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ tokio = {version = "1.37.0", features = [
tokio-stream = {version = "0.1.14", features = ["fs", "io-util", "time"]}
tokio-util = "0.7.11"
tracing = "0.1"
typed-builder = "0.20.0"
typetag = "0.2.18"
url = "2.4.0"
xxhash-rust = "0.8.12"
Expand Down
1 change: 1 addition & 0 deletions src/daft-dsl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ itertools = {workspace = true}
log = {workspace = true}
pyo3 = {workspace = true, optional = true}
serde = {workspace = true}
typed-builder = {workspace = true}
typetag = {workspace = true}

[features]
Expand Down
5 changes: 1 addition & 4 deletions src/daft-dsl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ pub use expr::{
pub use lit::{lit, literal_value, literals_to_series, null_lit, Literal, LiteralValue};
#[cfg(feature = "python")]
use pyo3::prelude::*;
pub use resolve_expr::{
check_column_name_validity, resolve_aggexprs, resolve_exprs, resolve_single_aggexpr,
resolve_single_expr,
};
pub use resolve_expr::{check_column_name_validity, ExprResolver};

#[cfg(feature = "python")]
pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
Expand Down
262 changes: 102 additions & 160 deletions src/daft-dsl/src/resolve_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ use std::{
use common_error::{DaftError, DaftResult};
use common_treenode::{Transformed, TransformedResult, TreeNode};
use daft_core::prelude::*;
use typed_builder::TypedBuilder;

use crate::{col, expr::has_agg, has_stateful_udf, AggExpr, ApproxPercentileParams, Expr, ExprRef};
use crate::{
col, expr::has_agg, functions::FunctionExpr, has_stateful_udf, AggExpr, Expr, ExprRef,
};

// Calculates all the possible struct get expressions in a schema.
// For each sugared string, calculates all possible corresponding expressions, in order of priority.
Expand Down Expand Up @@ -204,192 +207,131 @@ fn expand_wildcards(
}
}

fn extract_agg_expr(expr: &Expr) -> DaftResult<AggExpr> {
match expr {
Expr::Agg(agg_expr) => Ok(agg_expr.clone()),
Expr::Function { func, inputs } => Ok(AggExpr::MapGroups {
func: func.clone(),
inputs: inputs.clone(),
}),
Expr::Alias(e, name) => extract_agg_expr(e).map(|agg_expr| {
// reorder expressions so that alias goes before agg
match agg_expr {
AggExpr::Count(e, count_mode) => {
AggExpr::Count(Expr::Alias(e, name.clone()).into(), count_mode)
}
AggExpr::Sum(e) => AggExpr::Sum(Expr::Alias(e, name.clone()).into()),
AggExpr::ApproxPercentile(ApproxPercentileParams {
child: e,
percentiles,
force_list_output,
}) => AggExpr::ApproxPercentile(ApproxPercentileParams {
child: Expr::Alias(e, name.clone()).into(),
percentiles,
force_list_output,
}),
AggExpr::ApproxCountDistinct(e) => {
AggExpr::ApproxCountDistinct(Expr::Alias(e, name.clone()).into())
}
AggExpr::ApproxSketch(e, sketch_type) => {
AggExpr::ApproxSketch(Expr::Alias(e, name.clone()).into(), sketch_type)
}
AggExpr::MergeSketch(e, sketch_type) => {
AggExpr::MergeSketch(Expr::Alias(e, name.clone()).into(), sketch_type)
}
AggExpr::Mean(e) => AggExpr::Mean(Expr::Alias(e, name.clone()).into()),
AggExpr::Stddev(e) => AggExpr::Stddev(Expr::Alias(e, name.clone()).into()),
AggExpr::Min(e) => AggExpr::Min(Expr::Alias(e, name.clone()).into()),
AggExpr::Max(e) => AggExpr::Max(Expr::Alias(e, name.clone()).into()),
AggExpr::AnyValue(e, ignore_nulls) => {
AggExpr::AnyValue(Expr::Alias(e, name.clone()).into(), ignore_nulls)
}
AggExpr::List(e) => AggExpr::List(Expr::Alias(e, name.clone()).into()),
AggExpr::Concat(e) => AggExpr::Concat(Expr::Alias(e, name.clone()).into()),
AggExpr::MapGroups { func, inputs } => AggExpr::MapGroups {
func,
inputs: inputs
.into_iter()
.map(|input| input.alias(name.clone()))
.collect(),
},
}
}),
// TODO(Kevin): Support a mix of aggregation and non-aggregation expressions
// as long as the final value always has a cardinality of 1.
_ => Err(DaftError::ValueError(format!(
"Expected aggregation expression, but got: {expr}"
))),
/// Checks if an expression used in an aggregation is well formed.
/// Expressions for aggregations must be in the form (optional) non-agg expr <- agg exprs or literals <- non-agg exprs
///
/// # Examples
///
/// Allowed:
/// - lit("x")
/// - sum(col("a"))
/// - sum(col("a")) > 0
/// - sum(col("a")) - sum(col("b")) > sum(col("c"))
///
/// Not allowed:
/// - col("a")
/// - not an aggregation
/// - sum(col("a")) + col("b")
/// - not all branches are aggregations
fn has_single_agg_layer(expr: &ExprRef) -> bool {
match expr.as_ref() {
Expr::Agg(agg_expr) => !agg_expr.children().iter().any(has_agg),
Expr::Column(_) => false,
Expr::Literal(_) => true,
_ => expr.children().iter().all(has_single_agg_layer),
}
}

/// Resolves and validates the expression with a schema, returning the new expression and its field.
/// Specifically, makes sure the expression does not contain aggregations or stateful UDFs when they are not allowed,
/// and resolves struct accessors and wildcards.
/// May return multiple expressions if the expr contains a wildcard.
///
/// TODO: Use a builder pattern for this functionality
fn resolve_expr(
expr: ExprRef,
schema: &Schema,
allow_stateful_udf: bool,
) -> DaftResult<Vec<ExprRef>> {
// TODO(Kevin): Support aggregation expressions everywhere
fn convert_udfs_to_map_groups(expr: &ExprRef) -> ExprRef {
expr.clone()
.transform(|e| match e.as_ref() {
Expr::Function { func, inputs } if matches!(func, FunctionExpr::Python(_)) => {
Ok(Transformed::yes(Arc::new(Expr::Agg(AggExpr::MapGroups {
func: func.clone(),
inputs: inputs.clone(),
}))))
}
_ => Ok(Transformed::no(e)),
})
.unwrap()
.data
}

fn validate_expr(expr: ExprRef) -> DaftResult<ExprRef> {
if has_agg(&expr) {
return Err(DaftError::ValueError(format!(
"Aggregation expressions are currently only allowed in agg and pivot: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383",
)));
}

if !allow_stateful_udf && has_stateful_udf(&expr) {
Ok(expr)
}

fn validate_expr_in_agg(expr: ExprRef) -> DaftResult<ExprRef> {
let converted_expr = convert_udfs_to_map_groups(&expr);

if !has_single_agg_layer(&converted_expr) {
return Err(DaftError::ValueError(format!(
"Stateful UDFs are only allowed in projections: {expr}"
"Expressions in aggregations must be composed of non-nested aggregation expressions, got {expr}"
)));
}

let struct_expr_map = calculate_struct_expr_map(schema);
expand_wildcards(expr, schema, &struct_expr_map)?
.into_iter()
.map(|e| transform_struct_gets(e, &struct_expr_map))
.collect()
Ok(converted_expr)
}

// Resolve a single expression, erroring if any kind of expansion happens.
pub fn resolve_single_expr(
expr: ExprRef,
schema: &Schema,
/// Used for resolving and validating expressions.
/// Specifically, makes sure the expression does not contain aggregations or stateful UDFs
/// where they are not allowed, and resolves struct accessors and wildcards.
#[derive(Default, TypedBuilder)]
pub struct ExprResolver {
#[builder(default)]
allow_stateful_udf: bool,
) -> DaftResult<(ExprRef, Field)> {
let resolved_exprs = resolve_expr(expr.clone(), schema, allow_stateful_udf)?;
match resolved_exprs.as_slice() {
[resolved_expr] => Ok((resolved_expr.clone(), resolved_expr.to_field(schema)?)),
_ => Err(DaftError::ValueError(format!(
"Error resolving expression {}: expanded into {} expressions when 1 was expected",
expr,
resolved_exprs.len()
))),
}
#[builder(default)]
in_agg_context: bool,
}

pub fn resolve_exprs(
exprs: Vec<ExprRef>,
schema: &Schema,
allow_stateful_udf: bool,
) -> DaftResult<(Vec<ExprRef>, Vec<Field>)> {
// can't flat map because we need to deal with errors
let resolved_exprs: DaftResult<Vec<Vec<ExprRef>>> = exprs
.into_iter()
.map(|e| resolve_expr(e, schema, allow_stateful_udf))
.collect();
let resolved_exprs: Vec<ExprRef> = resolved_exprs?.into_iter().flatten().collect();
let resolved_fields: DaftResult<Vec<Field>> =
resolved_exprs.iter().map(|e| e.to_field(schema)).collect();
Ok((resolved_exprs, resolved_fields?))
}
impl ExprResolver {
fn resolve_helper(&self, expr: ExprRef, schema: &Schema) -> DaftResult<Vec<ExprRef>> {
if !self.allow_stateful_udf && has_stateful_udf(&expr) {
return Err(DaftError::ValueError(format!(
"Stateful UDFs are only allowed in projections: {expr}"
)));
}

/// Resolves and validates the expression with a schema, returning the extracted aggregation expression and its field.
/// Specifically, makes sure the expression does not contain aggregationsnested or stateful UDFs,
/// and resolves struct accessors and wildcards.
/// May return multiple expressions if the expr contains a wildcard.
///
/// TODO: Use a builder pattern for this functionality
fn resolve_aggexpr(expr: ExprRef, schema: &Schema) -> DaftResult<Vec<AggExpr>> {
let has_nested_agg = extract_agg_expr(&expr)?.children().iter().any(has_agg);
let validated_expr = if self.in_agg_context {
validate_expr_in_agg(expr)
} else {
validate_expr(expr)
}?;

if has_nested_agg {
return Err(DaftError::ValueError(format!(
"Nested aggregation expressions are not supported: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383"
)));
let struct_expr_map = calculate_struct_expr_map(schema);
expand_wildcards(validated_expr, schema, &struct_expr_map)?
.into_iter()
.map(|e| transform_struct_gets(e, &struct_expr_map))
.collect()
}

if has_stateful_udf(&expr) {
return Err(DaftError::ValueError(format!(
"Stateful UDFs are only allowed in projections: {expr}"
)));
/// Resolve multiple expressions. Due to wildcards, output vec may contain more expressions than input.
pub fn resolve(
&self,
exprs: Vec<ExprRef>,
schema: &Schema,
) -> DaftResult<(Vec<ExprRef>, Vec<Field>)> {
// can't flat map because we need to deal with errors
let resolved_exprs: DaftResult<Vec<Vec<ExprRef>>> = exprs
.into_iter()
.map(|e| self.resolve_helper(e, schema))
.collect();
let resolved_exprs: Vec<ExprRef> = resolved_exprs?.into_iter().flatten().collect();
let resolved_fields: DaftResult<Vec<Field>> =
resolved_exprs.iter().map(|e| e.to_field(schema)).collect();
Ok((resolved_exprs, resolved_fields?))
}

let struct_expr_map = calculate_struct_expr_map(schema);
expand_wildcards(expr, schema, &struct_expr_map)?
.into_iter()
.map(|expr| {
let agg_expr = extract_agg_expr(&expr)?;

let resolved_children = agg_expr
.children()
.into_iter()
.map(|e| transform_struct_gets(e, &struct_expr_map))
.collect::<DaftResult<Vec<_>>>()?;
Ok(agg_expr.with_new_children(resolved_children))
})
.collect()
}

pub fn resolve_single_aggexpr(expr: ExprRef, schema: &Schema) -> DaftResult<(AggExpr, Field)> {
let resolved_exprs = resolve_aggexpr(expr.clone(), schema)?;
match resolved_exprs.as_slice() {
[resolved_expr] => Ok((resolved_expr.clone(), resolved_expr.to_field(schema)?)),
_ => Err(DaftError::ValueError(format!(
"Error resolving expression {}: expanded into {} expressions when 1 was expected",
expr,
resolved_exprs.len()
))),
/// Resolve a single expression, ensuring that the output is also a single expression.
pub fn resolve_single(&self, expr: ExprRef, schema: &Schema) -> DaftResult<(ExprRef, Field)> {
let resolved_exprs = self.resolve_helper(expr.clone(), schema)?;
match resolved_exprs.as_slice() {
[resolved_expr] => Ok((resolved_expr.clone(), resolved_expr.to_field(schema)?)),
_ => Err(DaftError::ValueError(format!(
"Error resolving expression {}: expanded into {} expressions when 1 was expected",
expr,
resolved_exprs.len()
))),
}
}
}

pub fn resolve_aggexprs(
exprs: Vec<ExprRef>,
schema: &Schema,
) -> DaftResult<(Vec<AggExpr>, Vec<Field>)> {
// can't flat map because we need to deal with errors
let resolved_exprs: DaftResult<Vec<Vec<AggExpr>>> = exprs
.into_iter()
.map(|e| resolve_aggexpr(e, schema))
.collect();
let resolved_exprs: Vec<AggExpr> = resolved_exprs?.into_iter().flatten().collect();
let resolved_fields: DaftResult<Vec<Field>> =
resolved_exprs.iter().map(|e| e.to_field(schema)).collect();
Ok((resolved_exprs, resolved_fields?))
}

pub fn check_column_name_validity(name: &str, schema: &Schema) -> DaftResult<()> {
let struct_expr_map = calculate_struct_expr_map(schema);

Expand Down
Loading

0 comments on commit e18b719

Please sign in to comment.