Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Forecaster wrapper for Prophet #191

Merged
merged 4 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/augurs-prophet/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ include = [

[dependencies]
anyhow.workspace = true
augurs-core.workspace = true
bytemuck = { workspace = true, features = ["derive"], optional = true }
itertools.workspace = true
num-traits.workspace = true
Expand Down
191 changes: 191 additions & 0 deletions crates/augurs-prophet/src/forecaster.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
//! [`Fit`] and [`Predict`] implementations for the Prophet algorithm.
use std::{cell::RefCell, num::NonZeroU32, sync::Arc};

use augurs_core::{Fit, ModelError, Predict};

use crate::{optimizer::OptimizeOpts, Error, IncludeHistory, Optimizer, Prophet, TrainingData};

impl ModelError for Error {}

/// A forecaster that uses the Prophet algorithm.
///
/// This is a wrapper around the [`Prophet`] struct that provides
/// a simpler API for fitting and predicting. Notably it implements
/// the [`Fit`] trait from `augurs_core`, so it can be
/// used with the `augurs` framework (e.g. with the `Forecaster` struct
/// in the `augurs::forecaster` module).
#[derive(Debug)]
pub struct ProphetForecaster {
data: TrainingData,
model: Prophet<Arc<dyn Optimizer>>,
optimize_opts: OptimizeOpts,
}

impl ProphetForecaster {
/// Create a new Prophet forecaster.
///
/// # Parameters
///
/// - `opts`: The options to use for fitting the model.
/// - `optimizer`: The optimizer to use for fitting the model.
/// - `optimize_opts`: The options to use for optimizing the model.
pub fn new<T: Optimizer + 'static>(
mut model: Prophet<T>,
data: TrainingData,
optimize_opts: OptimizeOpts,
) -> Self {
let opts = model.opts_mut();
if opts.uncertainty_samples == 0 {
opts.uncertainty_samples = 1000;
}
Self {
data,
model: model.into_dyn_optimizer(),
optimize_opts,
}
}
}

impl Fit for ProphetForecaster {
type Fitted = FittedProphetForecaster;
type Error = Error;

fn fit(&self, y: &[f64]) -> Result<Self::Fitted, Self::Error> {
// Use the training data from `self`...
let mut training_data = self.data.clone();
// ...but replace the `y` column with whatever we're passed
// (which may be a transformed version of `y`, if the user is
// using `augurs_forecaster`).
training_data.y = y.to_vec();
let mut fitted_model = self.model.clone();
fitted_model.fit(training_data, self.optimize_opts.clone())?;
Ok(FittedProphetForecaster {
model: RefCell::new(fitted_model),
training_n: y.len(),
})
}
}

/// A fitted Prophet forecaster.
#[derive(Debug)]
pub struct FittedProphetForecaster {
model: RefCell<Prophet<Arc<dyn Optimizer>>>,
training_n: usize,
}

impl Predict for FittedProphetForecaster {
type Error = Error;

fn predict_in_sample_inplace(
&self,
level: Option<f64>,
forecast: &mut augurs_core::Forecast,
) -> Result<(), Self::Error> {
if let Some(level) = level {
self.model
.borrow_mut()
.set_interval_width(level.try_into()?);
}
let predictions = self.model.borrow().predict(None)?;
forecast.point = predictions.yhat.point;
if let Some(intervals) = forecast.intervals.as_mut() {
intervals.lower = predictions
.yhat
.lower
// This `expect` is OK because we've set uncertainty_samples > 0 in the
// `ProphetForecaster` constructor.
.expect("uncertainty_samples should be > 0");
intervals.upper = predictions
.yhat
.upper
// This `expect` is OK because we've set uncertainty_samples > 0 in the
// `ProphetForecaster` constructor.
.expect("uncertainty_samples should be > 0");
}
Ok(())
}

fn predict_inplace(
&self,
horizon: usize,
level: Option<f64>,
forecast: &mut augurs_core::Forecast,
) -> Result<(), Self::Error> {
if horizon == 0 {
return Ok(());
}
if let Some(level) = level {
self.model
.borrow_mut()
.set_interval_width(level.try_into()?);
}
let predictions = {
let model = self.model.borrow();
let prediction_data = model.make_future_dataframe(
NonZeroU32::try_from(horizon as u32).expect("horizon should be > 0"),
IncludeHistory::No,
)?;
model.predict(prediction_data)?
};
forecast.point = predictions.yhat.point;
if let Some(intervals) = forecast.intervals.as_mut() {
intervals.lower = predictions
.yhat
.lower
// This `expect` is OK because we've set uncertainty_samples > 0 in the
// `ProphetForecaster` constructor.
.expect("uncertainty_samples should be > 0");
intervals.upper = predictions
.yhat
.upper
// This `expect` is OK because we've set uncertainty_samples > 0 in the
// `ProphetForecaster` constructor.
.expect("uncertainty_samples should be > 0");
}
Ok(())
}

fn training_data_size(&self) -> usize {
self.training_n
}
}

#[cfg(test)]
mod test {

use augurs_core::{Fit, Predict};
use augurs_testing::assert_all_close;

use crate::{
testdata::{daily_univariate_ts, train_test_splitn},
wasmstan::WasmstanOptimizer,
IncludeHistory, Prophet,
};

use super::ProphetForecaster;

#[test]
fn forecaster() {
let test_days = 30;
let (train, _) = train_test_splitn(daily_univariate_ts(), test_days);

let model = Prophet::new(Default::default(), WasmstanOptimizer::new());
let forecaster = ProphetForecaster::new(model, train.clone(), Default::default());
let fitted = forecaster.fit(&train.y).unwrap();
let forecast_predictions = fitted.predict(30, 0.95).unwrap();

let mut prophet = Prophet::new(Default::default(), WasmstanOptimizer::new());
prophet.fit(train, Default::default()).unwrap();
let prediction_data = prophet
.make_future_dataframe(30.try_into().unwrap(), IncludeHistory::No)
.unwrap();
let predictions = prophet.predict(prediction_data).unwrap();

// We should get the same results back when using the Forecaster impl.
assert_eq!(
predictions.yhat.point.len(),
forecast_predictions.point.len()
);
assert_all_close(&predictions.yhat.point, &forecast_predictions.point);
}
}
1 change: 1 addition & 0 deletions crates/augurs-prophet/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod data;
mod error;
mod features;
pub mod forecaster;
// Export the optimizer module so that users can implement their own
// optimizers.
pub mod optimizer;
Expand Down
13 changes: 12 additions & 1 deletion crates/augurs-prophet/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
// WASM Components?
// TODO: write a pure Rust optimizer for the default case.

use std::fmt;
use std::{fmt, sync::Arc};

use crate::positive_float::PositiveFloat;

Expand Down Expand Up @@ -299,6 +299,17 @@ pub trait Optimizer: std::fmt::Debug {
) -> Result<OptimizedParams, Error>;
}

impl Optimizer for Arc<dyn Optimizer> {
fn optimize(
&self,
init: &InitialParams,
data: &Data,
opts: &OptimizeOpts,
) -> Result<OptimizedParams, Error> {
(**self).optimize(init, data, opts)
}
}

#[cfg(test)]
pub(crate) mod mock_optimizer {
use std::cell::RefCell;
Expand Down
57 changes: 55 additions & 2 deletions crates/augurs-prophet/src/prophet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@ pub(crate) mod prep;
use std::{
collections::{HashMap, HashSet},
num::NonZeroU32,
sync::Arc,
};

use itertools::{izip, Itertools};
use options::ProphetOptions;
use prep::{ComponentColumns, Modes, Preprocessed, Scales};

use crate::{
forecaster::ProphetForecaster,
optimizer::{InitialParams, OptimizeOpts, OptimizedParams, Optimizer},
Error, EstimationMode, FeaturePrediction, IncludeHistory, PredictionData, Predictions,
Regressor, Seasonality, TimestampSeconds, TrainingData,
Error, EstimationMode, FeaturePrediction, IncludeHistory, IntervalWidth, PredictionData,
Predictions, Regressor, Seasonality, TimestampSeconds, TrainingData,
};

/// The Prophet time series forecasting model.
Expand Down Expand Up @@ -233,6 +235,25 @@ impl<O> Prophet<O> {
Ok(PredictionData::new(ds))
}

/// Get a reference to the Prophet options.
pub fn opts(&self) -> &ProphetOptions {
&self.opts
}

/// Get a mutable reference to the Prophet options.
pub fn opts_mut(&mut self) -> &mut ProphetOptions {
&mut self.opts
}

/// Set the width of the uncertainty intervals.
///
/// The interval width does not affect training, only predictions,
/// so this can be called after fitting the model to obtain predictions
/// with different levels of uncertainty.
pub fn set_interval_width(&mut self, interval_width: IntervalWidth) {
self.opts.interval_width = interval_width;
}

fn infer_freq(history_dates: &[TimestampSeconds]) -> Result<TimestampSeconds, Error> {
const INFER_N: usize = 5;
let get_tried = || {
Expand Down Expand Up @@ -268,6 +289,38 @@ impl<O> Prophet<O> {
}
}

impl<O: Optimizer + 'static> Prophet<O> {
pub(crate) fn into_dyn_optimizer(self) -> Prophet<Arc<dyn Optimizer + 'static>> {
Prophet {
optimizer: Arc::new(self.optimizer),
opts: self.opts,
regressors: self.regressors,
optimized: self.optimized,
changepoints: self.changepoints,
changepoints_t: self.changepoints_t,
init: self.init,
scales: self.scales,
processed: self.processed,
seasonalities: self.seasonalities,
component_modes: self.component_modes,
train_holiday_names: self.train_holiday_names,
train_component_columns: self.train_component_columns,
}
}

/// Create a new `ProphetForecaster` from this Prophet model.
///
/// This requires the data and optimize options to be provided and sets up
/// a `ProphetForecaster` ready to be used with the `augurs_forecaster` crate.
pub fn into_forecaster(
self,
data: TrainingData,
optimize_opts: OptimizeOpts,
) -> ProphetForecaster {
ProphetForecaster::new(self, data, optimize_opts)
}
}

impl<O: Optimizer> Prophet<O> {
/// Fit the Prophet model to some training data.
pub fn fit(&mut self, data: TrainingData, mut opts: OptimizeOpts) -> Result<(), Error> {
Expand Down
53 changes: 53 additions & 0 deletions examples/forecasting/examples/prophet_forecaster.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//! Example of using the Prophet model with the wasmstan optimizer.

use augurs::{
forecaster::{transforms::MinMaxScaleParams, Forecaster, Transform},
prophet::{wasmstan::WasmstanOptimizer, Prophet, TrainingData},
};

fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
tracing::info!("Running Prophet example");

let ds = vec![
1704067200, 1704871384, 1705675569, 1706479753, 1707283938, 1708088123, 1708892307,
1709696492, 1710500676, 1711304861, 1712109046, 1712913230, 1713717415,
];
let y = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0,
];
let data = TrainingData::new(ds, y.clone())?;

// Set up the transforms.
// These are just illustrative examples; you can use whatever transforms
// you want.
let transforms = vec![Transform::min_max_scaler(MinMaxScaleParams::from_data(
y.iter().copied(),
))];

// Set up the model. Create the Prophet model as normal, then convert it to a
// `ProphetForecaster`.
let prophet = Prophet::new(Default::default(), WasmstanOptimizer::new());
let prophet_forecaster = prophet.into_forecaster(data.clone(), Default::default());

// Finally create a Forecaster using those transforms.
let mut forecaster = Forecaster::new(prophet_forecaster).with_transforms(transforms);

// Fit the forecaster. This will transform the training data by
// running the transforms in order, then fit the Prophet model.
forecaster.fit(&y).expect("model should fit");

// Generate some in-sample predictions with 95% prediction intervals.
// The forecaster will handle back-transforming them onto our original scale.
let predictions = forecaster.predict_in_sample(0.95)?;
assert_eq!(predictions.point.len(), y.len());
assert!(predictions.intervals.is_some());
println!("In-sample predictions: {:?}", predictions);

// Generate 10 out-of-sample predictions with 95% prediction intervals.
let predictions = forecaster.predict(10, 0.95)?;
assert_eq!(predictions.point.len(), 10);
assert!(predictions.intervals.is_some());
println!("Out-of-sample predictions: {:?}", predictions);
Ok(())
}
Loading