-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chore: add benchmark for Prophet #140
Conversation
Not sure if this will pass in CI?
WalkthroughThe pull request introduces updates to the Changes
Possibly related PRs
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (9)
crates/augurs/tests/integration.rs (9)
1-2
: Enhance file-level documentation.Consider adding more comprehensive documentation that includes:
- Purpose and scope of these integration tests
- Prerequisites for running the tests
- Description of the feature flags and their dependencies
-//! Integration tests for the augurs wrapper crate. +//! Integration tests for the augurs wrapper crate. +//! +//! This file contains integration tests for various features of the augurs crate: +//! - Changepoint detection +//! - Clustering +//! - Dynamic Time Warping (DTW) +//! - Exponential Smoothing (ETS) +//! - Forecasting +//! - MSTL (Multiple Seasonal-Trend decomposition using LOESS) +//! - Outlier detection +//! - Seasonal detection +//! +//! Each test is gated behind a feature flag and requires the corresponding feature +//! to be enabled during testing.
Line range hint
4-24
: Document test data and expected results.The test uses synthetic data but lacks documentation explaining the data pattern and why specific indices are expected as changepoints.
#[cfg(feature = "changepoint")] #[test] fn test_changepoint() { use augurs::changepoint::{ArgpcpDetector, Detector}; + // Create synthetic data with two segments: + // - First segment (indices 0-33): constant value of 1.0 + // - Second segment (indices 34-67): constant value of 2.0 + // Expected changepoints: 0 (start) and 33 (transition point) let data = vec![ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, ]; let changepoints = ArgpcpDetector::builder().build().detect_changepoints(&data); - // 1 changepoint, but the start is considered a changepoint too. + // Verify both the start (0) and the transition point (33) are detected assert_eq!(changepoints, vec![0, 33]); }
Line range hint
26-63
: Improve test structure and documentation for clustering tests.The test uses different parameter combinations but lacks clear documentation about the test cases and expected outcomes.
Consider restructuring the test into separate test cases with clear documentation:
#[cfg(feature = "clustering")] -#[test] -fn test_clustering() { +mod clustering_tests { use augurs::{clustering::DbscanClusterer, DistanceMatrix}; - let distance_matrix = vec![ + + /// Creates a sample distance matrix for testing: + /// - Points 0 and 1 are close (distance 1.0) + /// - Point 2 is moderately distant (distance 2.0-3.0) + /// - Point 3 is far from all others (distance 3.0-4.0) + fn create_test_matrix() -> DistanceMatrix { + let matrix = vec![ vec![0.0, 1.0, 2.0, 3.0], vec![1.0, 0.0, 3.0, 3.0], vec![2.0, 3.0, 0.0, 4.0], vec![3.0, 3.0, 4.0, 0.0], - ]; - let distance_matrix = DistanceMatrix::try_from_square(distance_matrix).unwrap(); + ]; + DistanceMatrix::try_from_square(matrix).unwrap() + } + + #[test] + fn test_strict_clustering() { + // Test with strict parameters (eps=0.5, min_points=2) + // Expect no clusters due to strict distance threshold + let distance_matrix = create_test_matrix(); + let clusters = DbscanClusterer::new(0.5, 2).fit(&distance_matrix); + assert_eq!(clusters, vec![-1, -1, -1, -1], "All points should be noise"); + } + + #[test] + fn test_moderate_clustering() { + // Test with moderate parameters (eps=1.0, min_points=2) + // Expect points 0 and 1 to form a cluster + let distance_matrix = create_test_matrix(); + let clusters = DbscanClusterer::new(1.0, 2).fit(&distance_matrix); + assert_eq!(clusters, vec![0, 0, -1, -1], "Points 0,1 should form a cluster"); + } - let clusters = DbscanClusterer::new(0.5, 2).fit(&distance_matrix); - assert_eq!(clusters, vec![-1, -1, -1, -1]); - - let clusters = DbscanClusterer::new(1.0, 2).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, -1, -1]); - - let clusters = DbscanClusterer::new(1.0, 3).fit(&distance_matrix); - assert_eq!(clusters, vec![-1, -1, -1, -1]); - - let clusters = DbscanClusterer::new(2.0, 2).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 0, -1]); - - let clusters = DbscanClusterer::new(2.0, 3).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 0, -1]); - - let clusters = DbscanClusterer::new(3.0, 3).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 0, 0]); + // Add remaining test cases with similar documentation... }
Line range hint
65-95
: Add error handling and improve documentation for DTW tests.The tests cover good cases but could benefit from better documentation and error handling.
#[cfg(feature = "dtw")] #[test] fn test_dtw() { use augurs::dtw::Dtw; use augurs_testing::assert_approx_eq; + // Test DTW distance between two sequences with constant difference of 3.0 let result = Dtw::euclidean().distance(&[0.0, 1.0, 2.0], &[3.0, 4.0, 5.0]); assert_approx_eq!(result, 5.0990195135927845); + + // Test with empty sequences + assert!(Dtw::euclidean().distance(&[], &[]).is_nan()); } #[cfg(feature = "dtw")] #[test] fn test_dtw_distance_matrix() { use augurs::dtw::Dtw; use augurs_testing::assert_all_close; let dtw = Dtw::euclidean(); + // Test with equal length sequences let series: &[&[f64]] = &[&[0.0, 1.0, 2.0], &[3.0, 4.0, 5.0], &[6.0, 7.0, 8.0]]; let dists = dtw.distance_matrix(series); assert_eq!(dists.shape(), (3, 3)); assert_all_close(&dists[0], &[0.0, 5.0990195135927845, 10.392304845413264]); - // Test with different length series. + // Test with varying length sequences to verify DTW's ability to handle unequal lengths let dtw = Dtw::euclidean(); let series: &[&[f64]] = &[&[0.0, 1.0, 2.0], &[3.0], &[6.0, 7.0]]; let dists = dtw.distance_matrix(series); assert_eq!(dists.shape(), (3, 3)); assert_all_close(&dists[0], &[0.0, 3.7416573867739413, 9.273618495495704]); + + // Test with empty sequences + let series: &[&[f64]] = &[&[], &[]]; + let dists = dtw.distance_matrix(series); + assert_eq!(dists.shape(), (2, 2)); }
Line range hint
97-121
: Enhance error handling and test coverage for ETS model.The test verifies basic functionality but could benefit from better error handling and edge cases.
#[cfg(feature = "ets")] #[test] fn test_ets() { use augurs::{ ets::{ model::{ErrorComponent, SeasonalComponent, TrendComponent}, AutoETS, }, prelude::*, }; use augurs_testing::{assert_within_pct, data::AIR_PASSENGERS}; + // Test with valid data let auto = AutoETS::non_seasonal(); - let fit = auto.fit(AIR_PASSENGERS).expect("fit failed"); + let fit = auto.fit(AIR_PASSENGERS).unwrap_or_else(|e| { + panic!("Failed to fit model: {}", e); + }); + + // Verify model components assert_eq!( fit.model().model_type().error, ErrorComponent::Multiplicative ); assert_eq!(fit.model().model_type().trend, TrendComponent::Additive); assert_eq!(fit.model().model_type().season, SeasonalComponent::None); assert_within_pct!(fit.model().log_likelihood(), -831.4883541595792, 0.01); assert_within_pct!(fit.model().aic(), 1672.9767083191584, 0.01); + + // Test with edge cases + assert!(AutoETS::non_seasonal().fit(&[]).is_err(), "Empty data should fail"); + assert!( + AutoETS::non_seasonal().fit(&[f64::NAN]).is_err(), + "NaN data should fail" + ); }
Line range hint
123-166
: Remove debug statement and improve test documentation.The test includes a debug print statement and could benefit from better documentation of the transformation pipeline.
#[cfg(feature = "forecaster")] #[test] fn test_forecaster() { use augurs::{ forecaster::{transforms::MinMaxScaleParams, Forecaster, Transform}, mstl::{MSTLModel, NaiveTrend}, }; use augurs_testing::{assert_all_close, data::AIR_PASSENGERS}; use itertools::{Itertools, MinMaxResult}; + // Calculate data bounds for scaling let MinMaxResult::MinMax(min, max) = AIR_PASSENGERS .iter() .copied() .minmax_by(|a, b| a.partial_cmp(b).unwrap()) else { unreachable!() }; + + // Create transformation pipeline: + // 1. Linear interpolation for missing values + // 2. Min-max scaling to [0,1] range + // 3. Logit transform for bounded forecasting let transforms = vec![ Transform::linear_interpolator(), Transform::min_max_scaler(MinMaxScaleParams::new(min - 1e-3, max + 1e-3)), Transform::logit(), ]; + + // Create and fit the forecasting model let model = MSTLModel::new(vec![2], NaiveTrend::new()); let mut forecaster = Forecaster::new(model).with_transforms(transforms); forecaster.fit(AIR_PASSENGERS).unwrap(); + + // Generate and verify forecasts let forecasts = forecaster.predict(4, None).unwrap(); - dbg!(&forecasts.point); assert_all_close( &forecasts.point, &[ 559.0587706145459, 432.00000550710956, 559.0587706145459, 432.00000550710956, ], ); + + // Test error handling + assert!(forecaster.predict(0, None).is_err(), "Zero horizon should fail"); }
Line range hint
168-238
: Improve test structure and add validation for confidence intervals.The test covers comprehensive functionality but could benefit from better structure and additional validations.
#[cfg(feature = "mstl")] -#[test] -fn test_mstl() { +mod mstl_tests { use augurs::{ mstl::{stlrs, MSTLModel, NaiveTrend}, prelude::*, }; use augurs_testing::{assert_all_close, data::VIC_ELEC}; - let mut stl_params = stlrs::params(); - stl_params - .seasonal_degree(0) - .seasonal_jump(1) - .trend_degree(1) - .trend_jump(1) - .low_pass_degree(1) - .inner_loops(2) - .outer_loops(0); - let mut mstl_params = stlrs::MstlParams::new(); - mstl_params.stl_params(stl_params); - let periods = vec![24, 24 * 7]; - let trend_model = NaiveTrend::new(); - let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params); - let fit = mstl.fit(&VIC_ELEC).unwrap(); + fn create_test_model() -> MSTLModel { + let mut stl_params = stlrs::params(); + stl_params + .seasonal_degree(0) + .seasonal_jump(1) + .trend_degree(1) + .trend_jump(1) + .low_pass_degree(1) + .inner_loops(2) + .outer_loops(0); + let mut mstl_params = stlrs::MstlParams::new(); + mstl_params.stl_params(stl_params); + let periods = vec![24, 24 * 7]; // Daily and weekly seasonality + let trend_model = NaiveTrend::new(); + MSTLModel::new(periods, trend_model).mstl_params(mstl_params) + } + #[test] + fn test_in_sample_predictions() { + let mstl = create_test_model(); + let fit = mstl.fit(&VIC_ELEC).unwrap(); + let in_sample = fit.predict_in_sample(0.95).unwrap(); + + // Verify predictions length + assert_eq!(in_sample.point.len(), VIC_ELEC.len()); + + // Verify first 12 values against R implementation + let expected_in_sample = vec![ + f64::NAN, + 7952.216, + 7269.439, + 6878.110, + 6606.999, + 6402.581, + 6659.523, + 7457.488, + 8111.359, + 8693.762, + 9255.807, + 9870.213, + ]; + assert_all_close(&in_sample.point[..12], &expected_in_sample); + } + #[test] + fn test_out_of_sample_predictions() { + let mstl = create_test_model(); + let fit = mstl.fit(&VIC_ELEC).unwrap(); + let out_of_sample = fit.predict(10, 0.95).unwrap(); + + // Verify predictions + let expected_out_of_sample = vec![ + 8920.670, 8874.234, 8215.508, 7782.726, 7697.259, + 8216.241, 9664.907, 10914.452, 11536.929, 11664.737, + ]; + assert_all_close(&out_of_sample.point, &expected_out_of_sample); + + // Verify confidence intervals + let ForecastIntervals { lower, upper, .. } = out_of_sample.intervals.unwrap(); + assert_eq!(lower.len(), 10); + assert_eq!(upper.len(), 10); + + let expected_lower = vec![ + 8700.984, 8563.551, 7835.001, 7343.354, 7206.026, + 7678.122, 9083.672, 10293.087, 10877.871, 10970.029, + ]; + let expected_upper = vec![ + 9140.356, 9184.917, 8596.016, 8222.098, 8188.491, + 8754.359, 10246.141, 11535.818, 12195.987, 12359.445, + ]; + assert_all_close(&lower, &expected_lower); + assert_all_close(&upper, &expected_upper); + + // Verify intervals are properly ordered + for (l, u) in lower.iter().zip(upper.iter()) { + assert!(l < u, "Lower bound should be less than upper bound"); + } + } }
Line range hint
240-297
: Improve test data documentation and add edge cases.The outlier detection tests could benefit from better documentation and additional test cases.
#[cfg(feature = "outlier")] +mod outlier_tests { + use augurs::outlier::{DbscanDetector, MADDetector, OutlierDetector}; + + /// Creates test data with known outliers: + /// - First two series are similar (normal) + /// - Third series has outliers in its last two points + fn create_test_data() -> Vec<Vec<f64>> { + vec![ + vec![1.0, 2.0, 1.5, 2.3], + vec![1.9, 2.2, 1.2, 2.4], + vec![1.5, 2.1, 6.4, 8.5], // Contains outliers + ] + } + #[test] fn test_outlier_dbscan() { - use augurs::outlier::{DbscanDetector, OutlierDetector}; - let data: &[&[f64]] = &[ - &[1.0, 2.0, 1.5, 2.3], - &[1.9, 2.2, 1.2, 2.4], - &[1.5, 2.1, 6.4, 8.5], - ]; + let data = create_test_data(); + let data_refs: Vec<&[f64]> = data.iter().map(|v| v.as_slice()).collect(); + let detector = DbscanDetector::with_sensitivity(0.5).expect("sensitivity is between 0.0 and 1.0"); - let processed = detector.preprocess(data).unwrap(); + let processed = detector.preprocess(&data_refs).unwrap(); let outliers = detector.detect(&processed).unwrap(); + // Verify outlier detection assert_eq!(outliers.outlying_series.len(), 1); assert!(outliers.outlying_series.contains(&2)); assert!(outliers.series_results[2].is_outlier); assert_eq!(outliers.series_results[2].scores, vec![0.0, 0.0, 1.0, 1.0]); assert!(outliers.cluster_band.is_some()); + + // Test edge cases + let empty: Vec<&[f64]> = vec![]; + assert!(detector.preprocess(&empty).is_err(), "Empty data should fail"); } #[test] fn test_outlier_mad() { - use augurs::outlier::{MADDetector, OutlierDetector}; - let data: &[&[f64]] = &[ - &[1.0, 2.0, 1.5, 2.3], - &[1.9, 2.2, 1.2, 2.4], - &[1.5, 2.1, 6.4, 8.5], - ]; + let data = create_test_data(); + let data_refs: Vec<&[f64]> = data.iter().map(|v| v.as_slice()).collect(); + let detector = MADDetector::with_sensitivity(0.5).unwrap(); - let processed = detector.preprocess(data).unwrap(); + let processed = detector.preprocess(&data_refs).unwrap(); let outliers = detector.detect(&processed).unwrap(); + // Verify outlier detection assert_eq!(outliers.outlying_series.len(), 1); assert!(outliers.outlying_series.contains(&2)); assert!(outliers.series_results[2].is_outlier); assert_eq!( outliers.series_results[2].scores, vec![ 0.6835259767082061, 0.057793242408848366, 5.028012089569781, 7.4553282707414 ] ); assert!(outliers.cluster_band.is_some()); + + // Test with invalid sensitivity + assert!(MADDetector::with_sensitivity(-1.0).is_err()); + assert!(MADDetector::with_sensitivity(1.5).is_err()); } }
Line range hint
299-332
: Enhance seasonal detection test with better documentation and assertions.The test uses synthetic data but lacks documentation about the data pattern and could include more comprehensive assertions.
#[cfg(feature = "seasons")] #[test] fn test_seasonal() { use augurs::seasons::{Detector, PeriodogramDetector}; + // Create synthetic data with a period of 4: + // - Pattern repeats every 4 points + // - Contains 8 complete cycles #[rustfmt::skip] - let y = &[ - 0.1, 0.3, 0.8, 0.5, - 0.1, 0.31, 0.79, 0.48, - 0.09, 0.29, 0.81, 0.49, - 0.11, 0.28, 0.78, 0.53, - 0.1, 0.3, 0.8, 0.5, - 0.1, 0.31, 0.79, 0.48, - 0.09, 0.29, 0.81, 0.49, - 0.11, 0.28, 0.78, 0.53, - ]; + let y = &[ + 0.1, 0.3, 0.8, 0.5, // Cycle 1 + 0.1, 0.31, 0.79, 0.48, // Cycle 2 + 0.09, 0.29, 0.81, 0.49, // Cycle 3 + 0.11, 0.28, 0.78, 0.53, // Cycle 4 + 0.1, 0.3, 0.8, 0.5, // Cycle 5 + 0.1, 0.31, 0.79, 0.48, // Cycle 6 + 0.09, 0.29, 0.81, 0.49, // Cycle 7 + 0.11, 0.28, 0.78, 0.53, // Cycle 8 + ]; + + // Detect seasonal periods let periods = PeriodogramDetector::default().detect(y); - assert_eq!(periods[0], 4); + + // Verify detected periods + assert!(!periods.is_empty(), "Should detect at least one period"); + assert_eq!(periods[0], 4, "Primary period should be 4"); + + // Test edge cases + let empty: Vec<f64> = vec![]; + assert!(PeriodogramDetector::default().detect(&empty).is_empty()); + + // Test with non-seasonal data + let random = vec![1.0, 2.0, 3.0, 4.0]; + assert!(PeriodogramDetector::default().detect(&random).is_empty()); }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
- crates/augurs-prophet/Cargo.toml (1 hunks)
- crates/augurs/tests/integration.rs (1 hunks)
🔇 Additional comments (1)
crates/augurs-prophet/Cargo.toml (1)
32-32
: LGTM! The addition of tracing-subscriber is well-placed.The addition of
tracing-subscriber
as a development dependency is appropriate for supporting the new integration tests, and correctly uses workspace-level version management.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Outside diff range and nitpick comments (2)
crates/augurs-prophet/Cargo.toml (2)
31-31
: Consider limiting tracing-subscriber featuresWhile the dev-dependencies are appropriate, consider specifying only the required features for tracing-subscriber instead of using
default-features = true
. This can help reduce compilation time during development.-tracing-subscriber = { workspace = true, default-features = true } +tracing-subscriber = { workspace = true, default-features = false, features = ["fmt"] }Also applies to: 33-33
53-56
: Consider splitting benchmarks into a separate PRWhile the benchmarks are valuable, they seem tangential to the main PR objective of adding integration tests. Consider:
- Splitting the benchmark addition into a separate PR to maintain focused changes
- This would simplify CI configuration and reduce the risk of CI failures
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (5)
- crates/augurs-prophet/Cargo.toml (2 hunks)
- crates/augurs-prophet/benches/real-life.rs (1 hunks)
- crates/augurs-prophet/src/prophet.rs (1 hunks)
- crates/augurs-prophet/src/prophet/predict.rs (5 hunks)
- crates/augurs-prophet/src/prophet/prep.rs (1 hunks)
🔇 Additional comments (9)
crates/augurs-prophet/Cargo.toml (1)
50-51
: LGTM: Appropriate library configurationDisabling the default benchmark harness is correct since we're using criterion for custom benchmarks.
crates/augurs-prophet/src/prophet.rs (1)
21-21
: Consider performance implications of cloning.While adding
Clone
is necessary for integration testing, be mindful that cloning a trained Prophet model with large datasets could be memory-intensive due to deep copying of all internal buffers (training data, parameters, etc.).Let's check the size of the struct's fields:
Consider implementing a more memory-efficient approach if cloning becomes a bottleneck:
- Use reference counting (
Arc
) for sharing large immutable data- Implement a custom
Clone
that only copies necessary fields- Add benchmarks to measure cloning overhead
crates/augurs-prophet/src/prophet/prep.rs (1)
31-31
: LGTM: Adding Clone trait is appropriate.The addition of the
Clone
trait to theModes
struct is well-justified as both its fields (HashSet<ComponentName>
) already implementClone
. This change aligns with similar modifications in the codebase and supports the integration testing objectives.crates/augurs-prophet/benches/real-life.rs (3)
1-8
: LGTM: Appropriate imports for benchmarking ProphetThe imports cover all necessary components for benchmarking Prophet's performance, including criterion for benchmarking, Prophet types, and testing utilities.
10-35
: LGTM: Well-structured benchmark for model fittingThe benchmark is well-implemented with:
- Appropriate Prophet options for testing
- Batched benchmarking for accurate measurements
- Fixed seed (100) for reproducible results
- Proper cloning of model and training data between iterations
2110-2111
: LGTM: Proper criterion benchmark setupThe benchmark registration using
criterion_group!
andcriterion_main!
follows best practices for criterion benchmarks.crates/augurs-prophet/src/prophet/predict.rs (3)
412-417
: Efficient use of mutable buffers to reduce allocationsIntroducing mutable buffers
yhat
andtrend
to reuse in each iteration reduces memory allocations and can improve performance in thesample_posterior_predictive
method.
456-460
: Clear buffers before reuse insample_model
Clearing
yhat_tmp
andtrend_tmp
at the start ofsample_model
ensures that old data does not persist between iterations. This practice is essential for maintaining data integrity when reusing buffers.
484-487
: Verify alignment and lengths inyhat
computationEnsure that the iterators
trend_tmp
,xb_a
,xb_m
, andnoise
are of the same length and correctly aligned. Misalignment could lead to incorrect calculations or runtime panics due to out-of-bounds access.Consider adding assertions to confirm that all vectors have the same length:
assert_eq!(trend_tmp.len(), xb_a.len()); assert_eq!(xb_a.len(), xb_m.len()); assert_eq!(xb_m.len(), noise.len());This verification ensures that the
izip!
macro will not panic and that each element corresponds correctly during the computation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (5)
crates/augurs-ets/benches/air_passengers_iai.rs (2)
Line range hint
13-15
: Consider documenting the AutoETS parameters.While the implementation is correct, it would be helpful to document what the parameters
1
and"ZZN"
represent in theAutoETS::new(1, "ZZN")
call. This would make the benchmark more maintainable and easier to understand.fn auto_fit() { + // 1: seasonal period + // "ZZN": ZZN specification (error/trend/seasonal components) AutoETS::new(1, "ZZN").unwrap().fit(black_box(AP)).unwrap(); }
Line range hint
28-45
: Consider alternative approaches for prediction benchmarking.While
iai
doesn't support benchmark setup, you could consider using Criterion.rs for the prediction benchmarks, which does support setup through itsBencher
API. This would allow you to measure prediction performance while keeping theiai
benchmarks for the fitting operations.Example implementation with Criterion:
use criterion::{criterion_group, criterion_main, Criterion}; fn predict_benchmark(c: &mut Criterion) { let model = Unfit::new(ModelType { error: ErrorComponent::Additive, trend: TrendComponent::Additive, season: None, }) .damped(true) .fit(AP) .unwrap(); c.bench_function("predict", |b| { b.iter(|| model.predict(24, 0.95)) }); }crates/augurs-ets/benches/air_passengers.rs (1)
Line range hint
66-70
: Consider optimizing the profiler configuration.The current profiler configuration uses a high sampling rate (10000) which might be excessive for these benchmarks. Consider:
- Reducing the sampling rate for quicker benchmark runs
- Adding explicit warmup and measurement time configurations for more stable results
criterion_group! { name = benches; - config = Criterion::default().with_profiler(PProfProfiler::new(10000, Output::Protobuf)); + config = Criterion::default() + .with_profiler(PProfProfiler::new(1000, Output::Protobuf)) + .warm_up_time(std::time::Duration::from_secs(1)) + .measurement_time(std::time::Duration::from_secs(5)); targets = auto_fit, fit, forecast, }crates/augurs-dtw/benches/dtw.rs (2)
Line range hint
8-22
: Consider improving error handling and using a CSV parser.The current implementation has several potential issues:
- Uses
unwrap()
which could panic on malformed data- Manual CSV parsing is error-prone
- Pre-allocates vectors without knowing valid data count
Consider these improvements:
-fn examples() -> Vec<Vec<f64>> { +fn examples() -> Result<Vec<Vec<f64>>, Box<dyn std::error::Error>> { let raw = include_str!("../data/series.csv"); - let n_columns = raw.lines().next().unwrap().split(',').count(); - let n_rows = raw.lines().count(); - let mut examples = vec![Vec::with_capacity(n_rows); n_columns]; - for line in raw.lines() { - for (i, value) in line.split(',').enumerate() { - let value: f64 = value.parse().unwrap(); - if !value.is_nan() { - examples[i].push(value); - } - } - } - examples + let mut rdr = csv::Reader::from_reader(raw.as_bytes()); + let mut examples: Vec<Vec<f64>> = Vec::new(); + + for result in rdr.records() { + let record = result?; + for (i, value) in record.iter().enumerate() { + if i >= examples.len() { + examples.push(Vec::new()); + } + if let Ok(value) = value.parse::<f64>() { + if !value.is_nan() { + examples[i].push(value); + } + } + } + } + Ok(examples) }This would require adding the
csv
crate to your dependencies:[dev-dependencies] csv = "1.2"
Line range hint
23-43
: Add documentation for benchmark parameters.While the benchmark is well-structured, it would benefit from documentation explaining:
- The significance of the chosen window sizes
- The expected impact on performance
- The characteristics of the input data being used
Add documentation like this:
fn distance_euclidean(c: &mut Criterion) { let mut group = c.benchmark_group("distance_euclidean"); let examples = examples(); let (s, t) = (&examples[0], &examples[1]); + // Window sizes chosen to demonstrate performance characteristics: + // - None: unrestricted DTW + // - Small windows (2, 5): tight constraints, fastest performance + // - Medium windows (10, 20): balanced accuracy/performance + // - Large window (50): similar to unrestricted for this data let windows = [None, Some(2), Some(5), Some(10), Some(20), Some(50)];
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (7)
- crates/augurs-clustering/benches/dbscan.rs (1 hunks)
- crates/augurs-dtw/benches/dtw.rs (1 hunks)
- crates/augurs-ets/benches/air_passengers.rs (1 hunks)
- crates/augurs-ets/benches/air_passengers_iai.rs (1 hunks)
- crates/augurs-mstl/benches/vic_elec.rs (1 hunks)
- crates/augurs-mstl/benches/vic_elec_iai.rs (1 hunks)
- crates/augurs-seasons/benches/periodogram.rs (1 hunks)
✅ Files skipped from review due to trivial changes (4)
- crates/augurs-clustering/benches/dbscan.rs
- crates/augurs-mstl/benches/vic_elec.rs
- crates/augurs-mstl/benches/vic_elec_iai.rs
- crates/augurs-seasons/benches/periodogram.rs
🔇 Additional comments (7)
crates/augurs-ets/benches/air_passengers_iai.rs (2)
Line range hint
1-11
: LGTM! Imports and attributes are well-organized.The use of
#![allow(missing_docs)]
is appropriate for a benchmark file, and all necessary imports are present.
Line range hint
17-26
: LGTM! Model configuration is well-structured.The explicit model configuration with damped trend provides good coverage for benchmarking different ETS model variants.
crates/augurs-ets/benches/air_passengers.rs (2)
1-2
: LGTM! Suppressing doc warnings is acceptable for benchmark files.The
#![allow(missing_docs)]
attribute is appropriately used here since benchmark files typically don't require extensive documentation.
Line range hint
13-63
: Consider enhancing benchmark coverage and methodology.The benchmarks could be improved in several ways:
- The forecast benchmark uses a fixed horizon of 24. Consider parameterizing this to measure performance across different forecast horizons.
- The
SmallInput
batch size might not be optimal for measuring the fitting performance accurately.- Memory allocations during model fitting aren't being measured separately.
Let's verify the dataset size to determine if
SmallInput
is appropriate:Consider these improvements:
fn forecast(c: &mut Criterion) { let model = Unfit::new(ModelType { error: ErrorComponent::Additive, trend: TrendComponent::Additive, season: None, }) .damped(true) .fit(AP) .unwrap(); let mut group = c.benchmark_group("forecast"); - group.bench_function("air_passengers", |b| { - b.iter(|| { - model.predict(24, 0.95).unwrap(); - }) - }); + for horizon in [12, 24, 48] { + group.bench_function(format!("air_passengers_h{horizon}"), |b| { + b.iter(|| { + model.predict(horizon, 0.95).unwrap(); + }) + }); + } }✅ Verification successful
Based on the output, I can now generate the final response since I have enough information about the dataset size. The AIR_PASSENGERS dataset contains approximately 144 data points (12 years of monthly data from the visible portion), which is relatively small.
The
SmallInput
batch size is appropriate, but forecast benchmarks need enhancementThe current benchmark implementation is mostly sound:
SmallInput
batch size is appropriate given the dataset size (~144 points)- However, the forecast benchmark could be improved by:
- Testing multiple forecast horizons (e.g., 12, 24, 48 periods)
- Including confidence interval computation overhead in measurements
- Adding separate benchmarks for point forecasts vs interval forecasts
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Find and analyze the Air Passengers dataset size rg -A 5 "AIR_PASSENGERS" --type rustLength of output: 5190
crates/augurs-dtw/benches/dtw.rs (3)
Line range hint
1-7
: LGTM: Appropriate imports and setup for benchmarking.The file correctly sets up the necessary imports for criterion benchmarking and DTW functionality.
Line range hint
72-73
: LGTM: Proper criterion benchmark setup.The benchmark configuration is correctly set up using criterion's standard macros.
Line range hint
44-71
: Verify the choice of window sizes for matrix benchmarks.The matrix benchmark uses a more limited set of window sizes (2, 10) compared to the single distance benchmark. This might not provide a complete performance profile.
Let's check if there are any constraints that influenced this decision:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
justfile (1)
38-38
: Consider documenting how to run Prophet benchmarks locally.While it's appropriate to exclude Prophet benchmarks from CI due to STAN dependencies, it would be helpful to add documentation (either in the README or as a comment) explaining how developers can run these benchmarks locally when they have STAN installed.
Example comment to add above the
test-all
target:+# To run Prophet benchmarks locally (requires STAN installation): +# cargo bench --package augurs-prophet --bench real-life +
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- justfile (1 hunks)
🔇 Additional comments (1)
justfile (1)
29-30
: LGTM! Clear explanation of benchmark exclusions.The comment clearly explains why Prophet benchmarks are excluded from the test-all target.
Summary by CodeRabbit
New Features
augurs
crate, including changepoint detection, clustering, dynamic time warping, exponential smoothing, forecasting, MSTL, outlier detection, and seasonal detection.Prophet
model, measuring performance for fitting and predicting.Chores
Cargo.toml
file to includetracing-subscriber
andcriterion
as development dependencies.Prophet
struct to support cloning, enhancing usability.Modes
struct to support cloning as well.test-all
target of thejustfile
to exclude additional benchmarks.