diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index d3a940f376..3bcf0f0cb9 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -49,6 +49,7 @@ mod sketch_percentile; mod sort; pub(crate) mod sparse_tensor; mod sqrt; +mod stddev; mod struct_; mod sum; mod take; @@ -189,6 +190,12 @@ pub trait DaftMeanAggable { fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output; } +pub trait DaftStddevAggable { + type Output; + fn stddev(&self) -> Self::Output; + fn grouped_stddev(&self, groups: &GroupIndices) -> Self::Output; +} + pub trait DaftCompareAggable { type Output; fn min(&self) -> Self::Output; diff --git a/src/daft-core/src/array/ops/stddev.rs b/src/daft-core/src/array/ops/stddev.rs new file mode 100644 index 0000000000..cf73fe6ff5 --- /dev/null +++ b/src/daft-core/src/array/ops/stddev.rs @@ -0,0 +1,16 @@ +use common_error::DaftResult; + +use super::{DaftStddevAggable, GroupIndices}; +use crate::{array::DataArray, datatypes::Float64Type}; + +impl DaftStddevAggable for DataArray { + type Output = DaftResult; + + fn stddev(&self) -> Self::Output { + todo!("stddev") + } + + fn grouped_stddev(&self, _: &GroupIndices) -> Self::Output { + todo!("stddev") + } +} diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index 541fe5c556..79cfaa484a 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -4,7 +4,7 @@ use logical::Decimal128Array; use crate::{ array::{ - ops::{DaftHllMergeAggable, GroupIndices}, + ops::{DaftHllMergeAggable, DaftMeanAggable, DaftStddevAggable, GroupIndices}, ListArray, }, count_mode::CountMode, @@ -149,24 +149,25 @@ impl Series { } pub fn mean(&self, groups: Option<&GroupIndices>) -> DaftResult { - use crate::{array::ops::DaftMeanAggable, datatypes::DataType::*}; - // Upcast all numeric types to float64 and use f64 mean kernel. - match self.data_type() { - dt if dt.is_numeric() => { - let casted = self.cast(&Float64)?; - match groups { - Some(groups) => { - Ok(DaftMeanAggable::grouped_mean(&casted.f64()?, groups)?.into_series()) - } - None => Ok(DaftMeanAggable::mean(&casted.f64()?)?.into_series()), - } - } - other => Err(DaftError::TypeError(format!( - "Numeric mean is not implemented for type {}", - other - ))), - } + self.data_type().assert_is_numeric()?; + let casted = self.cast(&DataType::Float64)?; + let casted = casted.f64()?; + let series = groups + .map_or_else(|| casted.mean(), |groups| casted.grouped_mean(groups))? + .into_series(); + Ok(series) + } + + pub fn stddev(&self, groups: Option<&GroupIndices>) -> DaftResult { + // Upcast all numeric types to float64 and use f64 stddev kernel. + self.data_type().assert_is_numeric()?; + let casted = self.cast(&DataType::Float64)?; + let casted = casted.f64()?; + let series = groups + .map_or_else(|| casted.stddev(), |groups| casted.grouped_stddev(groups))? + .into_series(); + Ok(series) } pub fn min(&self, groups: Option<&GroupIndices>) -> DaftResult { diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 80cd923dd9..e3e956cdaf 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -373,24 +373,24 @@ fn replace_column_with_semantic_id_aggexpr( AggExpr::Count(ref child, mode) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::Count(transformed_child, mode), - |_| e.clone(), + |_| e, ) } AggExpr::Sum(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Sum, |_| e.clone()) + .map_yes_no(AggExpr::Sum, |_| e) } AggExpr::ApproxPercentile(ApproxPercentileParams { ref child, ref percentiles, - ref force_list_output, + force_list_output, }) => replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) .map_yes_no( |transformed_child| { AggExpr::ApproxPercentile(ApproxPercentileParams { child: transformed_child, percentiles: percentiles.clone(), - force_list_output: *force_list_output, + force_list_output, }) }, |_| e.clone(), @@ -402,45 +402,44 @@ fn replace_column_with_semantic_id_aggexpr( AggExpr::ApproxSketch(ref child, sketch_type) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::ApproxSketch(transformed_child, sketch_type), - |_| e.clone(), + |_| e, ) } AggExpr::MergeSketch(ref child, sketch_type) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::MergeSketch(transformed_child, sketch_type), - |_| e.clone(), + |_| e, ) } - AggExpr::Stddev(ref _child) => { - todo!("stddev") - // replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - // .map_yes_no(AggExpr::Mean, |_| e.clone()) + AggExpr::Stddev(ref child) => { + replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) + .map_yes_no(AggExpr::Stddev, |_| e) } AggExpr::Mean(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Mean, |_| e.clone()) + .map_yes_no(AggExpr::Mean, |_| e) } AggExpr::Min(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Min, |_| e.clone()) + .map_yes_no(AggExpr::Min, |_| e) } AggExpr::Max(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Max, |_| e.clone()) + .map_yes_no(AggExpr::Max, |_| e) } AggExpr::AnyValue(ref child, ignore_nulls) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::AnyValue(transformed_child, ignore_nulls), - |_| e.clone(), + |_| e, ) } AggExpr::List(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::List, |_| e.clone()) + .map_yes_no(AggExpr::List, |_| e) } AggExpr::Concat(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Concat, |_| e.clone()) + .map_yes_no(AggExpr::Concat, |_| e) } AggExpr::MapGroups { func, inputs } => { let transforms = inputs diff --git a/src/daft-schema/src/dtype.rs b/src/daft-schema/src/dtype.rs index 65cf8f808e..2461aa6287 100644 --- a/src/daft-schema/src/dtype.rs +++ b/src/daft-schema/src/dtype.rs @@ -367,6 +367,18 @@ impl DataType { } } + #[inline] + pub fn assert_is_numeric(&self) -> DaftResult<()> { + if self.is_numeric() { + Ok(()) + } else { + Err(DaftError::TypeError(format!( + "Numeric mean is not implemented for type {}", + self, + ))) + } + } + #[inline] pub fn is_fixed_size_numeric(&self) -> bool { match self { diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 148b02f32f..eff28c6a26 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -480,7 +480,7 @@ impl Table { } } AggExpr::Mean(expr) => self.eval_expression(expr)?.mean(groups), - AggExpr::Stddev(_expr) => todo!("stddev"), + AggExpr::Stddev(expr) => self.eval_expression(expr)?.stddev(groups), AggExpr::Min(expr) => self.eval_expression(expr)?.min(groups), AggExpr::Max(expr) => self.eval_expression(expr)?.max(groups), &AggExpr::AnyValue(ref expr, ignore_nulls) => {