Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correctly pass cap_scaled to piecewise_logistic #142

Merged
merged 1 commit into from
Oct 25, 2024

Conversation

sd2k
Copy link
Collaborator

@sd2k sd2k commented Oct 25, 2024

Also add a benchmark & test for logistic predictions.

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced a new benchmark for the Prophet model called "prophet-logistic".
    • Added a binary for downloading Stan models.
  • Improvements

    • Enhanced configurability of the Prophet model with specific initialization options.
    • Improved error handling in the prediction process and refined future date generation logic.
  • Bug Fixes

    • Corrected the use of capacity data in the prediction method.
  • Chores

    • Updated test command configurations to include/exclude specific binaries.

Also add a benchmark & test for logistic predictions.
@sd2k sd2k enabled auto-merge (squash) October 25, 2024 12:24
Copy link
Contributor

coderabbitai bot commented Oct 25, 2024

Walkthrough

The pull request introduces several modifications to the augurs-prophet package. Key changes include the renaming of an existing benchmark and the addition of a new benchmark in the Cargo.toml file, along with the introduction of a binary for downloading Stan models. The benchmarking code has been updated to enhance the configurability of the Prophet model, and new benchmark functions have been added. Additionally, the predict method in the Prophet struct has been modified to improve data handling and error management.

Changes

File Path Change Summary
crates/augurs-prophet/Cargo.toml Renamed benchmark from "real-life" to "prophet-linear", added new benchmark "prophet-logistic", and specified binary "download-stan-model" with its path.
crates/augurs-prophet/benches/prophet-linear.rs Updated predict function to initialize Prophet model with specific ProphetOptions instead of default parameters.
crates/augurs-prophet/benches/prophet-logistic.rs Added new benchmarks fit and predict functions for the Prophet model, including static datasets for training and prediction.
crates/augurs-prophet/src/prophet.rs Modified predict method to use &df.cap_scaled and updated make_future_dataframe for better date handling and error management.
justfile Updated test-all target command to change exclusion patterns for cargo nextest run.

Possibly related PRs

🐇 In the meadow, we hop and play,
New benchmarks bloom, bright as day.
Prophet's model, now refined,
With options set, predictions aligned.
From "real-life" to "linear" we cheer,
In our code garden, progress is clear! 🌼


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (3)
crates/augurs-prophet/src/prophet.rs (1)

Line range hint 142-179: Consider adding a test for logistic predictions with scaled capacity.

While the test coverage is comprehensive, consider adding a specific test case that verifies the correct usage of cap_scaled in logistic predictions. This would help prevent future regressions.

Example test case:

#[test]
fn test_logistic_prediction_with_scaled_capacity() {
    let mut data = daily_univariate_ts().head(468);
    let max = data.y.iter().copied().nanmax(true);
    data = data.with_cap(vec![max; 468]).unwrap();
    
    let opts = ProphetOptions {
        growth: GrowthType::Logistic,
        ..ProphetOptions::default()
    };
    let mut prophet = Prophet::new(opts, MockOptimizer::new());
    prophet.fit(data.clone(), Default::default()).unwrap();
    let predictions = prophet.predict(None).unwrap();
    
    // Assert predictions respect capacity bounds
    for (pred, cap) in predictions.yhat.point.iter().zip(predictions.cap.iter()) {
        assert!(*pred <= *cap, "Prediction exceeds capacity");
    }
}
crates/augurs-prophet/benches/prophet-linear.rs (2)

Line range hint 11-35: Consider documenting the Prophet configuration choices

While the implementation is correct, adding comments explaining why specific values were chosen for yearly_seasonality, interval_width, and uncertainty_samples would improve maintainability.


Line range hint 93-1729: Document test data characteristics

The test data arrays would benefit from comments describing:

  • The time range and frequency of the timestamps
  • The characteristics of the training values (e.g., distribution, patterns)
  • Why these specific expected values were chosen
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 53ae827 and ffe4ab7.

📒 Files selected for processing (5)
  • crates/augurs-prophet/Cargo.toml (1 hunks)
  • crates/augurs-prophet/benches/prophet-linear.rs (1 hunks)
  • crates/augurs-prophet/benches/prophet-logistic.rs (1 hunks)
  • crates/augurs-prophet/src/prophet.rs (1 hunks)
  • justfile (1 hunks)
🔇 Additional comments (11)
justfile (1)

38-38: LGTM! Test exclusion pattern correctly updated.

The updated pattern properly excludes Prophet benchmarks that require STAN installation while allowing other tests to run. This aligns with the benchmark renaming changes and is well-documented in the comments.

crates/augurs-prophet/Cargo.toml (3)

54-56: LGTM: Benchmark rename provides better clarity.

The rename from "real-life" to "prophet-linear" better describes the benchmark's purpose and distinguishes it from the new logistic benchmark.


58-61: LGTM: New logistic benchmark aligns with PR objectives.

The addition of the "prophet-logistic" benchmark is consistent with the PR's goal of improving logistic predictions. The configuration correctly specifies the required features ("cmdstan", "compile-cmdstan").


Line range hint 63-65: Verify the download-stan-model binary implementation.

The binary configuration looks correct with appropriate feature gate. However, let's verify the implementation exists at the specified path.

✅ Verification successful

Let me check the actual implementation to ensure it's properly handling the download functionality.


Binary implementation is properly implemented and verified

The implementation at crates/augurs-prophet/src/bin/main.rs is well-structured and contains all necessary functionality for downloading and extracting the Stan model:

  • Downloads Prophet wheel based on platform architecture
  • Extracts prophet_model.bin and libtbb-dc01d64d.so.2 files
  • Handles platform-specific paths and permissions
  • Includes proper error handling and cleanup
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the binary implementation exists and contains download functionality

# Test 1: Check if the binary source file exists
if ! [ -f "crates/augurs-prophet/src/bin/main.rs" ]; then
    echo "Binary source file not found at specified path"
    exit 1
fi

# Test 2: Check if the file contains download-related code
rg -l "download.*stan.*model" "crates/augurs-prophet/src/bin/main.rs"

Length of output: 158


Script:

#!/bin/bash
# Let's examine the actual implementation
cat crates/augurs-prophet/src/bin/main.rs

Length of output: 5879

crates/augurs-prophet/src/prophet.rs (2)

153-153: LGTM: Using scaled capacity for trend prediction.

The change correctly uses cap_scaled instead of cap for the trend prediction, which is necessary for accurate logistic predictions as the model works with scaled values internally.


Line range hint 142-179: LGTM: Robust implementation of the predict method.

The implementation includes:

  • Proper error handling for unfitted models
  • Comprehensive pattern matching to ensure all required components are available
  • Correct handling of optional prediction data
crates/augurs-prophet/benches/prophet-linear.rs (2)

Line range hint 1-9: LGTM: Imports are appropriate for benchmarking

The imports cover all necessary components for benchmarking Prophet's linear predictions, including criterion framework and Prophet-related types.


Line range hint 11-89: Verify coverage of cap_scaled parameter fix

The benchmarks test linear predictions but don't explicitly verify the cap_scaled parameter fix mentioned in the PR objectives. Consider adding a benchmark case that tests logistic predictions with the cap_scaled parameter.

crates/augurs-prophet/benches/prophet-logistic.rs (3)

1-9: LGTM: Imports are appropriate for benchmarking.

The imports include all necessary components for benchmarking Prophet with logistic growth, including criterion for benchmarking and testing utilities.


43-85: LGTM: Comprehensive benchmark for logistic predictions with proper validation.

The benchmark:

  • Maintains consistent configuration with the fit benchmark
  • Properly sets caps and floors for prediction data
  • Validates predictions against expected values outside the timed section

10-41: LGTM: Benchmark correctly tests logistic growth fitting.

The benchmark properly initializes Prophet with logistic growth and sets up training data with caps and floors. The use of a fixed seed ensures reproducible results.

Given the PR objective to fix cap_scaled, let's verify the cap handling:

✅ Verification successful

Based on the search results, I can now verify the cap handling in the benchmark. The code shows:

  1. The benchmark correctly sets up caps with with_cap(vec![3.0; TRAINING_Y.len()]).
  2. The implementation in prophet/prep.rs properly handles cap scaling for logistic growth by:
    • Checking if caps are provided when growth type is logistic
    • Scaling caps appropriately using (cap - floor) / scales.y_scale
    • Validating that scaled caps are positive

Cap handling in the benchmark is correctly implemented.

The benchmark properly initializes logistic growth with caps, and the underlying implementation correctly handles cap scaling. The PR's objective to fix cap_scaled is well-tested by this benchmark since it exercises the full cap handling logic through the fitting process.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check cap handling in Prophet model initialization and fitting
# Look for cap-related code in the codebase to ensure consistent handling

# Search for cap-related function calls and parameters
rg -A 3 "with_cap|cap_scaled|piecewise_logistic.*cap" --type rust

Length of output: 18675

Comment on lines 37 to +44
fn predict(c: &mut Criterion) {
let mut model = Prophet::new(Default::default(), CmdstanOptimizer::new_embedded());
let opts = ProphetOptions {
yearly_seasonality: SeasonalityOption::Manual(false),
interval_width: 0.8.try_into().unwrap(),
uncertainty_samples: 500,
..Default::default()
};
let mut model = Prophet::new(opts, CmdstanOptimizer::new_embedded());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Extract Prophet configuration to reduce duplication

The Prophet configuration is duplicated between fit and predict functions. Consider extracting it to a helper function:

+fn create_prophet() -> Prophet<CmdstanOptimizer> {
+    let opts = ProphetOptions {
+        yearly_seasonality: SeasonalityOption::Manual(false),
+        interval_width: 0.8.try_into().unwrap(),
+        uncertainty_samples: 500,
+        ..Default::default()
+    };
+    Prophet::new(opts, CmdstanOptimizer::new_embedded())
+}

 fn predict(c: &mut Criterion) {
-    let opts = ProphetOptions {
-        yearly_seasonality: SeasonalityOption::Manual(false),
-        interval_width: 0.8.try_into().unwrap(),
-        uncertainty_samples: 500,
-        ..Default::default()
-    };
-    let mut model = Prophet::new(opts, CmdstanOptimizer::new_embedded());
+    let mut model = create_prophet();

Committable suggestion was skipped due to low confidence.

Comment on lines +87 to +2134
0.41158861163950583,
0.4123324308268751,
0.4129282199822072,
0.41331454202336537,
0.41343401384744843,
0.41323524874374984,
0.4126746503736042,
0.4117180016525681,
0.41034179762671186,
0.40853427877283965,
0.40629612986376584,
0.40364081936431834,
0.4005945649284569,
0.3971959217276611,
0.3934950016157257,
0.38955234226846747,
0.3854374559871378,
0.3812270977048996,
0.3770033002746346,
0.3728512324229337,
0.3688569403584598,
0.3651050377992335,
0.36167641116084315,
0.3586460065410707,
0.35608076310484416,
0.35403775345971444,
0.3525625858886795,
0.35168811580573234,
0.3514335049845033,
0.3518036570608792,
0.3527890469088338,
0.35436595006051586,
0.35649706667263836,
0.35913252302124,
0.35913252302124,
0.3622112225040347,
0.3622112225040347,
0.36566250786077314,
0.36566250786077314,
0.3694080872598714,
0.3694080872598714,
0.3733641691429732,
0.3733641691429732,
0.3774437446207155,
0.3774437446207155,
0.3815589519416558,
0.3856234551324679,
0.38955476856013327,
0.39327646077341105,
0.3967201746410105,
0.39982740615215323,
0.40255099148752205,
0.40485626053192486,
0.4067218248489406,
0.4081399789386908,
0.40911670495494357,
0.4096712827052326,
0.40983551833515813,
0.40965261624937105,
0.4091757291971898,
0.40846623079989636,
0.40759176277938036,
0.40662411557772543,
0.4056370057127241,
0.40470381597760274,
0.4038953653927151,
0.4032777746107686,
0.4029104893180027,
0.402844519190242,
0.4031209432380617,
0.4037697242046743,
0.4048088652377497,
0.4062439316661158,
0.4080679497001019,
0.41026168252964984,
0.4127942729847681,
0.41562423103409984,
0.418700734138437,
0.42196519926670895,
0.42535307748265727,
0.42879581552727936,
0.4322229241735841,
0.43556409022433523,
0.4387512680668091,
0.44172068768907424,
0.44441471899091817,
0.44678353687887584,
0.4487865380231536,
0.4503934680040993,
0.4515852265693976,
0.4523543287683867,
0.45270501030533156,
0.4526529763853887,
0.4522248042327043,
0.45145701997992993,
0.45145701997992993,
];

criterion_group!(benches, fit, predict);
criterion_main!(benches);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider moving test data to separate files.

The static test data is well-structured but makes the benchmark file quite large. Consider:

  1. Moving the data to separate files (e.g., JSON or CSV)
  2. Loading test data at runtime for better maintainability

Example implementation:

use std::fs::File;
use serde::Deserialize;

#[derive(Deserialize)]
struct TestData {
    training_ds: Vec<TimestampSeconds>,
    training_y: Vec<f64>,
    prediction_ds: Vec<TimestampSeconds>,
    expected: Vec<f64>,
}

fn load_test_data() -> TestData {
    let file = File::open("test_data/logistic_benchmark.json").unwrap();
    serde_json::from_reader(file).unwrap()
}

@sd2k sd2k merged commit 2db5f63 into main Oct 25, 2024
22 checks passed
@sd2k sd2k deleted the fix-logistic-predictions branch October 25, 2024 12:35
@sd2k sd2k mentioned this pull request Oct 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant