From 9dc5ce596f16a81e16feb0281ce05517e58a0ebf Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Mon, 11 Nov 2024 11:45:01 +0000 Subject: [PATCH 1/4] feat: add Forecaster wrapper for Prophet Since the `Forecaster` traits don't pass timestamps etc, we need to provide the training data at the time the forecaster is created and pass _that_ down to the Prophet model, after replacing the `y` column with whatever the forecaster gives us. This should work with transforms too. It does require cloning the data, unfortunately. The alternative approach which I tried first was to use an associated type for the data, but that is pretty infectious and ends up requiring the input data to be mutable so that we can replace the `y` column, which ends up looking pretty ugly. --- crates/augurs-prophet/Cargo.toml | 1 + crates/augurs-prophet/src/forecaster.rs | 197 ++++++++++++++++++++++++ crates/augurs-prophet/src/lib.rs | 1 + crates/augurs-prophet/src/optimizer.rs | 13 +- crates/augurs-prophet/src/prophet.rs | 13 +- 5 files changed, 222 insertions(+), 3 deletions(-) create mode 100644 crates/augurs-prophet/src/forecaster.rs diff --git a/crates/augurs-prophet/Cargo.toml b/crates/augurs-prophet/Cargo.toml index f8159d84..f8a128c9 100644 --- a/crates/augurs-prophet/Cargo.toml +++ b/crates/augurs-prophet/Cargo.toml @@ -28,6 +28,7 @@ include = [ [dependencies] anyhow.workspace = true +augurs-core.workspace = true bytemuck = { workspace = true, features = ["derive"], optional = true } itertools.workspace = true num-traits.workspace = true diff --git a/crates/augurs-prophet/src/forecaster.rs b/crates/augurs-prophet/src/forecaster.rs new file mode 100644 index 00000000..ed27bdf7 --- /dev/null +++ b/crates/augurs-prophet/src/forecaster.rs @@ -0,0 +1,197 @@ +//! [`Fit`] and [`Predict`] implementations for the Prophet algorithm. +use std::{cell::RefCell, num::NonZeroU32, sync::Arc}; + +use augurs_core::{Fit, ModelError, Predict}; + +use crate::{ + optimizer::OptimizeOpts, Error, IncludeHistory, Optimizer, Prophet, ProphetOptions, + TrainingData, +}; + +impl ModelError for Error {} + +/// A forecaster that uses the Prophet algorithm. +/// +/// This is a wrapper around the [`Prophet`] struct that provides +/// a simpler API for fitting and predicting. Notably it implements +/// the [`Fit`] trait from `augurs_core`, so it can be +/// used with the `augurs` framework (e.g. with the `Forecaster` struct +/// in the `augurs::forecaster` module). +#[derive(Debug)] +pub struct ProphetForecaster { + data: TrainingData, + opts: ProphetOptions, + optimizer: Arc, + optimize_opts: OptimizeOpts, +} + +impl ProphetForecaster { + /// Create a new Prophet forecaster. + /// + /// # Parameters + /// + /// - `opts`: The options to use for fitting the model. + /// - `optimizer`: The optimizer to use for fitting the model. + /// - `optimize_opts`: The options to use for optimizing the model. + pub fn new( + data: TrainingData, + mut opts: ProphetOptions, + optimizer: Arc, + optimize_opts: OptimizeOpts, + ) -> Self { + if opts.uncertainty_samples == 0 { + opts.uncertainty_samples = 1000; + } + Self { + data, + opts, + optimizer, + optimize_opts, + } + } +} + +impl Fit for ProphetForecaster { + type Fitted = FittedProphetForecaster; + type Error = Error; + + fn fit(&self, y: &[f64]) -> Result { + let mut training_data = self.data.clone(); + training_data.y = y.to_vec(); + let mut model = Prophet::new(self.opts.clone(), self.optimizer.clone()); + model.fit(training_data, self.optimize_opts.clone())?; + Ok(FittedProphetForecaster { + model: RefCell::new(model), + training_n: y.len(), + }) + } +} + +/// A fitted Prophet forecaster. +#[derive(Debug)] +pub struct FittedProphetForecaster { + model: RefCell>>, + training_n: usize, +} + +impl Predict for FittedProphetForecaster { + type Error = Error; + + fn predict_in_sample_inplace( + &self, + level: Option, + forecast: &mut augurs_core::Forecast, + ) -> Result<(), Self::Error> { + if let Some(level) = level { + self.model + .borrow_mut() + .set_interval_width(level.try_into()?); + } + let predictions = self.model.borrow().predict(None)?; + forecast.point = predictions.yhat.point; + if let Some(intervals) = forecast.intervals.as_mut() { + intervals.lower = predictions + .yhat + .lower + // This `expect` is OK because we've set uncertainty_samples > 0 in the + // `ProphetForecaster` constructor. + .expect("uncertainty_samples should be > 0"); + intervals.upper = predictions + .yhat + .upper + // This `expect` is OK because we've set uncertainty_samples > 0 in the + // `ProphetForecaster` constructor. + .expect("uncertainty_samples should be > 0"); + } + Ok(()) + } + + fn predict_inplace( + &self, + horizon: usize, + level: Option, + forecast: &mut augurs_core::Forecast, + ) -> Result<(), Self::Error> { + if horizon == 0 { + return Ok(()); + } + if let Some(level) = level { + self.model + .borrow_mut() + .set_interval_width(level.try_into()?); + } + let predictions = { + let model = self.model.borrow(); + let prediction_data = model.make_future_dataframe( + NonZeroU32::try_from(horizon as u32).expect("horizon should be > 0"), + IncludeHistory::No, + )?; + model.predict(prediction_data)? + }; + forecast.point = predictions.yhat.point; + if let Some(intervals) = forecast.intervals.as_mut() { + intervals.lower = predictions + .yhat + .lower + // This `expect` is OK because we've set uncertainty_samples > 0 in the + // `ProphetForecaster` constructor. + .expect("uncertainty_samples should be > 0"); + intervals.upper = predictions + .yhat + .upper + // This `expect` is OK because we've set uncertainty_samples > 0 in the + // `ProphetForecaster` constructor. + .expect("uncertainty_samples should be > 0"); + } + Ok(()) + } + + fn training_data_size(&self) -> usize { + self.training_n + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use augurs_core::{Fit, Predict}; + use augurs_testing::assert_all_close; + + use crate::{ + testdata::{daily_univariate_ts, train_test_splitn}, + wasmstan::WasmstanOptimizer, + IncludeHistory, Prophet, + }; + + use super::ProphetForecaster; + + #[test] + fn forecaster() { + let test_days = 30; + let (train, _) = train_test_splitn(daily_univariate_ts(), test_days); + + let forecaster = ProphetForecaster::new( + train.clone(), + Default::default(), + Arc::new(WasmstanOptimizer::new()), + Default::default(), + ); + let fitted = forecaster.fit(&train.y).unwrap(); + let forecast_predictions = fitted.predict(30, 0.95).unwrap(); + + let mut prophet = Prophet::new(Default::default(), WasmstanOptimizer::new()); + prophet.fit(train, Default::default()).unwrap(); + let prediction_data = prophet + .make_future_dataframe(30.try_into().unwrap(), IncludeHistory::No) + .unwrap(); + let predictions = prophet.predict(prediction_data).unwrap(); + + // We should get the same results back when using the Forecaster impl. + assert_eq!( + predictions.yhat.point.len(), + forecast_predictions.point.len() + ); + assert_all_close(&predictions.yhat.point, &forecast_predictions.point); + } +} diff --git a/crates/augurs-prophet/src/lib.rs b/crates/augurs-prophet/src/lib.rs index 457bb2ac..1c085172 100644 --- a/crates/augurs-prophet/src/lib.rs +++ b/crates/augurs-prophet/src/lib.rs @@ -2,6 +2,7 @@ mod data; mod error; mod features; +pub mod forecaster; // Export the optimizer module so that users can implement their own // optimizers. pub mod optimizer; diff --git a/crates/augurs-prophet/src/optimizer.rs b/crates/augurs-prophet/src/optimizer.rs index 17268be1..99899a0c 100644 --- a/crates/augurs-prophet/src/optimizer.rs +++ b/crates/augurs-prophet/src/optimizer.rs @@ -21,7 +21,7 @@ // WASM Components? // TODO: write a pure Rust optimizer for the default case. -use std::fmt; +use std::{fmt, sync::Arc}; use crate::positive_float::PositiveFloat; @@ -299,6 +299,17 @@ pub trait Optimizer: std::fmt::Debug { ) -> Result; } +impl Optimizer for Arc { + fn optimize( + &self, + init: &InitialParams, + data: &Data, + opts: &OptimizeOpts, + ) -> Result { + (**self).optimize(init, data, opts) + } +} + #[cfg(test)] pub(crate) mod mock_optimizer { use std::cell::RefCell; diff --git a/crates/augurs-prophet/src/prophet.rs b/crates/augurs-prophet/src/prophet.rs index 9cb5ff64..954f49a2 100644 --- a/crates/augurs-prophet/src/prophet.rs +++ b/crates/augurs-prophet/src/prophet.rs @@ -13,8 +13,8 @@ use prep::{ComponentColumns, Modes, Preprocessed, Scales}; use crate::{ optimizer::{InitialParams, OptimizeOpts, OptimizedParams, Optimizer}, - Error, EstimationMode, FeaturePrediction, IncludeHistory, PredictionData, Predictions, - Regressor, Seasonality, TimestampSeconds, TrainingData, + Error, EstimationMode, FeaturePrediction, IncludeHistory, IntervalWidth, PredictionData, + Predictions, Regressor, Seasonality, TimestampSeconds, TrainingData, }; /// The Prophet time series forecasting model. @@ -233,6 +233,15 @@ impl Prophet { Ok(PredictionData::new(ds)) } + /// Set the width of the uncertainty intervals. + /// + /// The interval width does not affect training, only predictions, + /// so this can be called after fitting the model to obtain predictions + /// with different levels of uncertainty. + pub fn set_interval_width(&mut self, interval_width: IntervalWidth) { + self.opts.interval_width = interval_width; + } + fn infer_freq(history_dates: &[TimestampSeconds]) -> Result { const INFER_N: usize = 5; let get_tried = || { From 9f76c817d803eb197616e1bd26f9ac69adf5198c Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Tue, 10 Dec 2024 10:51:21 +0000 Subject: [PATCH 2/4] Simplify API for Prophet forecaster; add example --- crates/augurs-prophet/src/forecaster.rs | 36 ++++++------- crates/augurs-prophet/src/prophet.rs | 44 +++++++++++++++ .../examples/prophet_forecaster.rs | 53 +++++++++++++++++++ 3 files changed, 112 insertions(+), 21 deletions(-) create mode 100644 examples/forecasting/examples/prophet_forecaster.rs diff --git a/crates/augurs-prophet/src/forecaster.rs b/crates/augurs-prophet/src/forecaster.rs index ed27bdf7..81906dfc 100644 --- a/crates/augurs-prophet/src/forecaster.rs +++ b/crates/augurs-prophet/src/forecaster.rs @@ -3,10 +3,7 @@ use std::{cell::RefCell, num::NonZeroU32, sync::Arc}; use augurs_core::{Fit, ModelError, Predict}; -use crate::{ - optimizer::OptimizeOpts, Error, IncludeHistory, Optimizer, Prophet, ProphetOptions, - TrainingData, -}; +use crate::{optimizer::OptimizeOpts, Error, IncludeHistory, Optimizer, Prophet, TrainingData}; impl ModelError for Error {} @@ -20,8 +17,7 @@ impl ModelError for Error {} #[derive(Debug)] pub struct ProphetForecaster { data: TrainingData, - opts: ProphetOptions, - optimizer: Arc, + model: Prophet>, optimize_opts: OptimizeOpts, } @@ -33,19 +29,18 @@ impl ProphetForecaster { /// - `opts`: The options to use for fitting the model. /// - `optimizer`: The optimizer to use for fitting the model. /// - `optimize_opts`: The options to use for optimizing the model. - pub fn new( + pub fn new( + mut model: Prophet, data: TrainingData, - mut opts: ProphetOptions, - optimizer: Arc, optimize_opts: OptimizeOpts, ) -> Self { + let opts = model.opts_mut(); if opts.uncertainty_samples == 0 { opts.uncertainty_samples = 1000; } Self { data, - opts, - optimizer, + model: model.into_dyn_optimizer(), optimize_opts, } } @@ -56,12 +51,16 @@ impl Fit for ProphetForecaster { type Error = Error; fn fit(&self, y: &[f64]) -> Result { + // Use the training data from `self`... let mut training_data = self.data.clone(); + // ...but replace the `y` column with whatever we're passed + // (which may be a transformed version of `y`, if the user is + // using `augurs_forecaster`). training_data.y = y.to_vec(); - let mut model = Prophet::new(self.opts.clone(), self.optimizer.clone()); - model.fit(training_data, self.optimize_opts.clone())?; + let mut fitted_model = self.model.clone(); + fitted_model.fit(training_data, self.optimize_opts.clone())?; Ok(FittedProphetForecaster { - model: RefCell::new(model), + model: RefCell::new(fitted_model), training_n: y.len(), }) } @@ -153,7 +152,6 @@ impl Predict for FittedProphetForecaster { #[cfg(test)] mod test { - use std::sync::Arc; use augurs_core::{Fit, Predict}; use augurs_testing::assert_all_close; @@ -171,12 +169,8 @@ mod test { let test_days = 30; let (train, _) = train_test_splitn(daily_univariate_ts(), test_days); - let forecaster = ProphetForecaster::new( - train.clone(), - Default::default(), - Arc::new(WasmstanOptimizer::new()), - Default::default(), - ); + let model = Prophet::new(Default::default(), WasmstanOptimizer::new()); + let forecaster = ProphetForecaster::new(model, train.clone(), Default::default()); let fitted = forecaster.fit(&train.y).unwrap(); let forecast_predictions = fitted.predict(30, 0.95).unwrap(); diff --git a/crates/augurs-prophet/src/prophet.rs b/crates/augurs-prophet/src/prophet.rs index 954f49a2..62f928ed 100644 --- a/crates/augurs-prophet/src/prophet.rs +++ b/crates/augurs-prophet/src/prophet.rs @@ -5,6 +5,7 @@ pub(crate) mod prep; use std::{ collections::{HashMap, HashSet}, num::NonZeroU32, + sync::Arc, }; use itertools::{izip, Itertools}; @@ -12,6 +13,7 @@ use options::ProphetOptions; use prep::{ComponentColumns, Modes, Preprocessed, Scales}; use crate::{ + forecaster::ProphetForecaster, optimizer::{InitialParams, OptimizeOpts, OptimizedParams, Optimizer}, Error, EstimationMode, FeaturePrediction, IncludeHistory, IntervalWidth, PredictionData, Predictions, Regressor, Seasonality, TimestampSeconds, TrainingData, @@ -233,6 +235,16 @@ impl Prophet { Ok(PredictionData::new(ds)) } + /// Get a reference to the Prophet options. + pub fn opts(&self) -> &ProphetOptions { + &self.opts + } + + /// Get a mutable reference to the Prophet options. + pub fn opts_mut(&mut self) -> &mut ProphetOptions { + &mut self.opts + } + /// Set the width of the uncertainty intervals. /// /// The interval width does not affect training, only predictions, @@ -277,6 +289,38 @@ impl Prophet { } } +impl Prophet { + pub(crate) fn into_dyn_optimizer(self) -> Prophet> { + Prophet { + optimizer: Arc::new(self.optimizer), + opts: self.opts, + regressors: self.regressors, + optimized: self.optimized, + changepoints: self.changepoints, + changepoints_t: self.changepoints_t, + init: self.init, + scales: self.scales, + processed: self.processed, + seasonalities: self.seasonalities, + component_modes: self.component_modes, + train_holiday_names: self.train_holiday_names, + train_component_columns: self.train_component_columns, + } + } + + /// Create a new `ProphetForecaster` from this Prophet model. + /// + /// This requires the data and optimize options to be provided and sets up + /// a `ProphetForecaster` ready to be used with the `augurs_forecaster` crate. + pub fn into_forecaster( + self, + data: TrainingData, + optimize_opts: OptimizeOpts, + ) -> ProphetForecaster { + ProphetForecaster::new(self, data, optimize_opts) + } +} + impl Prophet { /// Fit the Prophet model to some training data. pub fn fit(&mut self, data: TrainingData, mut opts: OptimizeOpts) -> Result<(), Error> { diff --git a/examples/forecasting/examples/prophet_forecaster.rs b/examples/forecasting/examples/prophet_forecaster.rs new file mode 100644 index 00000000..5d0b826b --- /dev/null +++ b/examples/forecasting/examples/prophet_forecaster.rs @@ -0,0 +1,53 @@ +//! Example of using the Prophet model with the wasmstan optimizer. + +use augurs::{ + forecaster::{transforms::MinMaxScaleParams, Forecaster, Transform}, + prophet::{wasmstan::WasmstanOptimizer, Prophet, TrainingData}, +}; + +fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); + tracing::info!("Running Prophet example"); + + let ds = vec![ + 1704067200, 1704871384, 1705675569, 1706479753, 1707283938, 1708088123, 1708892307, + 1709696492, 1710500676, 1711304861, 1712109046, 1712913230, 1713717415, + ]; + let y = vec![ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, + ]; + let data = TrainingData::new(ds, y.clone())?; + + // Set up the transforms. + // These are just illustrative examples; you can use whatever transforms + // you want. + let transforms = vec![Transform::min_max_scaler(MinMaxScaleParams::from_data( + y.iter().copied(), + ))]; + + // Set up the model. Create the Prophet model as normal, then convert it to a + // `ProphetForecaster`. + let prophet = Prophet::new(Default::default(), WasmstanOptimizer::new()); + let prophet_forecaster = prophet.into_forecaster(data.clone(), Default::default()); + + // Finally create a Forecaster using those transforms. + let mut forecaster = Forecaster::new(prophet_forecaster).with_transforms(transforms); + + // Fit the forecaster. This will transform the training data by + // running the transforms in order, then fit the Prophet model. + forecaster.fit(&y).expect("model should fit"); + + // Generate some in-sample predictions with 95% prediction intervals. + // The forecaster will handle back-transforming them onto our original scale. + let predictions = forecaster.predict_in_sample(0.95)?; + assert_eq!(predictions.point.len(), y.len()); + assert!(predictions.intervals.is_some()); + println!("In-sample predictions: {:?}", predictions); + + // Generate 10 out-of-sample predictions with 95% prediction intervals. + let predictions = forecaster.predict(10, 0.95)?; + assert_eq!(predictions.point.len(), 10); + assert!(predictions.intervals.is_some()); + println!("Out-of-sample predictions: {:?}", predictions); + Ok(()) +} From 4cabe527aa7c9da6d2142223a95f2fd1395b6bf2 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Tue, 10 Dec 2024 11:25:49 +0000 Subject: [PATCH 3/4] Improve edge case handling in Prophet forecaster --- crates/augurs-prophet/src/forecaster.rs | 17 ++++++++--------- crates/augurs-prophet/src/optimizer.rs | 3 +++ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/crates/augurs-prophet/src/forecaster.rs b/crates/augurs-prophet/src/forecaster.rs index 81906dfc..a1de3b81 100644 --- a/crates/augurs-prophet/src/forecaster.rs +++ b/crates/augurs-prophet/src/forecaster.rs @@ -94,13 +94,13 @@ impl Predict for FittedProphetForecaster { .lower // This `expect` is OK because we've set uncertainty_samples > 0 in the // `ProphetForecaster` constructor. - .expect("uncertainty_samples should be > 0"); + .expect("uncertainty_samples should be > 0, this is a bug"); intervals.upper = predictions .yhat .upper // This `expect` is OK because we've set uncertainty_samples > 0 in the // `ProphetForecaster` constructor. - .expect("uncertainty_samples should be > 0"); + .expect("uncertainty_samples should be > 0, this is a bug"); } Ok(()) } @@ -111,9 +111,11 @@ impl Predict for FittedProphetForecaster { level: Option, forecast: &mut augurs_core::Forecast, ) -> Result<(), Self::Error> { - if horizon == 0 { - return Ok(()); - } + let horizon = match NonZeroU32::try_from(horizon as u32) { + Ok(h) => h, + // If horizon is 0, short circuit without even trying to predict. + Err(_) => return Ok(()), + }; if let Some(level) = level { self.model .borrow_mut() @@ -121,10 +123,7 @@ impl Predict for FittedProphetForecaster { } let predictions = { let model = self.model.borrow(); - let prediction_data = model.make_future_dataframe( - NonZeroU32::try_from(horizon as u32).expect("horizon should be > 0"), - IncludeHistory::No, - )?; + let prediction_data = model.make_future_dataframe(horizon, IncludeHistory::No)?; model.predict(prediction_data)? }; forecast.point = predictions.yhat.point; diff --git a/crates/augurs-prophet/src/optimizer.rs b/crates/augurs-prophet/src/optimizer.rs index 99899a0c..06c582bd 100644 --- a/crates/augurs-prophet/src/optimizer.rs +++ b/crates/augurs-prophet/src/optimizer.rs @@ -299,6 +299,9 @@ pub trait Optimizer: std::fmt::Debug { ) -> Result; } +/// An implementation of `Optimize` which simply delegates to the +/// `Arc`'s inner type. This enables thread-safe sharing of optimizers +/// while maintaining the ability to use dynamic dispatch. impl Optimizer for Arc { fn optimize( &self, From 78f1b3381fed76011c8a6d08f7357cb24518c8e7 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Tue, 10 Dec 2024 11:32:07 +0000 Subject: [PATCH 4/4] Add note on default uncertainty samples in ProphetForecaster --- crates/augurs-prophet/src/forecaster.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/augurs-prophet/src/forecaster.rs b/crates/augurs-prophet/src/forecaster.rs index a1de3b81..bd176bae 100644 --- a/crates/augurs-prophet/src/forecaster.rs +++ b/crates/augurs-prophet/src/forecaster.rs @@ -27,6 +27,8 @@ impl ProphetForecaster { /// # Parameters /// /// - `opts`: The options to use for fitting the model. + /// Note that `uncertainty_samples` will be set to 1000 if it is 0, + /// to facilitate generating prediction intervals. /// - `optimizer`: The optimizer to use for fitting the model. /// - `optimize_opts`: The options to use for optimizing the model. pub fn new(