-
Notifications
You must be signed in to change notification settings - Fork 174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT] Enable group by keys in aggregation expressions #3399
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
|
||
use std::{ | ||
cmp::Ordering, | ||
collections::{BinaryHeap, HashMap}, | ||
collections::{BinaryHeap, HashMap, HashSet}, | ||
sync::Arc, | ||
}; | ||
|
||
|
@@ -208,7 +208,7 @@ | |
} | ||
|
||
/// Checks if an expression used in an aggregation is well formed. | ||
/// Expressions for aggregations must be in the form (optional) non-agg expr <- agg exprs or literals <- non-agg exprs | ||
/// Expressions for aggregations must be in the form (optional) non-agg expr <- [(agg exprs <- non-agg exprs) or literals or group by keys] | ||
/// | ||
/// # Examples | ||
/// | ||
|
@@ -217,19 +217,24 @@ | |
/// - sum(col("a")) | ||
/// - sum(col("a")) > 0 | ||
/// - sum(col("a")) - sum(col("b")) > sum(col("c")) | ||
/// - sum(col("a")) + col("b") when "b" is a group by key | ||
/// | ||
/// Not allowed: | ||
/// - col("a") | ||
/// - col("a") when "a" is not a group by key | ||
/// - not an aggregation | ||
/// - sum(col("a")) + col("b") | ||
/// - not all branches are aggregations | ||
fn has_single_agg_layer(expr: &ExprRef) -> bool { | ||
match expr.as_ref() { | ||
Expr::Agg(agg_expr) => !agg_expr.children().iter().any(has_agg), | ||
Expr::Column(_) => false, | ||
Expr::Literal(_) => true, | ||
_ => expr.children().iter().all(has_single_agg_layer), | ||
} | ||
/// - sum(col("a")) + col("b") when "b" is not a group by key | ||
/// - not all branches are aggregations, literals, or group by keys | ||
fn has_single_agg_layer(expr: &ExprRef, groupby: &HashSet<ExprRef>) -> bool { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we move |
||
groupby.contains(expr) | ||
|| match expr.as_ref() { | ||
Expr::Agg(agg_expr) => !agg_expr.children().iter().any(has_agg), | ||
Expr::Column(_) => false, | ||
Expr::Literal(_) => true, | ||
_ => expr | ||
.children() | ||
.iter() | ||
.all(|e| has_single_agg_layer(e, groupby)), | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmmm I'm not sure this function really tests also potential bug (although unsure it could even occur): groupby contains expr with multiple aggs -> returns true (as in single agg) even though not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point but I'm not really sure what to name it. I think the logic of the function itself is fairly straightforward though. I'll give it a thought As for the bug, group bys should never have aggregations in them and will fail in another place in the code if they do. |
||
} | ||
|
||
fn convert_udfs_to_map_groups(expr: &ExprRef) -> ExprRef { | ||
|
@@ -257,10 +262,10 @@ | |
Ok(expr) | ||
} | ||
|
||
fn validate_expr_in_agg(expr: ExprRef) -> DaftResult<ExprRef> { | ||
fn validate_expr_in_agg(expr: ExprRef, groupby: &HashSet<ExprRef>) -> DaftResult<ExprRef> { | ||
let converted_expr = convert_udfs_to_map_groups(&expr); | ||
|
||
if !has_single_agg_layer(&converted_expr) { | ||
if !has_single_agg_layer(&converted_expr, groupby) { | ||
return Err(DaftError::ValueError(format!( | ||
"Expressions in aggregations must be composed of non-nested aggregation expressions, got {expr}" | ||
))); | ||
|
@@ -278,6 +283,8 @@ | |
allow_stateful_udf: bool, | ||
#[builder(default)] | ||
in_agg_context: bool, | ||
#[builder(default)] | ||
groupby: HashSet<ExprRef>, | ||
universalmind303 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
impl ExprResolver { | ||
|
@@ -289,7 +296,7 @@ | |
} | ||
|
||
let validated_expr = if self.in_agg_context { | ||
validate_expr_in_agg(expr) | ||
validate_expr_in_agg(expr, &self.groupby) | ||
} else { | ||
validate_expr(expr) | ||
}?; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
adding new line makes look better in rustdocs
https://rust-lang.github.io/rust-clippy/master/index.html#too_long_first_doc_paragraph