diff --git a/src/daft-core/src/array/ops/mean.rs b/src/daft-core/src/array/ops/mean.rs index 09307d3dbc..d5764c4954 100644 --- a/src/daft-core/src/array/ops/mean.rs +++ b/src/daft-core/src/array/ops/mean.rs @@ -4,10 +4,7 @@ use arrow2::array::PrimitiveArray; use common_error::DaftResult; use crate::{ - array::ops::{ - as_arrow::AsArrow, DaftCountAggable, DaftMeanAggable, DaftSumAggable, GroupIndices, - }, - count_mode::CountMode, + array::ops::{DaftMeanAggable, GroupIndices}, datatypes::*, utils::stats, }; @@ -23,18 +20,8 @@ impl DaftMeanAggable for DataArray { } fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output { - let sum_values = self.grouped_sum(groups)?; - let count_values = self.grouped_count(groups, CountMode::Valid)?; - assert_eq!(sum_values.len(), count_values.len()); - let mean_per_group = sum_values - .as_arrow() - .values_iter() - .zip(count_values.as_arrow().values_iter()) - .map(|(s, c)| match (s, c) { - (_, 0) => None, - (s, c) => Some(s / (*c as f64)), - }); - let mean_array = Box::new(PrimitiveArray::from_trusted_len_iter(mean_per_group)); - Ok(Self::from((self.field.name.as_ref(), mean_array))) + let grouped_means = stats::grouped_stats(self, groups)?.map(|(stats, _)| stats.mean); + let data = Box::new(PrimitiveArray::from_iter(grouped_means)); + Ok(Self::from((self.field.name.as_ref(), data))) } }