diff --git a/crates/augurs-prophet/Cargo.toml b/crates/augurs-prophet/Cargo.toml index f8159d84..f8a128c9 100644 --- a/crates/augurs-prophet/Cargo.toml +++ b/crates/augurs-prophet/Cargo.toml @@ -28,6 +28,7 @@ include = [ [dependencies] anyhow.workspace = true +augurs-core.workspace = true bytemuck = { workspace = true, features = ["derive"], optional = true } itertools.workspace = true num-traits.workspace = true diff --git a/crates/augurs-prophet/src/forecaster.rs b/crates/augurs-prophet/src/forecaster.rs new file mode 100644 index 00000000..bd176bae --- /dev/null +++ b/crates/augurs-prophet/src/forecaster.rs @@ -0,0 +1,192 @@ +//! [`Fit`] and [`Predict`] implementations for the Prophet algorithm. +use std::{cell::RefCell, num::NonZeroU32, sync::Arc}; + +use augurs_core::{Fit, ModelError, Predict}; + +use crate::{optimizer::OptimizeOpts, Error, IncludeHistory, Optimizer, Prophet, TrainingData}; + +impl ModelError for Error {} + +/// A forecaster that uses the Prophet algorithm. +/// +/// This is a wrapper around the [`Prophet`] struct that provides +/// a simpler API for fitting and predicting. Notably it implements +/// the [`Fit`] trait from `augurs_core`, so it can be +/// used with the `augurs` framework (e.g. with the `Forecaster` struct +/// in the `augurs::forecaster` module). +#[derive(Debug)] +pub struct ProphetForecaster { + data: TrainingData, + model: Prophet>, + optimize_opts: OptimizeOpts, +} + +impl ProphetForecaster { + /// Create a new Prophet forecaster. + /// + /// # Parameters + /// + /// - `opts`: The options to use for fitting the model. + /// Note that `uncertainty_samples` will be set to 1000 if it is 0, + /// to facilitate generating prediction intervals. + /// - `optimizer`: The optimizer to use for fitting the model. + /// - `optimize_opts`: The options to use for optimizing the model. + pub fn new( + mut model: Prophet, + data: TrainingData, + optimize_opts: OptimizeOpts, + ) -> Self { + let opts = model.opts_mut(); + if opts.uncertainty_samples == 0 { + opts.uncertainty_samples = 1000; + } + Self { + data, + model: model.into_dyn_optimizer(), + optimize_opts, + } + } +} + +impl Fit for ProphetForecaster { + type Fitted = FittedProphetForecaster; + type Error = Error; + + fn fit(&self, y: &[f64]) -> Result { + // Use the training data from `self`... + let mut training_data = self.data.clone(); + // ...but replace the `y` column with whatever we're passed + // (which may be a transformed version of `y`, if the user is + // using `augurs_forecaster`). + training_data.y = y.to_vec(); + let mut fitted_model = self.model.clone(); + fitted_model.fit(training_data, self.optimize_opts.clone())?; + Ok(FittedProphetForecaster { + model: RefCell::new(fitted_model), + training_n: y.len(), + }) + } +} + +/// A fitted Prophet forecaster. +#[derive(Debug)] +pub struct FittedProphetForecaster { + model: RefCell>>, + training_n: usize, +} + +impl Predict for FittedProphetForecaster { + type Error = Error; + + fn predict_in_sample_inplace( + &self, + level: Option, + forecast: &mut augurs_core::Forecast, + ) -> Result<(), Self::Error> { + if let Some(level) = level { + self.model + .borrow_mut() + .set_interval_width(level.try_into()?); + } + let predictions = self.model.borrow().predict(None)?; + forecast.point = predictions.yhat.point; + if let Some(intervals) = forecast.intervals.as_mut() { + intervals.lower = predictions + .yhat + .lower + // This `expect` is OK because we've set uncertainty_samples > 0 in the + // `ProphetForecaster` constructor. + .expect("uncertainty_samples should be > 0, this is a bug"); + intervals.upper = predictions + .yhat + .upper + // This `expect` is OK because we've set uncertainty_samples > 0 in the + // `ProphetForecaster` constructor. + .expect("uncertainty_samples should be > 0, this is a bug"); + } + Ok(()) + } + + fn predict_inplace( + &self, + horizon: usize, + level: Option, + forecast: &mut augurs_core::Forecast, + ) -> Result<(), Self::Error> { + let horizon = match NonZeroU32::try_from(horizon as u32) { + Ok(h) => h, + // If horizon is 0, short circuit without even trying to predict. + Err(_) => return Ok(()), + }; + if let Some(level) = level { + self.model + .borrow_mut() + .set_interval_width(level.try_into()?); + } + let predictions = { + let model = self.model.borrow(); + let prediction_data = model.make_future_dataframe(horizon, IncludeHistory::No)?; + model.predict(prediction_data)? + }; + forecast.point = predictions.yhat.point; + if let Some(intervals) = forecast.intervals.as_mut() { + intervals.lower = predictions + .yhat + .lower + // This `expect` is OK because we've set uncertainty_samples > 0 in the + // `ProphetForecaster` constructor. + .expect("uncertainty_samples should be > 0"); + intervals.upper = predictions + .yhat + .upper + // This `expect` is OK because we've set uncertainty_samples > 0 in the + // `ProphetForecaster` constructor. + .expect("uncertainty_samples should be > 0"); + } + Ok(()) + } + + fn training_data_size(&self) -> usize { + self.training_n + } +} + +#[cfg(test)] +mod test { + + use augurs_core::{Fit, Predict}; + use augurs_testing::assert_all_close; + + use crate::{ + testdata::{daily_univariate_ts, train_test_splitn}, + wasmstan::WasmstanOptimizer, + IncludeHistory, Prophet, + }; + + use super::ProphetForecaster; + + #[test] + fn forecaster() { + let test_days = 30; + let (train, _) = train_test_splitn(daily_univariate_ts(), test_days); + + let model = Prophet::new(Default::default(), WasmstanOptimizer::new()); + let forecaster = ProphetForecaster::new(model, train.clone(), Default::default()); + let fitted = forecaster.fit(&train.y).unwrap(); + let forecast_predictions = fitted.predict(30, 0.95).unwrap(); + + let mut prophet = Prophet::new(Default::default(), WasmstanOptimizer::new()); + prophet.fit(train, Default::default()).unwrap(); + let prediction_data = prophet + .make_future_dataframe(30.try_into().unwrap(), IncludeHistory::No) + .unwrap(); + let predictions = prophet.predict(prediction_data).unwrap(); + + // We should get the same results back when using the Forecaster impl. + assert_eq!( + predictions.yhat.point.len(), + forecast_predictions.point.len() + ); + assert_all_close(&predictions.yhat.point, &forecast_predictions.point); + } +} diff --git a/crates/augurs-prophet/src/lib.rs b/crates/augurs-prophet/src/lib.rs index 457bb2ac..1c085172 100644 --- a/crates/augurs-prophet/src/lib.rs +++ b/crates/augurs-prophet/src/lib.rs @@ -2,6 +2,7 @@ mod data; mod error; mod features; +pub mod forecaster; // Export the optimizer module so that users can implement their own // optimizers. pub mod optimizer; diff --git a/crates/augurs-prophet/src/optimizer.rs b/crates/augurs-prophet/src/optimizer.rs index 17268be1..06c582bd 100644 --- a/crates/augurs-prophet/src/optimizer.rs +++ b/crates/augurs-prophet/src/optimizer.rs @@ -21,7 +21,7 @@ // WASM Components? // TODO: write a pure Rust optimizer for the default case. -use std::fmt; +use std::{fmt, sync::Arc}; use crate::positive_float::PositiveFloat; @@ -299,6 +299,20 @@ pub trait Optimizer: std::fmt::Debug { ) -> Result; } +/// An implementation of `Optimize` which simply delegates to the +/// `Arc`'s inner type. This enables thread-safe sharing of optimizers +/// while maintaining the ability to use dynamic dispatch. +impl Optimizer for Arc { + fn optimize( + &self, + init: &InitialParams, + data: &Data, + opts: &OptimizeOpts, + ) -> Result { + (**self).optimize(init, data, opts) + } +} + #[cfg(test)] pub(crate) mod mock_optimizer { use std::cell::RefCell; diff --git a/crates/augurs-prophet/src/prophet.rs b/crates/augurs-prophet/src/prophet.rs index 9cb5ff64..62f928ed 100644 --- a/crates/augurs-prophet/src/prophet.rs +++ b/crates/augurs-prophet/src/prophet.rs @@ -5,6 +5,7 @@ pub(crate) mod prep; use std::{ collections::{HashMap, HashSet}, num::NonZeroU32, + sync::Arc, }; use itertools::{izip, Itertools}; @@ -12,9 +13,10 @@ use options::ProphetOptions; use prep::{ComponentColumns, Modes, Preprocessed, Scales}; use crate::{ + forecaster::ProphetForecaster, optimizer::{InitialParams, OptimizeOpts, OptimizedParams, Optimizer}, - Error, EstimationMode, FeaturePrediction, IncludeHistory, PredictionData, Predictions, - Regressor, Seasonality, TimestampSeconds, TrainingData, + Error, EstimationMode, FeaturePrediction, IncludeHistory, IntervalWidth, PredictionData, + Predictions, Regressor, Seasonality, TimestampSeconds, TrainingData, }; /// The Prophet time series forecasting model. @@ -233,6 +235,25 @@ impl Prophet { Ok(PredictionData::new(ds)) } + /// Get a reference to the Prophet options. + pub fn opts(&self) -> &ProphetOptions { + &self.opts + } + + /// Get a mutable reference to the Prophet options. + pub fn opts_mut(&mut self) -> &mut ProphetOptions { + &mut self.opts + } + + /// Set the width of the uncertainty intervals. + /// + /// The interval width does not affect training, only predictions, + /// so this can be called after fitting the model to obtain predictions + /// with different levels of uncertainty. + pub fn set_interval_width(&mut self, interval_width: IntervalWidth) { + self.opts.interval_width = interval_width; + } + fn infer_freq(history_dates: &[TimestampSeconds]) -> Result { const INFER_N: usize = 5; let get_tried = || { @@ -268,6 +289,38 @@ impl Prophet { } } +impl Prophet { + pub(crate) fn into_dyn_optimizer(self) -> Prophet> { + Prophet { + optimizer: Arc::new(self.optimizer), + opts: self.opts, + regressors: self.regressors, + optimized: self.optimized, + changepoints: self.changepoints, + changepoints_t: self.changepoints_t, + init: self.init, + scales: self.scales, + processed: self.processed, + seasonalities: self.seasonalities, + component_modes: self.component_modes, + train_holiday_names: self.train_holiday_names, + train_component_columns: self.train_component_columns, + } + } + + /// Create a new `ProphetForecaster` from this Prophet model. + /// + /// This requires the data and optimize options to be provided and sets up + /// a `ProphetForecaster` ready to be used with the `augurs_forecaster` crate. + pub fn into_forecaster( + self, + data: TrainingData, + optimize_opts: OptimizeOpts, + ) -> ProphetForecaster { + ProphetForecaster::new(self, data, optimize_opts) + } +} + impl Prophet { /// Fit the Prophet model to some training data. pub fn fit(&mut self, data: TrainingData, mut opts: OptimizeOpts) -> Result<(), Error> { diff --git a/examples/forecasting/examples/prophet_forecaster.rs b/examples/forecasting/examples/prophet_forecaster.rs new file mode 100644 index 00000000..5d0b826b --- /dev/null +++ b/examples/forecasting/examples/prophet_forecaster.rs @@ -0,0 +1,53 @@ +//! Example of using the Prophet model with the wasmstan optimizer. + +use augurs::{ + forecaster::{transforms::MinMaxScaleParams, Forecaster, Transform}, + prophet::{wasmstan::WasmstanOptimizer, Prophet, TrainingData}, +}; + +fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); + tracing::info!("Running Prophet example"); + + let ds = vec![ + 1704067200, 1704871384, 1705675569, 1706479753, 1707283938, 1708088123, 1708892307, + 1709696492, 1710500676, 1711304861, 1712109046, 1712913230, 1713717415, + ]; + let y = vec![ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, + ]; + let data = TrainingData::new(ds, y.clone())?; + + // Set up the transforms. + // These are just illustrative examples; you can use whatever transforms + // you want. + let transforms = vec![Transform::min_max_scaler(MinMaxScaleParams::from_data( + y.iter().copied(), + ))]; + + // Set up the model. Create the Prophet model as normal, then convert it to a + // `ProphetForecaster`. + let prophet = Prophet::new(Default::default(), WasmstanOptimizer::new()); + let prophet_forecaster = prophet.into_forecaster(data.clone(), Default::default()); + + // Finally create a Forecaster using those transforms. + let mut forecaster = Forecaster::new(prophet_forecaster).with_transforms(transforms); + + // Fit the forecaster. This will transform the training data by + // running the transforms in order, then fit the Prophet model. + forecaster.fit(&y).expect("model should fit"); + + // Generate some in-sample predictions with 95% prediction intervals. + // The forecaster will handle back-transforming them onto our original scale. + let predictions = forecaster.predict_in_sample(0.95)?; + assert_eq!(predictions.point.len(), y.len()); + assert!(predictions.intervals.is_some()); + println!("In-sample predictions: {:?}", predictions); + + // Generate 10 out-of-sample predictions with 95% prediction intervals. + let predictions = forecaster.predict(10, 0.95)?; + assert_eq!(predictions.point.len(), 10); + assert!(predictions.intervals.is_some()); + println!("Out-of-sample predictions: {:?}", predictions); + Ok(()) +}