Skip to content

Commit

Permalink
Update grouped-mean impl to use stats
Browse files Browse the repository at this point in the history
  • Loading branch information
Raunak Bhagat committed Oct 8, 2024
1 parent 0c976a4 commit 65e443a
Showing 1 changed file with 4 additions and 17 deletions.
21 changes: 4 additions & 17 deletions src/daft-core/src/array/ops/mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ use arrow2::array::PrimitiveArray;
use common_error::DaftResult;

use crate::{
array::ops::{
as_arrow::AsArrow, DaftCountAggable, DaftMeanAggable, DaftSumAggable, GroupIndices,
},
count_mode::CountMode,
array::ops::{DaftMeanAggable, GroupIndices},
datatypes::*,
utils::stats,
};
Expand All @@ -23,18 +20,8 @@ impl DaftMeanAggable for DataArray<Float64Type> {
}

fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output {
let sum_values = self.grouped_sum(groups)?;
let count_values = self.grouped_count(groups, CountMode::Valid)?;
assert_eq!(sum_values.len(), count_values.len());
let mean_per_group = sum_values
.as_arrow()
.values_iter()
.zip(count_values.as_arrow().values_iter())
.map(|(s, c)| match (s, c) {
(_, 0) => None,
(s, c) => Some(s / (*c as f64)),
});
let mean_array = Box::new(PrimitiveArray::from_trusted_len_iter(mean_per_group));
Ok(Self::from((self.field.name.as_ref(), mean_array)))
let grouped_means = stats::grouped_stats(self, groups)?.map(|(stats, _)| stats.mean);
let data = Box::new(PrimitiveArray::from_iter(grouped_means));
Ok(Self::from((self.field.name.as_ref(), data)))
}
}

0 comments on commit 65e443a

Please sign in to comment.