Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Implement standard deviation #3005

Merged
merged 31 commits into from
Oct 8, 2024
Merged

[FEAT] Implement standard deviation #3005

merged 31 commits into from
Oct 8, 2024

Conversation

raunakab
Copy link
Contributor

@raunakab raunakab commented Oct 6, 2024

Overview

  • Add a standard deviation function
    • similar in implementation to how AggExpr::count and AggExpr::Mean work

Notes

Implementations differ slightly for non- vs multi- partitioned based dataframes:

  1. The non-partitioned implementation uses the simple, naive approach, derived from definition of stddev (i.e., stddev(X) = sqrt(sum((x_i - mean(X))^2) / N)).
  2. The multi-partitioned implementation calculates stddev(X) = sqrt(E(X^2) - E(X)^2).

- move all test mods into their own separate files
    - if `blah.rs` had a submodule `tests`, then `blah/mod.rs` would contain the original code and `blah/tests.rs` would contain the tests
- removed all enum imports
- ran `cargo clippy --fix ...`
@github-actions github-actions bot added the enhancement New feature or request label Oct 6, 2024
@raunakab raunakab requested review from colin-ho and jaychia October 6, 2024 18:37
Copy link

codspeed-hq bot commented Oct 6, 2024

CodSpeed Performance Report

Merging #3005 will not alter performance

Comparing feat/stddev (1874f43) with main (3f37a69)

Summary

✅ 17 untouched benchmarks

Copy link

codecov bot commented Oct 6, 2024

Codecov Report

Attention: Patch coverage is 86.98980% with 102 lines in your changes missing coverage. Please review.

Project coverage is 78.49%. Comparing base (3f37a69) to head (1874f43).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
src/daft-dsl/src/expr/mod.rs 79.31% 30 Missing ⚠️
src/daft-dsl/src/lit.rs 67.46% 27 Missing ⚠️
src/daft-plan/src/logical_ops/project.rs 0.00% 15 Missing ⚠️
src/daft-dsl/src/resolve_expr/mod.rs 60.86% 9 Missing ⚠️
src/daft-schema/src/dtype.rs 50.00% 4 Missing ⚠️
src/daft-table/src/lib.rs 20.00% 4 Missing ⚠️
src/daft-dsl/src/arithmetic/tests.rs 70.00% 3 Missing ⚠️
src/daft-dsl/src/expr/tests.rs 95.58% 3 Missing ⚠️
daft/dataframe/dataframe.py 60.00% 2 Missing ⚠️
daft/series.py 33.33% 2 Missing ⚠️
... and 2 more
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #3005      +/-   ##
==========================================
+ Coverage   78.43%   78.49%   +0.05%     
==========================================
  Files         603      609       +6     
  Lines       71504    71693     +189     
==========================================
+ Hits        56086    56273     +187     
- Misses      15418    15420       +2     
Files with missing lines Coverage Δ
daft/expressions/expressions.py 93.78% <100.00%> (+0.02%) ⬆️
src/daft-core/src/array/ops/mean.rs 100.00% <100.00%> (ø)
src/daft-core/src/array/ops/stddev.rs 100.00% <100.00%> (ø)
src/daft-core/src/datatypes/agg_ops.rs 100.00% <100.00%> (ø)
src/daft-core/src/datatypes/mod.rs 40.00% <ø> (ø)
src/daft-core/src/series/ops/agg.rs 75.59% <100.00%> (+3.72%) ⬆️
src/daft-dsl/src/arithmetic/mod.rs 100.00% <ø> (ø)
src/daft-dsl/src/functions/map/mod.rs 92.30% <100.00%> (ø)
src/daft-dsl/src/functions/mod.rs 84.21% <100.00%> (ø)
src/daft-dsl/src/functions/partitioning/mod.rs 100.00% <100.00%> (ø)
... and 23 more

... and 1 file with indirect coverage changes

@raunakab raunakab marked this pull request as ready for review October 8, 2024 15:33
@raunakab raunakab requested review from andrewgazelka and desmondcheongzx and removed request for colin-ho October 8, 2024 16:14
@raunakab raunakab requested a review from jaychia October 8, 2024 16:32
Copy link
Contributor

@colin-ho colin-ho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @universalmind303, does this work with sql? can i do 'SELECT stddev(value) FROM df GROUP BY group`

also, i may have missed it but could you clarify your top comment about how the implementation differs slightly for local and distributed?

src/daft-plan/src/physical_planner/translate.rs Outdated Show resolved Hide resolved
src/daft-dsl/src/functions/python/mod.rs Show resolved Hide resolved
src/daft-core/src/datatypes/agg_ops.rs Outdated Show resolved Hide resolved
src/daft-dsl/src/expr/mod.rs Outdated Show resolved Hide resolved
src/daft-core/src/array/ops/mean.rs Outdated Show resolved Hide resolved
src/daft-core/src/utils/stats.rs Outdated Show resolved Hide resolved
src/daft-core/src/array/ops/mean.rs Outdated Show resolved Hide resolved
src/daft-core/src/utils/stats.rs Show resolved Hide resolved
daft/dataframe/dataframe.py Show resolved Hide resolved
Raunak Bhagat added 4 commits October 8, 2024 12:22
…rtion on re-insertion of id

- it is possible to have an existing key already in the map; thus shouldn't panic
- keeping the count as a u64 would require casting to f64 in the loop, which leads to poor performance
    - instead store it as an f64 eagerly
Copy link
Contributor

@desmondcheongzx desmondcheongzx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two additional comments:

  1. I think we didn't update the docs?
  2. This is not welford's :P

tests/dataframe/test_stddev.py Outdated Show resolved Hide resolved
@raunakab
Copy link
Contributor Author

raunakab commented Oct 8, 2024

@colin-ho

The implementations slightly differ in their computations.

In the non-partitioned one, I just perform the straight shot stddev. This is essentially calculating the mean, and then calculating sum((x_i - mean)^2 ) / N, and then finally sqrt-ing that.

In the multi-partitioned one, doing the above approach requires a 3 stage agg (or some weird cardinalities being passed along). Therefore, I instead leverage the fact that the stddev formula can be expanded into stddev(X) = sqrt(E(X^2) - E(X)^2). Thus, in that situation, the first stage only requires me to compute the local sq.sum, the sum, and the count. The second stage requires the global version of all of that, and the final stage is a simple projection to calculate the final result using the previous aggs.

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Oct 8, 2024
@raunakab
Copy link
Contributor Author

raunakab commented Oct 8, 2024

@desmondcheongzx Thanks for the docs reminder! Updated docs in latest commit.

Copy link
Contributor

@desmondcheongzx desmondcheongzx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks!

// where X is the sub_expr.
//
// First stage, we compute `sum(X^2)`, `sum(X)` and `count(X)`.
// Second stage, we `global_sqsum := sum(sum(X^2))`, `global_sum := sum(sum(X))` and `global_count := sum(count(X))` in order to get the global versions of the first stage.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Second stage, we `global_sqsum := sum(sum(X^2))`, `global_sum := sum(sum(X))` and `global_count := sum(count(X))` in order to get the global versions of the first stage.
// Second stage, we `global_sum := sum(sum(X^2))`, `global_sum := sum(sum(X))` and `global_count := sum(count(X))` in order to get the global versions of the first stage.

//
// First stage, we compute `sum(X^2)`, `sum(X)` and `count(X)`.
// Second stage, we `global_sqsum := sum(sum(X^2))`, `global_sum := sum(sum(X))` and `global_count := sum(count(X))` in order to get the global versions of the first stage.
// In the final projection, we then compute `sqrt((global_sqsum / global_count) - (global_sum / global_count) ^ 2)`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// In the final projection, we then compute `sqrt((global_sqsum / global_count) - (global_sum / global_count) ^ 2)`.
// In the final projection, we then compute `sqrt((global_sum / global_count) - (global_sum / global_count) ^ 2)`.

@raunakab raunakab merged commit 64b8699 into main Oct 8, 2024
46 checks passed
@raunakab raunakab deleted the feat/stddev branch October 8, 2024 22:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants