-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: correctly pass cap_scaled
to piecewise_logistic
#142
Conversation
Also add a benchmark & test for logistic predictions.
WalkthroughThe pull request introduces several modifications to the Changes
Possibly related PRs
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (3)
crates/augurs-prophet/src/prophet.rs (1)
Line range hint
142-179
: Consider adding a test for logistic predictions with scaled capacity.While the test coverage is comprehensive, consider adding a specific test case that verifies the correct usage of
cap_scaled
in logistic predictions. This would help prevent future regressions.Example test case:
#[test] fn test_logistic_prediction_with_scaled_capacity() { let mut data = daily_univariate_ts().head(468); let max = data.y.iter().copied().nanmax(true); data = data.with_cap(vec![max; 468]).unwrap(); let opts = ProphetOptions { growth: GrowthType::Logistic, ..ProphetOptions::default() }; let mut prophet = Prophet::new(opts, MockOptimizer::new()); prophet.fit(data.clone(), Default::default()).unwrap(); let predictions = prophet.predict(None).unwrap(); // Assert predictions respect capacity bounds for (pred, cap) in predictions.yhat.point.iter().zip(predictions.cap.iter()) { assert!(*pred <= *cap, "Prediction exceeds capacity"); } }crates/augurs-prophet/benches/prophet-linear.rs (2)
Line range hint
11-35
: Consider documenting the Prophet configuration choicesWhile the implementation is correct, adding comments explaining why specific values were chosen for
yearly_seasonality
,interval_width
, anduncertainty_samples
would improve maintainability.
Line range hint
93-1729
: Document test data characteristicsThe test data arrays would benefit from comments describing:
- The time range and frequency of the timestamps
- The characteristics of the training values (e.g., distribution, patterns)
- Why these specific expected values were chosen
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (5)
- crates/augurs-prophet/Cargo.toml (1 hunks)
- crates/augurs-prophet/benches/prophet-linear.rs (1 hunks)
- crates/augurs-prophet/benches/prophet-logistic.rs (1 hunks)
- crates/augurs-prophet/src/prophet.rs (1 hunks)
- justfile (1 hunks)
🔇 Additional comments (11)
justfile (1)
38-38
: LGTM! Test exclusion pattern correctly updated.The updated pattern properly excludes Prophet benchmarks that require STAN installation while allowing other tests to run. This aligns with the benchmark renaming changes and is well-documented in the comments.
crates/augurs-prophet/Cargo.toml (3)
54-56
: LGTM: Benchmark rename provides better clarity.The rename from "real-life" to "prophet-linear" better describes the benchmark's purpose and distinguishes it from the new logistic benchmark.
58-61
: LGTM: New logistic benchmark aligns with PR objectives.The addition of the "prophet-logistic" benchmark is consistent with the PR's goal of improving logistic predictions. The configuration correctly specifies the required features ("cmdstan", "compile-cmdstan").
Line range hint
63-65
: Verify the download-stan-model binary implementation.The binary configuration looks correct with appropriate feature gate. However, let's verify the implementation exists at the specified path.
✅ Verification successful
Let me check the actual implementation to ensure it's properly handling the download functionality.
Binary implementation is properly implemented and verified
The implementation at
crates/augurs-prophet/src/bin/main.rs
is well-structured and contains all necessary functionality for downloading and extracting the Stan model:
- Downloads Prophet wheel based on platform architecture
- Extracts
prophet_model.bin
andlibtbb-dc01d64d.so.2
files- Handles platform-specific paths and permissions
- Includes proper error handling and cleanup
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify the binary implementation exists and contains download functionality # Test 1: Check if the binary source file exists if ! [ -f "crates/augurs-prophet/src/bin/main.rs" ]; then echo "Binary source file not found at specified path" exit 1 fi # Test 2: Check if the file contains download-related code rg -l "download.*stan.*model" "crates/augurs-prophet/src/bin/main.rs"Length of output: 158
Script:
#!/bin/bash # Let's examine the actual implementation cat crates/augurs-prophet/src/bin/main.rsLength of output: 5879
crates/augurs-prophet/src/prophet.rs (2)
153-153
: LGTM: Using scaled capacity for trend prediction.The change correctly uses
cap_scaled
instead ofcap
for the trend prediction, which is necessary for accurate logistic predictions as the model works with scaled values internally.
Line range hint
142-179
: LGTM: Robust implementation of the predict method.The implementation includes:
- Proper error handling for unfitted models
- Comprehensive pattern matching to ensure all required components are available
- Correct handling of optional prediction data
crates/augurs-prophet/benches/prophet-linear.rs (2)
Line range hint
1-9
: LGTM: Imports are appropriate for benchmarkingThe imports cover all necessary components for benchmarking Prophet's linear predictions, including criterion framework and Prophet-related types.
Line range hint
11-89
: Verify coverage of cap_scaled parameter fixThe benchmarks test linear predictions but don't explicitly verify the
cap_scaled
parameter fix mentioned in the PR objectives. Consider adding a benchmark case that tests logistic predictions with thecap_scaled
parameter.crates/augurs-prophet/benches/prophet-logistic.rs (3)
1-9
: LGTM: Imports are appropriate for benchmarking.The imports include all necessary components for benchmarking Prophet with logistic growth, including criterion for benchmarking and testing utilities.
43-85
: LGTM: Comprehensive benchmark for logistic predictions with proper validation.The benchmark:
- Maintains consistent configuration with the fit benchmark
- Properly sets caps and floors for prediction data
- Validates predictions against expected values outside the timed section
10-41
: LGTM: Benchmark correctly tests logistic growth fitting.The benchmark properly initializes Prophet with logistic growth and sets up training data with caps and floors. The use of a fixed seed ensures reproducible results.
Given the PR objective to fix
cap_scaled
, let's verify the cap handling:✅ Verification successful
Based on the search results, I can now verify the cap handling in the benchmark. The code shows:
- The benchmark correctly sets up caps with
with_cap(vec![3.0; TRAINING_Y.len()])
.- The implementation in
prophet/prep.rs
properly handles cap scaling for logistic growth by:
- Checking if caps are provided when growth type is logistic
- Scaling caps appropriately using
(cap - floor) / scales.y_scale
- Validating that scaled caps are positive
Cap handling in the benchmark is correctly implemented.
The benchmark properly initializes logistic growth with caps, and the underlying implementation correctly handles cap scaling. The PR's objective to fix
cap_scaled
is well-tested by this benchmark since it exercises the full cap handling logic through the fitting process.🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check cap handling in Prophet model initialization and fitting # Look for cap-related code in the codebase to ensure consistent handling # Search for cap-related function calls and parameters rg -A 3 "with_cap|cap_scaled|piecewise_logistic.*cap" --type rustLength of output: 18675
fn predict(c: &mut Criterion) { | ||
let mut model = Prophet::new(Default::default(), CmdstanOptimizer::new_embedded()); | ||
let opts = ProphetOptions { | ||
yearly_seasonality: SeasonalityOption::Manual(false), | ||
interval_width: 0.8.try_into().unwrap(), | ||
uncertainty_samples: 500, | ||
..Default::default() | ||
}; | ||
let mut model = Prophet::new(opts, CmdstanOptimizer::new_embedded()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Extract Prophet configuration to reduce duplication
The Prophet configuration is duplicated between fit
and predict
functions. Consider extracting it to a helper function:
+fn create_prophet() -> Prophet<CmdstanOptimizer> {
+ let opts = ProphetOptions {
+ yearly_seasonality: SeasonalityOption::Manual(false),
+ interval_width: 0.8.try_into().unwrap(),
+ uncertainty_samples: 500,
+ ..Default::default()
+ };
+ Prophet::new(opts, CmdstanOptimizer::new_embedded())
+}
fn predict(c: &mut Criterion) {
- let opts = ProphetOptions {
- yearly_seasonality: SeasonalityOption::Manual(false),
- interval_width: 0.8.try_into().unwrap(),
- uncertainty_samples: 500,
- ..Default::default()
- };
- let mut model = Prophet::new(opts, CmdstanOptimizer::new_embedded());
+ let mut model = create_prophet();
Committable suggestion was skipped due to low confidence.
0.41158861163950583, | ||
0.4123324308268751, | ||
0.4129282199822072, | ||
0.41331454202336537, | ||
0.41343401384744843, | ||
0.41323524874374984, | ||
0.4126746503736042, | ||
0.4117180016525681, | ||
0.41034179762671186, | ||
0.40853427877283965, | ||
0.40629612986376584, | ||
0.40364081936431834, | ||
0.4005945649284569, | ||
0.3971959217276611, | ||
0.3934950016157257, | ||
0.38955234226846747, | ||
0.3854374559871378, | ||
0.3812270977048996, | ||
0.3770033002746346, | ||
0.3728512324229337, | ||
0.3688569403584598, | ||
0.3651050377992335, | ||
0.36167641116084315, | ||
0.3586460065410707, | ||
0.35608076310484416, | ||
0.35403775345971444, | ||
0.3525625858886795, | ||
0.35168811580573234, | ||
0.3514335049845033, | ||
0.3518036570608792, | ||
0.3527890469088338, | ||
0.35436595006051586, | ||
0.35649706667263836, | ||
0.35913252302124, | ||
0.35913252302124, | ||
0.3622112225040347, | ||
0.3622112225040347, | ||
0.36566250786077314, | ||
0.36566250786077314, | ||
0.3694080872598714, | ||
0.3694080872598714, | ||
0.3733641691429732, | ||
0.3733641691429732, | ||
0.3774437446207155, | ||
0.3774437446207155, | ||
0.3815589519416558, | ||
0.3856234551324679, | ||
0.38955476856013327, | ||
0.39327646077341105, | ||
0.3967201746410105, | ||
0.39982740615215323, | ||
0.40255099148752205, | ||
0.40485626053192486, | ||
0.4067218248489406, | ||
0.4081399789386908, | ||
0.40911670495494357, | ||
0.4096712827052326, | ||
0.40983551833515813, | ||
0.40965261624937105, | ||
0.4091757291971898, | ||
0.40846623079989636, | ||
0.40759176277938036, | ||
0.40662411557772543, | ||
0.4056370057127241, | ||
0.40470381597760274, | ||
0.4038953653927151, | ||
0.4032777746107686, | ||
0.4029104893180027, | ||
0.402844519190242, | ||
0.4031209432380617, | ||
0.4037697242046743, | ||
0.4048088652377497, | ||
0.4062439316661158, | ||
0.4080679497001019, | ||
0.41026168252964984, | ||
0.4127942729847681, | ||
0.41562423103409984, | ||
0.418700734138437, | ||
0.42196519926670895, | ||
0.42535307748265727, | ||
0.42879581552727936, | ||
0.4322229241735841, | ||
0.43556409022433523, | ||
0.4387512680668091, | ||
0.44172068768907424, | ||
0.44441471899091817, | ||
0.44678353687887584, | ||
0.4487865380231536, | ||
0.4503934680040993, | ||
0.4515852265693976, | ||
0.4523543287683867, | ||
0.45270501030533156, | ||
0.4526529763853887, | ||
0.4522248042327043, | ||
0.45145701997992993, | ||
0.45145701997992993, | ||
]; | ||
|
||
criterion_group!(benches, fit, predict); | ||
criterion_main!(benches); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider moving test data to separate files.
The static test data is well-structured but makes the benchmark file quite large. Consider:
- Moving the data to separate files (e.g., JSON or CSV)
- Loading test data at runtime for better maintainability
Example implementation:
use std::fs::File;
use serde::Deserialize;
#[derive(Deserialize)]
struct TestData {
training_ds: Vec<TimestampSeconds>,
training_y: Vec<f64>,
prediction_ds: Vec<TimestampSeconds>,
expected: Vec<f64>,
}
fn load_test_data() -> TestData {
let file = File::open("test_data/logistic_benchmark.json").unwrap();
serde_json::from_reader(file).unwrap()
}
Also add a benchmark & test for logistic predictions.
Summary by CodeRabbit
Release Notes
New Features
Improvements
Bug Fixes
Chores