From 989d6f9ad285a902261268a2e588150f3ea71862 Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Wed, 20 Mar 2024 12:33:27 -0700 Subject: [PATCH] [CHORE] Add global aggregation docs and error on improper aggregation usage (#2025) Added API docs for global expressions. Also, the builder now checks if there are any aggregation expressions outside of the top level of an agg call, and errors if there are. This effectively disables nested aggs for now --- daft/dataframe/dataframe.py | 12 ++++ daft/expressions/expressions.py | 16 +++++ docs/source/api_docs/expressions.rst | 19 ++++++ src/daft-plan/src/builder.rs | 98 +++++++++++++++++++++++++++- 4 files changed, 144 insertions(+), 1 deletion(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index e9a8af32d8..b9f5ecc66a 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -1133,6 +1133,18 @@ def max(self, *cols: ColumnInputType) -> "DataFrame": """ return self._apply_agg_fn(Expression.max, cols) + @DataframePublicAPI + def any_value(self, *cols: ColumnInputType) -> "DataFrame": + """Returns an arbitrary value on this DataFrame. + Values for each column are not guaranteed to be from the same row. + + Args: + *cols (Union[str, Expression]): columns to get an arbitrary value from + Returns: + DataFrame: DataFrame with any values. + """ + return self._apply_agg_fn(Expression.any_value, cols) + @DataframePublicAPI def count(self, *cols: ColumnInputType) -> "DataFrame": """Performs a global count on the DataFrame diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 2489af0d22..e3dfd0724a 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -343,34 +343,50 @@ def floor(self) -> Expression: return Expression._from_pyexpr(expr) def count(self, mode: CountMode = CountMode.Valid) -> Expression: + """Counts the number of values in the expression. + + Args: + mode: whether to count all values, non-null (valid) values, or null values. Defaults to CountMode.Valid. + """ expr = self._expr.count(mode) return Expression._from_pyexpr(expr) def sum(self) -> Expression: + """Calculates the sum of the values in the expression""" expr = self._expr.sum() return Expression._from_pyexpr(expr) def mean(self) -> Expression: + """Calculates the mean of the values in the expression""" expr = self._expr.mean() return Expression._from_pyexpr(expr) def min(self) -> Expression: + """Calculates the minimum value in the expression""" expr = self._expr.min() return Expression._from_pyexpr(expr) def max(self) -> Expression: + """Calculates the maximum value in the expression""" expr = self._expr.max() return Expression._from_pyexpr(expr) def any_value(self, ignore_nulls=False) -> Expression: + """Returns any value in the expression + + Args: + ignore_nulls: whether to ignore null values when selecting the value. Defaults to False. + """ expr = self._expr.any_value(ignore_nulls) return Expression._from_pyexpr(expr) def agg_list(self) -> Expression: + """Aggregates the values in the expression into a list""" expr = self._expr.agg_list() return Expression._from_pyexpr(expr) def agg_concat(self) -> Expression: + """Aggregates the values in the expression into a single string by concatenating them""" expr = self._expr.agg_concat() return Expression._from_pyexpr(expr) diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index afd07e5fc4..028dfc3dce 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -69,6 +69,25 @@ Logical Expression.__ge__ Expression.is_in +.. _api=aggregation-expression: + +Aggregation +########### + +The following can be used with DataFrame.agg or GroupedDataFrame.agg + +.. autosummary:: + :toctree: doc_gen/expression_methods + + Expression.count + Expression.sum + Expression.mean + Expression.min + Expression.max + Expression.any_value + Expression.agg_list + Expression.agg_concat + .. _expression-accessor-properties: .. _api-string-expression-operations: diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index e10eb057da..7f3cdf5cc0 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -47,6 +47,24 @@ impl LogicalPlanBuilder { } } +fn check_for_agg(expr: &Expr) -> bool { + use Expr::*; + + match expr { + Agg(_) => true, + Column(_) | Literal(_) => false, + Alias(e, _) | Cast(e, _) | Not(e) | IsNull(e) | NotNull(e) => check_for_agg(e), + BinaryOp { left, right, .. } => check_for_agg(left) || check_for_agg(right), + Function { inputs, .. } => inputs.iter().any(check_for_agg), + IsIn(l, r) => check_for_agg(l) || check_for_agg(r), + IfElse { + if_true, + if_false, + predicate, + } => check_for_agg(if_true) || check_for_agg(if_false) || check_for_agg(predicate), + } +} + fn extract_agg_expr(expr: &Expr) -> DaftResult { use Expr::*; @@ -86,6 +104,26 @@ fn extract_agg_expr(expr: &Expr) -> DaftResult { } } +fn extract_and_check_agg_expr(expr: &Expr) -> DaftResult { + use daft_dsl::AggExpr::*; + + let agg_expr = extract_agg_expr(expr)?; + let has_nested_agg = match &agg_expr { + Count(e, _) | Sum(e) | Mean(e) | Min(e) | Max(e) | AnyValue(e, _) | List(e) | Concat(e) => { + check_for_agg(e) + } + MapGroups { inputs, .. } => inputs.iter().any(check_for_agg), + }; + + if has_nested_agg { + Err(DaftError::ValueError(format!( + "Nested aggregation expressions are not supported: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383" + ))) + } else { + Ok(agg_expr) + } +} + impl LogicalPlanBuilder { #[cfg(feature = "python")] pub fn in_memory_scan( @@ -142,12 +180,26 @@ impl LogicalPlanBuilder { projection: Vec, resource_request: ResourceRequest, ) -> DaftResult { + for expr in &projection { + if check_for_agg(expr) { + return Err(DaftError::ValueError(format!( + "Aggregation expressions are not currently supported in projection: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383" + ))); + } + } + let logical_plan: LogicalPlan = logical_ops::Project::try_new(self.plan.clone(), projection, resource_request)?.into(); Ok(logical_plan.into()) } pub fn filter(&self, predicate: Expr) -> DaftResult { + if check_for_agg(&predicate) { + return Err(DaftError::ValueError(format!( + "Aggregation expressions are not currently supported in filter: {predicate}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383" + ))); + } + let logical_plan: LogicalPlan = logical_ops::Filter::try_new(self.plan.clone(), predicate)?.into(); Ok(logical_plan.into()) @@ -160,12 +212,28 @@ impl LogicalPlanBuilder { } pub fn explode(&self, to_explode: Vec) -> DaftResult { + for expr in &to_explode { + if check_for_agg(expr) { + return Err(DaftError::ValueError(format!( + "Aggregation expressions are not currently supported in explode: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383" + ))); + } + } + let logical_plan: LogicalPlan = logical_ops::Explode::try_new(self.plan.clone(), to_explode)?.into(); Ok(logical_plan.into()) } pub fn sort(&self, sort_by: Vec, descending: Vec) -> DaftResult { + for expr in &sort_by { + if check_for_agg(expr) { + return Err(DaftError::ValueError(format!( + "Aggregation expressions are not currently supported in sort: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383" + ))); + } + } + let logical_plan: LogicalPlan = logical_ops::Sort::try_new(self.plan.clone(), sort_by, descending)?.into(); Ok(logical_plan.into()) @@ -176,6 +244,14 @@ impl LogicalPlanBuilder { num_partitions: Option, partition_by: Vec, ) -> DaftResult { + for expr in &partition_by { + if check_for_agg(expr) { + return Err(DaftError::ValueError(format!( + "Aggregation expressions are not currently supported in hash repartition: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383" + ))); + } + } + let logical_plan: LogicalPlan = logical_ops::Repartition::try_new( self.plan.clone(), RepartitionSpec::Hash(HashRepartitionConfig::new(num_partitions, partition_by)), @@ -221,7 +297,7 @@ impl LogicalPlanBuilder { pub fn aggregate(&self, agg_exprs: Vec, groupby_exprs: Vec) -> DaftResult { let agg_exprs = agg_exprs .iter() - .map(extract_agg_expr) + .map(extract_and_check_agg_expr) .collect::>>()?; let logical_plan: LogicalPlan = @@ -237,6 +313,16 @@ impl LogicalPlanBuilder { join_type: JoinType, join_strategy: Option, ) -> DaftResult { + for side in [&left_on, &right_on] { + for expr in side { + if check_for_agg(expr) { + return Err(DaftError::ValueError(format!( + "Aggregation expressions are not currently supported in join: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383" + ))); + } + } + } + let logical_plan: LogicalPlan = logical_ops::Join::try_new( self.plan.clone(), right.plan.clone(), @@ -269,6 +355,16 @@ impl LogicalPlanBuilder { compression: Option, io_config: Option, ) -> DaftResult { + if let Some(partition_cols) = &partition_cols { + for expr in partition_cols { + if check_for_agg(expr) { + return Err(DaftError::ValueError(format!( + "Aggregation expressions are not currently supported in table write: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383" + ))); + } + } + } + let sink_info = SinkInfo::OutputFileInfo(OutputFileInfo::new( root_dir.into(), file_format,