Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Enable group by keys in aggregation expressions #3399

Merged
merged 4 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions src/daft-dsl/src/resolve_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use std::{
cmp::Ordering,
collections::{BinaryHeap, HashMap},
collections::{BinaryHeap, HashMap, HashSet},
sync::Arc,
};

Expand Down Expand Up @@ -208,7 +208,7 @@
}

/// Checks if an expression used in an aggregation is well formed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
/// Checks if an expression used in an aggregation is well formed.
/// Checks if an expression used in an aggregation is well-formed.

adding new line makes look better in rustdocs

https://rust-lang.github.io/rust-clippy/master/index.html#too_long_first_doc_paragraph

/// Expressions for aggregations must be in the form (optional) non-agg expr <- agg exprs or literals <- non-agg exprs
/// Expressions for aggregations must be in the form (optional) non-agg expr <- [(agg exprs <- non-agg exprs) or literals or group by keys]
///
/// # Examples
///
Expand All @@ -217,19 +217,24 @@
/// - sum(col("a"))
/// - sum(col("a")) > 0
/// - sum(col("a")) - sum(col("b")) > sum(col("c"))
/// - sum(col("a")) + col("b") when "b" is a group by key
///
/// Not allowed:
/// - col("a")
/// - col("a") when "a" is not a group by key
/// - not an aggregation
/// - sum(col("a")) + col("b")
/// - not all branches are aggregations
fn has_single_agg_layer(expr: &ExprRef) -> bool {
match expr.as_ref() {
Expr::Agg(agg_expr) => !agg_expr.children().iter().any(has_agg),
Expr::Column(_) => false,
Expr::Literal(_) => true,
_ => expr.children().iter().all(has_single_agg_layer),
}
/// - sum(col("a")) + col("b") when "b" is not a group by key
/// - not all branches are aggregations, literals, or group by keys
fn has_single_agg_layer(expr: &ExprRef, groupby: &HashSet<ExprRef>) -> bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we move has_single_agg_layer and validate_expr_in_agg to be a member of ExprResolver. It looks like they're only ever used as part of the expr resolving.

groupby.contains(expr)
|| match expr.as_ref() {
Expr::Agg(agg_expr) => !agg_expr.children().iter().any(has_agg),
Expr::Column(_) => false,

Check warning on line 231 in src/daft-dsl/src/resolve_expr/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/resolve_expr/mod.rs#L231

Added line #L231 was not covered by tests
Expr::Literal(_) => true,
_ => expr
.children()
.iter()
.all(|e| has_single_agg_layer(e, groupby)),
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm I'm not sure this function really tests has_single_agg_layer anymore. Maybe rename or rethink logic? or make into two separate functions?

also potential bug (although unsure it could even occur):

groupby contains expr with multiple aggs -> returns true (as in single agg) even though not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point but I'm not really sure what to name it. I think the logic of the function itself is fairly straightforward though. I'll give it a thought

As for the bug, group bys should never have aggregations in them and will fail in another place in the code if they do.

}

fn convert_udfs_to_map_groups(expr: &ExprRef) -> ExprRef {
Expand Down Expand Up @@ -257,10 +262,10 @@
Ok(expr)
}

fn validate_expr_in_agg(expr: ExprRef) -> DaftResult<ExprRef> {
fn validate_expr_in_agg(expr: ExprRef, groupby: &HashSet<ExprRef>) -> DaftResult<ExprRef> {
let converted_expr = convert_udfs_to_map_groups(&expr);

if !has_single_agg_layer(&converted_expr) {
if !has_single_agg_layer(&converted_expr, groupby) {
return Err(DaftError::ValueError(format!(
"Expressions in aggregations must be composed of non-nested aggregation expressions, got {expr}"
)));
Expand All @@ -278,6 +283,8 @@
allow_stateful_udf: bool,
#[builder(default)]
in_agg_context: bool,
#[builder(default)]
groupby: HashSet<ExprRef>,
universalmind303 marked this conversation as resolved.
Show resolved Hide resolved
}

impl ExprResolver {
Expand All @@ -289,7 +296,7 @@
}

let validated_expr = if self.in_agg_context {
validate_expr_in_agg(expr)
validate_expr_in_agg(expr, &self.groupby)
} else {
validate_expr(expr)
}?;
Expand Down
7 changes: 5 additions & 2 deletions src/daft-logical-plan/src/ops/agg.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{collections::HashSet, sync::Arc};

use daft_dsl::{ExprRef, ExprResolver};
use daft_schema::schema::{Schema, SchemaRef};
Expand Down Expand Up @@ -37,7 +37,10 @@ impl Aggregate {
let upstream_schema = input.schema();

let groupby_resolver = ExprResolver::default();
let agg_resolver = ExprResolver::builder().in_agg_context(true).build();
let agg_resolver = ExprResolver::builder()
.in_agg_context(true)
.groupby(HashSet::from_iter(groupby.clone()))
.build();

universalmind303 marked this conversation as resolved.
Show resolved Hide resolved
let (groupby, groupby_fields) = groupby_resolver
.resolve(groupby, &upstream_schema)
Expand Down
7 changes: 5 additions & 2 deletions src/daft-logical-plan/src/ops/pivot.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{collections::HashSet, sync::Arc};

use common_error::DaftError;
use daft_core::prelude::*;
Expand Down Expand Up @@ -35,7 +35,10 @@ impl Pivot {
let upstream_schema = input.schema();

let expr_resolver = ExprResolver::default();
let agg_resolver = ExprResolver::builder().in_agg_context(true).build();
let agg_resolver = ExprResolver::builder()
.in_agg_context(true)
.groupby(HashSet::from_iter(group_by.clone()))
.build();

let (group_by, group_by_fields) = expr_resolver
.resolve(group_by, &upstream_schema)
Expand Down
30 changes: 30 additions & 0 deletions tests/dataframe/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,33 @@ def test_agg_with_literal_groupby(make_df, repartition_nparts, with_morsel_size)
"sum_plus_1": [7, 10, 13],
"1_plus_sum": [9, 12, 15],
}


@pytest.mark.parametrize("repartition_nparts", [1, 2, 4])
def test_agg_with_groupby_key_in_agg(make_df, repartition_nparts, with_morsel_size):
daft_df = make_df(
{
"group": [1, 1, 1, 2, 2, 2, 3, 3, 3],
"id": [1, 2, 3, 2, 3, 4, 3, 4, 5],
"values": [4, 5, 6, 5, 6, 7, 6, 7, 8],
},
repartition=repartition_nparts,
)

daft_df = (
daft_df.groupby("group")
.agg(
col("group").alias("group_alias"),
(col("group") + 1).alias("group_plus_1"),
(col("id").sum() + col("group")).alias("id_plus_group"),
)
.sort("group")
)

res = daft_df.to_pydict()
assert res == {
"group": [1, 2, 3],
"group_alias": [1, 2, 3],
"group_plus_1": [2, 3, 4],
"id_plus_group": [7, 11, 15],
}