Skip to content

Commit

Permalink
[CHORE] Add global aggregation docs and error on improper aggregation…
Browse files Browse the repository at this point in the history
… usage (#2025)

Added API docs for global expressions.
Also, the builder now checks if there are any aggregation expressions
outside of the top level of an agg call, and errors if there are. This
effectively disables nested aggs for now
  • Loading branch information
kevinzwang authored Mar 20, 2024
1 parent c2db062 commit 989d6f9
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 1 deletion.
12 changes: 12 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,18 @@ def max(self, *cols: ColumnInputType) -> "DataFrame":
"""
return self._apply_agg_fn(Expression.max, cols)

@DataframePublicAPI
def any_value(self, *cols: ColumnInputType) -> "DataFrame":
"""Returns an arbitrary value on this DataFrame.
Values for each column are not guaranteed to be from the same row.
Args:
*cols (Union[str, Expression]): columns to get an arbitrary value from
Returns:
DataFrame: DataFrame with any values.
"""
return self._apply_agg_fn(Expression.any_value, cols)

@DataframePublicAPI
def count(self, *cols: ColumnInputType) -> "DataFrame":
"""Performs a global count on the DataFrame
Expand Down
16 changes: 16 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,34 +343,50 @@ def floor(self) -> Expression:
return Expression._from_pyexpr(expr)

def count(self, mode: CountMode = CountMode.Valid) -> Expression:
"""Counts the number of values in the expression.
Args:
mode: whether to count all values, non-null (valid) values, or null values. Defaults to CountMode.Valid.
"""
expr = self._expr.count(mode)
return Expression._from_pyexpr(expr)

def sum(self) -> Expression:
"""Calculates the sum of the values in the expression"""
expr = self._expr.sum()
return Expression._from_pyexpr(expr)

def mean(self) -> Expression:
"""Calculates the mean of the values in the expression"""
expr = self._expr.mean()
return Expression._from_pyexpr(expr)

def min(self) -> Expression:
"""Calculates the minimum value in the expression"""
expr = self._expr.min()
return Expression._from_pyexpr(expr)

def max(self) -> Expression:
"""Calculates the maximum value in the expression"""
expr = self._expr.max()
return Expression._from_pyexpr(expr)

def any_value(self, ignore_nulls=False) -> Expression:
"""Returns any value in the expression
Args:
ignore_nulls: whether to ignore null values when selecting the value. Defaults to False.
"""
expr = self._expr.any_value(ignore_nulls)
return Expression._from_pyexpr(expr)

def agg_list(self) -> Expression:
"""Aggregates the values in the expression into a list"""
expr = self._expr.agg_list()
return Expression._from_pyexpr(expr)

def agg_concat(self) -> Expression:
"""Aggregates the values in the expression into a single string by concatenating them"""
expr = self._expr.agg_concat()
return Expression._from_pyexpr(expr)

Expand Down
19 changes: 19 additions & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,25 @@ Logical
Expression.__ge__
Expression.is_in

.. _api=aggregation-expression:

Aggregation
###########

The following can be used with DataFrame.agg or GroupedDataFrame.agg

.. autosummary::
:toctree: doc_gen/expression_methods

Expression.count
Expression.sum
Expression.mean
Expression.min
Expression.max
Expression.any_value
Expression.agg_list
Expression.agg_concat

.. _expression-accessor-properties:
.. _api-string-expression-operations:

Expand Down
98 changes: 97 additions & 1 deletion src/daft-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,24 @@ impl LogicalPlanBuilder {
}
}

fn check_for_agg(expr: &Expr) -> bool {
use Expr::*;

match expr {
Agg(_) => true,
Column(_) | Literal(_) => false,
Alias(e, _) | Cast(e, _) | Not(e) | IsNull(e) | NotNull(e) => check_for_agg(e),
BinaryOp { left, right, .. } => check_for_agg(left) || check_for_agg(right),
Function { inputs, .. } => inputs.iter().any(check_for_agg),
IsIn(l, r) => check_for_agg(l) || check_for_agg(r),
IfElse {
if_true,
if_false,
predicate,
} => check_for_agg(if_true) || check_for_agg(if_false) || check_for_agg(predicate),
}
}

fn extract_agg_expr(expr: &Expr) -> DaftResult<daft_dsl::AggExpr> {
use Expr::*;

Expand Down Expand Up @@ -86,6 +104,26 @@ fn extract_agg_expr(expr: &Expr) -> DaftResult<daft_dsl::AggExpr> {
}
}

fn extract_and_check_agg_expr(expr: &Expr) -> DaftResult<daft_dsl::AggExpr> {
use daft_dsl::AggExpr::*;

let agg_expr = extract_agg_expr(expr)?;
let has_nested_agg = match &agg_expr {
Count(e, _) | Sum(e) | Mean(e) | Min(e) | Max(e) | AnyValue(e, _) | List(e) | Concat(e) => {
check_for_agg(e)
}
MapGroups { inputs, .. } => inputs.iter().any(check_for_agg),
};

if has_nested_agg {
Err(DaftError::ValueError(format!(
"Nested aggregation expressions are not supported: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383"
)))
} else {
Ok(agg_expr)
}
}

impl LogicalPlanBuilder {
#[cfg(feature = "python")]
pub fn in_memory_scan(
Expand Down Expand Up @@ -142,12 +180,26 @@ impl LogicalPlanBuilder {
projection: Vec<Expr>,
resource_request: ResourceRequest,
) -> DaftResult<Self> {
for expr in &projection {
if check_for_agg(expr) {
return Err(DaftError::ValueError(format!(
"Aggregation expressions are not currently supported in projection: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383"
)));
}
}

let logical_plan: LogicalPlan =
logical_ops::Project::try_new(self.plan.clone(), projection, resource_request)?.into();
Ok(logical_plan.into())
}

pub fn filter(&self, predicate: Expr) -> DaftResult<Self> {
if check_for_agg(&predicate) {
return Err(DaftError::ValueError(format!(
"Aggregation expressions are not currently supported in filter: {predicate}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383"
)));
}

let logical_plan: LogicalPlan =
logical_ops::Filter::try_new(self.plan.clone(), predicate)?.into();
Ok(logical_plan.into())
Expand All @@ -160,12 +212,28 @@ impl LogicalPlanBuilder {
}

pub fn explode(&self, to_explode: Vec<Expr>) -> DaftResult<Self> {
for expr in &to_explode {
if check_for_agg(expr) {
return Err(DaftError::ValueError(format!(
"Aggregation expressions are not currently supported in explode: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383"
)));
}
}

let logical_plan: LogicalPlan =
logical_ops::Explode::try_new(self.plan.clone(), to_explode)?.into();
Ok(logical_plan.into())
}

pub fn sort(&self, sort_by: Vec<Expr>, descending: Vec<bool>) -> DaftResult<Self> {
for expr in &sort_by {
if check_for_agg(expr) {
return Err(DaftError::ValueError(format!(
"Aggregation expressions are not currently supported in sort: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383"
)));
}
}

let logical_plan: LogicalPlan =
logical_ops::Sort::try_new(self.plan.clone(), sort_by, descending)?.into();
Ok(logical_plan.into())
Expand All @@ -176,6 +244,14 @@ impl LogicalPlanBuilder {
num_partitions: Option<usize>,
partition_by: Vec<Expr>,
) -> DaftResult<Self> {
for expr in &partition_by {
if check_for_agg(expr) {
return Err(DaftError::ValueError(format!(
"Aggregation expressions are not currently supported in hash repartition: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383"
)));
}
}

let logical_plan: LogicalPlan = logical_ops::Repartition::try_new(
self.plan.clone(),
RepartitionSpec::Hash(HashRepartitionConfig::new(num_partitions, partition_by)),
Expand Down Expand Up @@ -221,7 +297,7 @@ impl LogicalPlanBuilder {
pub fn aggregate(&self, agg_exprs: Vec<Expr>, groupby_exprs: Vec<Expr>) -> DaftResult<Self> {
let agg_exprs = agg_exprs
.iter()
.map(extract_agg_expr)
.map(extract_and_check_agg_expr)
.collect::<DaftResult<Vec<daft_dsl::AggExpr>>>()?;

let logical_plan: LogicalPlan =
Expand All @@ -237,6 +313,16 @@ impl LogicalPlanBuilder {
join_type: JoinType,
join_strategy: Option<JoinStrategy>,
) -> DaftResult<Self> {
for side in [&left_on, &right_on] {
for expr in side {
if check_for_agg(expr) {
return Err(DaftError::ValueError(format!(
"Aggregation expressions are not currently supported in join: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383"
)));
}
}
}

let logical_plan: LogicalPlan = logical_ops::Join::try_new(
self.plan.clone(),
right.plan.clone(),
Expand Down Expand Up @@ -269,6 +355,16 @@ impl LogicalPlanBuilder {
compression: Option<String>,
io_config: Option<IOConfig>,
) -> DaftResult<Self> {
if let Some(partition_cols) = &partition_cols {
for expr in partition_cols {
if check_for_agg(expr) {
return Err(DaftError::ValueError(format!(
"Aggregation expressions are not currently supported in table write: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383"
)));
}
}
}

let sink_info = SinkInfo::OutputFileInfo(OutputFileInfo::new(
root_dir.into(),
file_format,
Expand Down

0 comments on commit 989d6f9

Please sign in to comment.