From b493f28492390739c250d9a851724650993bb484 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Wed, 21 Feb 2024 15:05:41 -0800 Subject: [PATCH] add all functions to implement --- daft/daft.pyi | 1 + daft/dataframe/dataframe.py | 11 +++++ daft/expressions/expressions.py | 4 ++ daft/logical/builder.py | 2 + src/daft-core/src/array/ops/any_value.rs | 43 +++++++++++++++++++ src/daft-core/src/array/ops/mod.rs | 7 +++ .../src/series/array_impl/data_array.rs | 10 +++++ .../src/series/array_impl/logical_array.rs | 10 +++++ .../src/series/array_impl/nested_array.rs | 4 ++ src/daft-core/src/series/ops/agg.rs | 4 ++ src/daft-core/src/series/series_like.rs | 1 + src/daft-dsl/src/expr.rs | 14 +++++- src/daft-dsl/src/python.rs | 4 ++ src/daft-dsl/src/treenode.rs | 2 + src/daft-plan/src/logical_ops/project.rs | 4 ++ src/daft-plan/src/planner.rs | 3 ++ src/daft-table/src/lib.rs | 1 + 17 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 src/daft-core/src/array/ops/any_value.rs diff --git a/daft/daft.pyi b/daft/daft.pyi index 46fc5d9448..50ca8440b6 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -869,6 +869,7 @@ class PyExpr: def mean(self) -> PyExpr: ... def min(self) -> PyExpr: ... def max(self) -> PyExpr: ... + def any_value(self) -> PyExpr: ... def agg_list(self) -> PyExpr: ... def agg_concat(self) -> PyExpr: ... def explode(self) -> PyExpr: ... diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 32f8bb726e..7d503663ed 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -1546,6 +1546,17 @@ def max(self, *cols: ColumnInputType) -> "DataFrame": return self.df._agg([(c, "max") for c in cols], group_by=self.group_by) + def any_value(self, *cols: ColumnInputType) -> "DataFrame": + """Returns a single non-deterministic value from each group in this GroupedDataFrame + + Args: + *cols (Union[str, Expression]): columns to get + + Returns: + DataFrame: DataFrame with any values. + """ + return self.df._agg([(c, "any_value") for c in cols], group_by=self.group_by) + def count(self) -> "DataFrame": """Performs grouped count on this GroupedDataFrame. diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 2901ce690d..c5755531a1 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -340,6 +340,10 @@ def _max(self) -> Expression: expr = self._expr.max() return Expression._from_pyexpr(expr) + def _any_value(self) -> Expression: + expr = self._expr.any_value() + return Expression._from_pyexpr(expr) + def _agg_list(self) -> Expression: expr = self._expr.agg_list() return Expression._from_pyexpr(expr) diff --git a/daft/logical/builder.py b/daft/logical/builder.py index f5ad0451aa..1550c46114 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -171,6 +171,8 @@ def agg( exprs.append(expr._max()) elif op == "mean": exprs.append(expr._mean()) + elif op == "any_value": + exprs.append(expr._any_value()) elif op == "list": exprs.append(expr._agg_list()) elif op == "concat": diff --git a/src/daft-core/src/array/ops/any_value.rs b/src/daft-core/src/array/ops/any_value.rs new file mode 100644 index 0000000000..34619f636a --- /dev/null +++ b/src/daft-core/src/array/ops/any_value.rs @@ -0,0 +1,43 @@ +use common_error::DaftResult; + +use crate::{ + array::{DataArray, FixedSizeListArray, ListArray, StructArray}, + datatypes::DaftPhysicalType, +}; + +use super::DaftAnyValueAggable; + +impl DaftAnyValueAggable for DataArray +where + T: DaftPhysicalType, +{ + type Output = DaftResult>; + + fn any_value(&self) -> Self::Output { + todo!() + } + + fn grouped_any_value(&self, _groups: &super::GroupIndices) -> Self::Output { + todo!() + } +} + +macro_rules! impl_daft_any_value_nested_array { + ($arr:ident) => { + impl DaftAnyValueAggable for $arr { + type Output = DaftResult<$arr>; + + fn any_value(&self) -> Self::Output { + todo!() + } + + fn grouped_any_value(&self, _groups: &super::GroupIndices) -> Self::Output { + todo!() + } + } + }; +} + +impl_daft_any_value_nested_array!(FixedSizeListArray); +impl_daft_any_value_nested_array!(ListArray); +impl_daft_any_value_nested_array!(StructArray); diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index d9d66cfec6..17f9621317 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -1,4 +1,5 @@ mod abs; +mod any_value; mod apply; mod arange; mod arithmetic; @@ -134,6 +135,12 @@ pub trait DaftCompareAggable { fn grouped_max(&self, groups: &GroupIndices) -> Self::Output; } +pub trait DaftAnyValueAggable { + type Output; + fn any_value(&self) -> Self::Output; + fn grouped_any_value(&self, groups: &GroupIndices) -> Self::Output; +} + pub trait DaftListAggable { type Output; fn list(&self) -> Self::Output; diff --git a/src/daft-core/src/series/array_impl/data_array.rs b/src/daft-core/src/series/array_impl/data_array.rs index 00c8bd63ad..7f27e0b2eb 100644 --- a/src/daft-core/src/series/array_impl/data_array.rs +++ b/src/daft-core/src/series/array_impl/data_array.rs @@ -162,6 +162,16 @@ macro_rules! impl_series_like_for_data_array { } } + fn any_value(&self, groups: Option<&GroupIndices>) -> DaftResult { + use crate::array::ops::DaftAnyValueAggable; + match groups { + Some(groups) => { + Ok(DaftAnyValueAggable::grouped_any_value(&self.0, groups)?.into_series()) + } + None => Ok(DaftAnyValueAggable::any_value(&self.0)?.into_series()), + } + } + fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult { match groups { Some(groups) => Ok(self.0.grouped_list(groups)?.into_series()), diff --git a/src/daft-core/src/series/array_impl/logical_array.rs b/src/daft-core/src/series/array_impl/logical_array.rs index 51a2ecd39d..06ba91fcae 100644 --- a/src/daft-core/src/series/array_impl/logical_array.rs +++ b/src/daft-core/src/series/array_impl/logical_array.rs @@ -157,6 +157,16 @@ macro_rules! impl_series_like_for_logical_array { }; Ok($da::new(self.0.field.clone(), data_array).into_series()) } + fn any_value(&self, groups: Option<&GroupIndices>) -> DaftResult { + use crate::array::ops::DaftAnyValueAggable; + let data_array = match groups { + Some(groups) => { + DaftAnyValueAggable::grouped_any_value(&self.0.physical, groups)? + } + None => DaftAnyValueAggable::any_value(&self.0.physical)?, + }; + Ok($da::new(self.0.field.clone(), data_array).into_series()) + } fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult { use crate::array::{ops::DaftListAggable, ListArray}; let data_array = match groups { diff --git a/src/daft-core/src/series/array_impl/nested_array.rs b/src/daft-core/src/series/array_impl/nested_array.rs index 499a509bd3..2062690337 100644 --- a/src/daft-core/src/series/array_impl/nested_array.rs +++ b/src/daft-core/src/series/array_impl/nested_array.rs @@ -60,6 +60,10 @@ macro_rules! impl_series_like_for_nested_arrays { ))) } + fn any_value(&self, _groups: Option<&GroupIndices>) -> DaftResult { + todo!(); + } + fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult { use crate::array::ops::DaftListAggable; diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index c55295cb5e..aa71581f35 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -97,6 +97,10 @@ impl Series { self.inner.max(groups) } + pub fn any_value(&self, groups: Option<&GroupIndices>) -> DaftResult { + self.inner.any_value(groups) + } + pub fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult { self.inner.agg_list(groups) } diff --git a/src/daft-core/src/series/series_like.rs b/src/daft-core/src/series/series_like.rs index 32cf055993..6c767cdb51 100644 --- a/src/daft-core/src/series/series_like.rs +++ b/src/daft-core/src/series/series_like.rs @@ -16,6 +16,7 @@ pub trait SeriesLike: Send + Sync + Any + std::fmt::Debug { fn validity(&self) -> Option<&arrow2::bitmap::Bitmap>; fn min(&self, groups: Option<&GroupIndices>) -> DaftResult; fn max(&self, groups: Option<&GroupIndices>) -> DaftResult; + fn any_value(&self, groups: Option<&GroupIndices>) -> DaftResult; fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult; fn broadcast(&self, num: usize) -> DaftResult; fn cast(&self, datatype: &DataType) -> DaftResult; diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index a29a2b2871..14aae4f6cd 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -58,6 +58,7 @@ pub enum AggExpr { Mean(ExprRef), Min(ExprRef), Max(ExprRef), + AnyValue(ExprRef), List(ExprRef), Concat(ExprRef), MapGroups { @@ -87,6 +88,7 @@ impl AggExpr { | Mean(expr) | Min(expr) | Max(expr) + | AnyValue(expr) | List(expr) | Concat(expr) => expr.name(), MapGroups { func: _, inputs } => inputs.first().unwrap().name(), @@ -116,6 +118,10 @@ impl AggExpr { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_max()")) } + AnyValue(expr) => { + let child_id = expr.semantic_id(schema); + FieldID::new(format!("{child_id}.local_any_value()")) + } List(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_list()")) @@ -136,6 +142,7 @@ impl AggExpr { | Mean(expr) | Min(expr) | Max(expr) + | AnyValue(expr) | List(expr) | Concat(expr) => vec![expr.clone()], MapGroups { func: _, inputs } => inputs.iter().map(|e| e.clone().into()).collect(), @@ -196,7 +203,7 @@ impl AggExpr { }, )) } - Min(expr) | Max(expr) => { + Min(expr) | Max(expr) | AnyValue(expr) => { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), field.dtype)) } @@ -279,6 +286,10 @@ impl Expr { Expr::Agg(AggExpr::Max(self.clone().into())) } + pub fn any_value(&self) -> Self { + Expr::Agg(AggExpr::AnyValue(self.clone().into())) + } + pub fn agg_list(&self) -> Self { Expr::Agg(AggExpr::List(self.clone().into())) } @@ -604,6 +615,7 @@ impl Display for AggExpr { Mean(expr) => write!(f, "mean({expr})"), Min(expr) => write!(f, "min({expr})"), Max(expr) => write!(f, "max({expr})"), + AnyValue(expr) => write!(f, "any_value({expr})"), List(expr) => write!(f, "list({expr})"), Concat(expr) => write!(f, "list({expr})"), MapGroups { func, inputs } => function_display(f, func, inputs), diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index fdd8c09401..6d6ddc5266 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -173,6 +173,10 @@ impl PyExpr { Ok(self.expr.max().into()) } + pub fn any_value(&self) -> PyResult { + Ok(self.expr.any_value().into()) + } + pub fn agg_list(&self) -> PyResult { Ok(self.expr.agg_list().into()) } diff --git a/src/daft-dsl/src/treenode.rs b/src/daft-dsl/src/treenode.rs index 81452bcdae..437472836c 100644 --- a/src/daft-dsl/src/treenode.rs +++ b/src/daft-dsl/src/treenode.rs @@ -21,6 +21,7 @@ impl TreeNode for Expr { | Mean(expr) | Min(expr) | Max(expr) + | AnyValue(expr) | List(expr) | Concat(expr) => vec![expr.as_ref()], MapGroups { func: _, inputs } => inputs.iter().collect::>(), @@ -65,6 +66,7 @@ impl TreeNode for Expr { Mean(expr) => transform(expr.as_ref().clone())?.mean(), Min(expr) => transform(expr.as_ref().clone())?.min(), Max(expr) => transform(expr.as_ref().clone())?.max(), + AnyValue(expr) => transform(expr.as_ref().clone())?.any_value(), List(expr) => transform(expr.as_ref().clone())?.agg_list(), Concat(expr) => transform(expr.as_ref().clone())?.agg_concat(), MapGroups { func, inputs } => Expr::Agg(MapGroups { diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index db87148d11..052c6139bf 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -370,6 +370,10 @@ fn replace_column_with_semantic_id_aggexpr( replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) .map_yes_no(AggExpr::Max, |_| e.clone()) } + AggExpr::AnyValue(ref child) => { + replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) + .map_yes_no(AggExpr::AnyValue, |_| e.clone()) + } AggExpr::List(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) .map_yes_no(AggExpr::List, |_| e.clone()) diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index e5d999b2a3..09bf6b9494 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -441,6 +441,9 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftRe .into())); final_exprs.push(Column(max_of_max_id.clone()).alias(output_name)); } + AnyValue(_e) => { + todo!() + } List(e) => { let list_id = agg_expr.semantic_id(&schema).id; let concat_of_list_id = Concat(Column(list_id.clone()).into()) diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 0dcd16544c..c468080613 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -325,6 +325,7 @@ impl Table { Mean(expr) => Series::mean(&self.eval_expression(expr)?, groups), Min(expr) => Series::min(&self.eval_expression(expr)?, groups), Max(expr) => Series::max(&self.eval_expression(expr)?, groups), + AnyValue(expr) => Series::any_value(&self.eval_expression(expr)?, groups), List(expr) => Series::agg_list(&self.eval_expression(expr)?, groups), Concat(expr) => Series::agg_concat(&self.eval_expression(expr)?, groups), MapGroups { .. } => Err(DaftError::ValueError(