Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Forecaster wrapper for Prophet and modify core traits #184

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/augurs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::convert::Infallible;

pub use distance::DistanceMatrix;
pub use forecast::{Forecast, ForecastIntervals};
pub use traits::{Fit, Predict};
pub use traits::{Data, Fit, MutableData, Predict};

/// An error produced by a time series forecasting model.
pub trait ModelError: std::error::Error + Sync + Send + 'static {}
Expand Down
116 changes: 105 additions & 11 deletions crates/augurs-core/src/traits.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,121 @@
use crate::{Forecast, ModelError};

/// Trait for data that can be used as an input to [`Fit`].
///
/// This trait is implemented for a number of types including slices, arrays, and
/// vectors. It is also implemented for references to these types.
pub trait Data {
/// Return the data as a slice of `f64`.
fn as_slice(&self) -> &[f64];
}

impl<const N: usize> Data for [f64; N] {
fn as_slice(&self) -> &[f64] {
self
}
}

impl Data for &[f64] {
fn as_slice(&self) -> &[f64] {
self
}
}

impl Data for &mut [f64] {
fn as_slice(&self) -> &[f64] {
self
}
}

impl Data for Vec<f64> {
fn as_slice(&self) -> &[f64] {
self.as_slice()
}
}

impl<T> Data for &T
where
T: Data,
{
fn as_slice(&self) -> &[f64] {
(**self).as_slice()
}
}

impl<T> Data for &mut T
where
T: Data,
{
fn as_slice(&self) -> &[f64] {
(**self).as_slice()
}
}

/// Trait for data that can be used in the forecaster.
///
/// This trait is implemented for a number of types including slices, arrays, and
/// vectors. It is also implemented for references to these types.
pub trait MutableData: Data {
/// Update the `y` values to those in the provided slice.
fn set(&mut self, y: Vec<f64>);
}

impl<const N: usize> MutableData for [f64; N] {
fn set(&mut self, y: Vec<f64>) {
self.copy_from_slice(y.as_slice());
}
}

impl MutableData for &mut [f64] {
fn set(&mut self, y: Vec<f64>) {
self.copy_from_slice(y.as_slice());
}
}

impl MutableData for Vec<f64> {
fn set(&mut self, y: Vec<f64>) {
self.copy_from_slice(y.as_slice());
}
}

impl<T> MutableData for &mut T
where
T: MutableData,
{
fn set(&mut self, y: Vec<f64>) {
(**self).set(y);
}
}

/// A new, unfitted time series forecasting model.
pub trait Fit {
/// The type of the training data used to fit the model.
type TrainingData<'a>: Data
where
Self: 'a;

/// The type of the fitted model produced by the `fit` method.
type Fitted: Predict;

/// The type of error returned when fitting the model.
type Error: ModelError;

/// Fit the model to the training data.
fn fit(&self, y: &[f64]) -> Result<Self::Fitted, Self::Error>;
fn fit<'a, 'b: 'a>(&'b self, y: Self::TrainingData<'a>) -> Result<Self::Fitted, Self::Error>;
}

impl<F> Fit for Box<F>
where
F: Fit,
{
type Fitted = F::Fitted;
type Error = F::Error;
fn fit(&self, y: &[f64]) -> Result<Self::Fitted, Self::Error> {
(**self).fit(y)
}
}
// impl<'a, F, TD> Fit for Box<F>
// where
// F: Fit<TrainingData<'a> = TD>,
// TD: Data,
// {
// type TrainingData = TD;
// type Fitted = F::Fitted;
// type Error = F::Error;
// fn fit(&self, y: Self::TrainingData<'a>) -> Result<Self::Fitted, Self::Error> {
// (**self).fit(y)
// }
// }

/// A fitted time series forecasting model.
pub trait Predict {
Expand Down
6 changes: 4 additions & 2 deletions crates/augurs-ets/src/auto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,10 @@ impl AutoETS {
}

impl Fit for AutoETS {
type TrainingData<'a> = &'a [f64];
type Fitted = FittedAutoETS;
type Error = Error;

/// Search for the best model, fitting it to the data.
///
/// The model is stored on the `AutoETS` struct and can be retrieved with
Expand All @@ -470,7 +472,7 @@ impl Fit for AutoETS {
///
/// If no model can be found, or if any parameters are invalid, this function
/// returns an error.
fn fit(&self, y: &[f64]) -> Result<Self::Fitted> {
fn fit<'a, 'b: 'a>(&'b self, y: Self::TrainingData<'a>) -> Result<Self::Fitted> {
let data_positive = y.iter().fold(f64::INFINITY, |a, &b| a.min(b)) > 0.0;
if self.spec.error == ErrorSpec::Multiplicative && !data_positive {
return Err(Error::InvalidModelSpec(format!(
Expand Down Expand Up @@ -627,7 +629,7 @@ mod test {
#[test]
fn air_passengers_fit() {
let auto = AutoETS::new(1, "ZZN").unwrap();
let fit = auto.fit(AIR_PASSENGERS).expect("fit failed");
let fit = auto.fit(&mut AIR_PASSENGERS.to_vec()).expect("fit failed");
assert_eq!(fit.model.model_type().error, ErrorComponent::Multiplicative);
assert_eq!(fit.model.model_type().trend, TrendComponent::Additive);
assert_eq!(fit.model.model_type().season, SeasonalComponent::None);
Expand Down
2 changes: 1 addition & 1 deletion crates/augurs-ets/src/trend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl TrendModel for AutoETSTrendModel {

fn fit(
&self,
y: &[f64],
y: &mut [f64],
) -> Result<
Box<dyn FittedTrendModel + Sync + Send>,
Box<dyn std::error::Error + Send + Sync + 'static>,
Expand Down
35 changes: 0 additions & 35 deletions crates/augurs-forecaster/src/data.rs

This file was deleted.

16 changes: 9 additions & 7 deletions crates/augurs-forecaster/src/forecaster.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use augurs_core::{Fit, Forecast, Predict};
use augurs_core::{Data, Fit, Forecast, MutableData, Predict};

use crate::{Data, Error, Result, Transform, Transforms};
use crate::{Error, Result, Transform, Transforms};

/// A high-level API to fit and predict time series forecasting models.
///
Expand All @@ -16,10 +16,11 @@ pub struct Forecaster<M: Fit> {
transforms: Transforms,
}

impl<M> Forecaster<M>
impl<'a, M: 'a> Forecaster<M>
where
M: Fit,
M::Fitted: Predict,
M::TrainingData<'a>: MutableData,
{
/// Create a new `Forecaster` with the given model.
pub fn new(model: M) -> Self {
Expand All @@ -37,12 +38,13 @@ where
}

/// Fit the model to the given time series.
pub fn fit<D: Data + Clone>(&mut self, y: D) -> Result<()> {
pub fn fit<'b: 'a>(&'b mut self, mut td: M::TrainingData<'a>) -> Result<()> {
let data: Vec<_> = self
.transforms
.transform(y.as_slice().iter().copied())
.transform(td.as_slice().iter().copied())
.collect();
self.fitted = Some(self.model.fit(&data).map_err(|e| Error::Fit {
td.set(data);
self.fitted = Some(self.model.fit(td).map_err(|e| Error::Fit {
source: Box::new(e) as _,
})?);
Ok(())
Expand Down Expand Up @@ -105,7 +107,7 @@ mod test {

#[test]
fn test_forecaster() {
let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0];
let data = &mut [1.0_f64, 2.0, 3.0, 4.0, 5.0];
let MinMaxResult::MinMax(min, max) = data
.iter()
.copied()
Expand Down
2 changes: 0 additions & 2 deletions crates/augurs-forecaster/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#![doc = include_str!("../README.md")]

mod data;
mod error;
mod forecaster;
pub mod transforms;

pub use data::Data;
pub use error::Error;
pub use forecaster::Forecaster;
pub use transforms::Transform;
Expand Down
18 changes: 11 additions & 7 deletions crates/augurs-mstl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ impl<T: TrendModel> MSTLModel<T> {
// Determine the differencing term for the trend component.
let trend = fit.trend();
let residual = fit.remainder();
let deseasonalised = trend
let mut deseasonalised = trend
.iter()
.zip(residual)
.map(|(t, r)| (t + r) as f64)
.collect::<Vec<_>>();
let fitted_trend_model = self
.trend_model
.fit(&deseasonalised)
.fit(&mut deseasonalised)
.map_err(Error::TrendModel)?;
tracing::trace!(
trend_model = ?self.trend_model,
Expand Down Expand Up @@ -234,9 +234,13 @@ impl FittedMSTLModel {
impl ModelError for Error {}

impl<T: TrendModel> augurs_core::Fit for MSTLModel<T> {
type TrainingData<'a> = &'a mut [f64] where T: 'a;
type Fitted = FittedMSTLModel;
type Error = Error;
fn fit(&self, y: &[f64]) -> Result<Self::Fitted> {
fn fit<'a, 'b: 'a>(&self, y: Self::TrainingData<'a>) -> Result<Self::Fitted>
where
T: 'a,
{
self.fit_impl(y)
}
}
Expand Down Expand Up @@ -271,7 +275,7 @@ mod tests {

#[test]
fn results_match_r() {
let y = VIC_ELEC.clone();
let mut y = VIC_ELEC.clone();

let mut stl_params = stlrs::params();
stl_params
Expand All @@ -287,7 +291,7 @@ mod tests {
let periods = vec![24, 24 * 7];
let trend_model = NaiveTrend::new();
let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params);
let fit = mstl.fit(&y).unwrap();
let fit = mstl.fit(&mut y).unwrap();

let in_sample = fit.predict_in_sample(0.95).unwrap();
// The first 12 values from R.
Expand Down Expand Up @@ -332,7 +336,7 @@ mod tests {

#[test]
fn predict_zero_horizon() {
let y = VIC_ELEC.clone();
let mut y = VIC_ELEC.clone();

let mut stl_params = stlrs::params();
stl_params
Expand All @@ -348,7 +352,7 @@ mod tests {
let periods = vec![24, 24 * 7];
let trend_model = NaiveTrend::new();
let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params);
let fit = mstl.fit(&y).unwrap();
let fit = mstl.fit(&mut y).unwrap();
let forecast = fit.predict(0, 0.95).unwrap();
assert!(forecast.point.is_empty());
let ForecastIntervals { lower, upper, .. } = forecast.intervals.unwrap();
Expand Down
6 changes: 3 additions & 3 deletions crates/augurs-mstl/src/trend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub trait TrendModel: Debug {
/// Implementations should store any state required for prediction in the struct itself.
fn fit(
&self,
y: &[f64],
y: &mut [f64],
) -> Result<
Box<dyn FittedTrendModel + Sync + Send>,
Box<dyn std::error::Error + Send + Sync + 'static>,
Expand Down Expand Up @@ -121,7 +121,7 @@ impl<T: TrendModel + ?Sized> TrendModel for Box<T> {

fn fit(
&self,
y: &[f64],
y: &mut [f64],
) -> Result<
Box<dyn FittedTrendModel + Sync + Send>,
Box<dyn std::error::Error + Send + Sync + 'static>,
Expand Down Expand Up @@ -196,7 +196,7 @@ impl TrendModel for NaiveTrend {

fn fit(
&self,
y: &[f64],
y: &mut [f64],
) -> Result<
Box<dyn FittedTrendModel + Sync + Send>,
Box<dyn std::error::Error + Send + Sync + 'static>,
Expand Down
1 change: 1 addition & 0 deletions crates/augurs-prophet/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ include = [

[dependencies]
anyhow.workspace = true
augurs-core.workspace = true
bytemuck = { workspace = true, features = ["derive"], optional = true }
itertools.workspace = true
num-traits.workspace = true
Expand Down
5 changes: 5 additions & 0 deletions crates/augurs-prophet/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,9 @@ pub enum Error {
/// there is no frequency that appears more often than others.
#[error("Unable to infer frequency from dates: {0:?}")]
UnableToInferFrequency(Vec<TimestampSeconds>),
/// The provided horizon was invalid.
///
/// The horizon must be greater than 0.
#[error("Horizon must be > 0, got {0}")]
InvalidHorizon(usize),
}
Loading
Loading