Skip to content

Commit

Permalink
Implement structure for local and distributed stddev
Browse files Browse the repository at this point in the history
  • Loading branch information
Raunak Bhagat committed Oct 7, 2024
1 parent a53dfaa commit 2ae9d08
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 35 deletions.
7 changes: 7 additions & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ mod sketch_percentile;
mod sort;
pub(crate) mod sparse_tensor;
mod sqrt;
mod stddev;
mod struct_;
mod sum;
mod take;
Expand Down Expand Up @@ -189,6 +190,12 @@ pub trait DaftMeanAggable {
fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output;
}

pub trait DaftStddevAggable {
type Output;
fn stddev(&self) -> Self::Output;
fn grouped_stddev(&self, groups: &GroupIndices) -> Self::Output;
}

pub trait DaftCompareAggable {
type Output;
fn min(&self) -> Self::Output;
Expand Down
16 changes: 16 additions & 0 deletions src/daft-core/src/array/ops/stddev.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use common_error::DaftResult;

use super::{DaftStddevAggable, GroupIndices};
use crate::{array::DataArray, datatypes::Float64Type};

impl DaftStddevAggable for DataArray<Float64Type> {
type Output = DaftResult<Self>;

fn stddev(&self) -> Self::Output {
todo!("stddev")

Check warning on line 10 in src/daft-core/src/array/ops/stddev.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/stddev.rs#L9-L10

Added lines #L9 - L10 were not covered by tests
}

fn grouped_stddev(&self, _: &GroupIndices) -> Self::Output {
todo!("stddev")

Check warning on line 14 in src/daft-core/src/array/ops/stddev.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/stddev.rs#L13-L14

Added lines #L13 - L14 were not covered by tests
}
}
37 changes: 19 additions & 18 deletions src/daft-core/src/series/ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use logical::Decimal128Array;

use crate::{
array::{
ops::{DaftHllMergeAggable, GroupIndices},
ops::{DaftHllMergeAggable, DaftMeanAggable, DaftStddevAggable, GroupIndices},
ListArray,
},
count_mode::CountMode,
Expand Down Expand Up @@ -149,24 +149,25 @@ impl Series {
}

pub fn mean(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
use crate::{array::ops::DaftMeanAggable, datatypes::DataType::*};

// Upcast all numeric types to float64 and use f64 mean kernel.
match self.data_type() {
dt if dt.is_numeric() => {
let casted = self.cast(&Float64)?;
match groups {
Some(groups) => {
Ok(DaftMeanAggable::grouped_mean(&casted.f64()?, groups)?.into_series())
}
None => Ok(DaftMeanAggable::mean(&casted.f64()?)?.into_series()),
}
}
other => Err(DaftError::TypeError(format!(
"Numeric mean is not implemented for type {}",
other
))),
}
self.data_type().assert_is_numeric()?;
let casted = self.cast(&DataType::Float64)?;
let casted = casted.f64()?;
let series = groups
.map_or_else(|| casted.mean(), |groups| casted.grouped_mean(groups))?
.into_series();
Ok(series)
}

pub fn stddev(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
// Upcast all numeric types to float64 and use f64 stddev kernel.
self.data_type().assert_is_numeric()?;
let casted = self.cast(&DataType::Float64)?;
let casted = casted.f64()?;
let series = groups
.map_or_else(|| casted.stddev(), |groups| casted.grouped_stddev(groups))?
.into_series();
Ok(series)

Check warning on line 170 in src/daft-core/src/series/ops/agg.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/series/ops/agg.rs#L162-L170

Added lines #L162 - L170 were not covered by tests
}

pub fn min(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
Expand Down
31 changes: 15 additions & 16 deletions src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,24 +373,24 @@ fn replace_column_with_semantic_id_aggexpr(
AggExpr::Count(ref child, mode) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no(
|transformed_child| AggExpr::Count(transformed_child, mode),
|_| e.clone(),
|_| e,

Check warning on line 376 in src/daft-plan/src/logical_ops/project.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_ops/project.rs#L376

Added line #L376 was not covered by tests
)
}
AggExpr::Sum(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Sum, |_| e.clone())
.map_yes_no(AggExpr::Sum, |_| e)

Check warning on line 381 in src/daft-plan/src/logical_ops/project.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_ops/project.rs#L381

Added line #L381 was not covered by tests
}
AggExpr::ApproxPercentile(ApproxPercentileParams {
ref child,
ref percentiles,
ref force_list_output,
force_list_output,

Check warning on line 386 in src/daft-plan/src/logical_ops/project.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_ops/project.rs#L386

Added line #L386 was not covered by tests
}) => replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(
|transformed_child| {
AggExpr::ApproxPercentile(ApproxPercentileParams {
child: transformed_child,
percentiles: percentiles.clone(),
force_list_output: *force_list_output,
force_list_output,

Check warning on line 393 in src/daft-plan/src/logical_ops/project.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_ops/project.rs#L393

Added line #L393 was not covered by tests
})
},
|_| e.clone(),
Expand All @@ -402,45 +402,44 @@ fn replace_column_with_semantic_id_aggexpr(
AggExpr::ApproxSketch(ref child, sketch_type) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no(
|transformed_child| AggExpr::ApproxSketch(transformed_child, sketch_type),
|_| e.clone(),
|_| e,

Check warning on line 405 in src/daft-plan/src/logical_ops/project.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_ops/project.rs#L405

Added line #L405 was not covered by tests
)
}
AggExpr::MergeSketch(ref child, sketch_type) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no(
|transformed_child| AggExpr::MergeSketch(transformed_child, sketch_type),
|_| e.clone(),
|_| e,

Check warning on line 411 in src/daft-plan/src/logical_ops/project.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_ops/project.rs#L411

Added line #L411 was not covered by tests
)
}
AggExpr::Stddev(ref _child) => {
todo!("stddev")
// replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
// .map_yes_no(AggExpr::Mean, |_| e.clone())
AggExpr::Stddev(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Stddev, |_| e)

Check warning on line 416 in src/daft-plan/src/logical_ops/project.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_ops/project.rs#L414-L416

Added lines #L414 - L416 were not covered by tests
}
AggExpr::Mean(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Mean, |_| e.clone())
.map_yes_no(AggExpr::Mean, |_| e)

Check warning on line 420 in src/daft-plan/src/logical_ops/project.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_ops/project.rs#L420

Added line #L420 was not covered by tests
}
AggExpr::Min(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Min, |_| e.clone())
.map_yes_no(AggExpr::Min, |_| e)

Check warning on line 424 in src/daft-plan/src/logical_ops/project.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_ops/project.rs#L424

Added line #L424 was not covered by tests
}
AggExpr::Max(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Max, |_| e.clone())
.map_yes_no(AggExpr::Max, |_| e)

Check warning on line 428 in src/daft-plan/src/logical_ops/project.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_ops/project.rs#L428

Added line #L428 was not covered by tests
}
AggExpr::AnyValue(ref child, ignore_nulls) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no(
|transformed_child| AggExpr::AnyValue(transformed_child, ignore_nulls),
|_| e.clone(),
|_| e,

Check warning on line 433 in src/daft-plan/src/logical_ops/project.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_ops/project.rs#L433

Added line #L433 was not covered by tests
)
}
AggExpr::List(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::List, |_| e.clone())
.map_yes_no(AggExpr::List, |_| e)

Check warning on line 438 in src/daft-plan/src/logical_ops/project.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_ops/project.rs#L438

Added line #L438 was not covered by tests
}
AggExpr::Concat(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Concat, |_| e.clone())
.map_yes_no(AggExpr::Concat, |_| e)

Check warning on line 442 in src/daft-plan/src/logical_ops/project.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_ops/project.rs#L442

Added line #L442 was not covered by tests
}
AggExpr::MapGroups { func, inputs } => {
let transforms = inputs
Expand Down
12 changes: 12 additions & 0 deletions src/daft-schema/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,18 @@ impl DataType {
}
}

#[inline]
pub fn assert_is_numeric(&self) -> DaftResult<()> {
if self.is_numeric() {
Ok(())
} else {
Err(DaftError::TypeError(format!(
"Numeric mean is not implemented for type {}",
self,
)))

Check warning on line 378 in src/daft-schema/src/dtype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-schema/src/dtype.rs#L375-L378

Added lines #L375 - L378 were not covered by tests
}
}

#[inline]
pub fn is_fixed_size_numeric(&self) -> bool {
match self {
Expand Down
2 changes: 1 addition & 1 deletion src/daft-table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ impl Table {
}
}
AggExpr::Mean(expr) => self.eval_expression(expr)?.mean(groups),
AggExpr::Stddev(_expr) => todo!("stddev"),
AggExpr::Stddev(expr) => self.eval_expression(expr)?.stddev(groups),

Check warning on line 483 in src/daft-table/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-table/src/lib.rs#L483

Added line #L483 was not covered by tests
AggExpr::Min(expr) => self.eval_expression(expr)?.min(groups),
AggExpr::Max(expr) => self.eval_expression(expr)?.max(groups),
&AggExpr::AnyValue(ref expr, ignore_nulls) => {
Expand Down

0 comments on commit 2ae9d08

Please sign in to comment.