From 510654ca92037443c66b43e1211fad1f53512331 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Mon, 11 Nov 2024 11:45:01 +0000 Subject: [PATCH 1/2] feat: add Forecaster wrapper for Prophet This isn't ideal because the APIs don't match at all - the forecaster traits don't pass timestamps or anything else down. We could modify them to use an associated type for the data? Not sure how that would affect dyn compatibility (if it exists). --- crates/augurs-prophet/Cargo.toml | 1 + crates/augurs-prophet/src/forecaster.rs | 149 ++++++++++++++++++++++++ crates/augurs-prophet/src/lib.rs | 1 + crates/augurs-prophet/src/optimizer.rs | 13 ++- crates/augurs-prophet/src/prophet.rs | 13 ++- 5 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 crates/augurs-prophet/src/forecaster.rs diff --git a/crates/augurs-prophet/Cargo.toml b/crates/augurs-prophet/Cargo.toml index f8159d84..f8a128c9 100644 --- a/crates/augurs-prophet/Cargo.toml +++ b/crates/augurs-prophet/Cargo.toml @@ -28,6 +28,7 @@ include = [ [dependencies] anyhow.workspace = true +augurs-core.workspace = true bytemuck = { workspace = true, features = ["derive"], optional = true } itertools.workspace = true num-traits.workspace = true diff --git a/crates/augurs-prophet/src/forecaster.rs b/crates/augurs-prophet/src/forecaster.rs new file mode 100644 index 00000000..656b65b8 --- /dev/null +++ b/crates/augurs-prophet/src/forecaster.rs @@ -0,0 +1,149 @@ +//! [`Fit`] and [`Predict`] implementations for the Prophet algorithm. +use std::{cell::RefCell, num::NonZeroU32, sync::Arc}; + +use augurs_core::{Fit, ModelError, Predict}; + +use crate::{ + optimizer::OptimizeOpts, Error, IncludeHistory, Optimizer, Prophet, ProphetOptions, + TrainingData, +}; + +impl ModelError for Error {} + +/// A forecaster that uses the Prophet algorithm. +/// +/// This is a wrapper around the [`Prophet`] struct that provides +/// a simpler API for fitting and predicting. Notably it implements +/// the [`Fit`] trait from `augurs_core`, so it can be +/// used with the `augurs` framework (e.g. with the `Forecaster` struct +/// in the `augurs::forecaster` module). +#[derive(Debug)] +pub struct ProphetForecaster { + opts: ProphetOptions, + optimizer: Arc, + optimize_opts: OptimizeOpts, +} + +impl ProphetForecaster { + /// Create a new Prophet forecaster. + /// + /// # Parameters + /// + /// - `opts`: The options to use for fitting the model. + /// - `optimizer`: The optimizer to use for fitting the model. + /// - `optimize_opts`: The options to use for optimizing the model. + pub fn new( + mut opts: ProphetOptions, + optimizer: Arc, + optimize_opts: OptimizeOpts, + ) -> Self { + if opts.uncertainty_samples == 0 { + opts.uncertainty_samples = 1000; + } + Self { + opts, + optimizer, + optimize_opts, + } + } +} + +impl Fit for ProphetForecaster { + type Fitted = FittedProphetForecaster; + type Error = Error; + + fn fit(&self, y: &[f64]) -> Result { + let ds = vec![]; + let training_data = TrainingData::new(ds, y.to_vec())?; + let mut model = Prophet::new(self.opts.clone(), self.optimizer.clone()); + model.fit(training_data, self.optimize_opts.clone())?; + Ok(FittedProphetForecaster { + model: RefCell::new(model), + training_n: y.len(), + }) + } +} + +/// A fitted Prophet forecaster. +#[derive(Debug)] +pub struct FittedProphetForecaster { + model: RefCell>>, + training_n: usize, +} + +impl Predict for FittedProphetForecaster { + type Error = Error; + + fn predict_in_sample_inplace( + &self, + level: Option, + forecast: &mut augurs_core::Forecast, + ) -> Result<(), Self::Error> { + if let Some(level) = level { + self.model + .borrow_mut() + .set_interval_width(level.try_into()?); + } + let predictions = self.model.borrow().predict(None)?; + forecast.point = predictions.yhat.point; + if let Some(intervals) = forecast.intervals.as_mut() { + intervals.lower = predictions + .yhat + .lower + // This `expect` is OK because we've set uncertainty_samples > 0 in the + // `ProphetForecaster` constructor. + .expect("uncertainty_samples should be > 0"); + intervals.upper = predictions + .yhat + .upper + // This `expect` is OK because we've set uncertainty_samples > 0 in the + // `ProphetForecaster` constructor. + .expect("uncertainty_samples should be > 0"); + } + Ok(()) + } + + fn predict_inplace( + &self, + horizon: usize, + level: Option, + forecast: &mut augurs_core::Forecast, + ) -> Result<(), Self::Error> { + if horizon == 0 { + return Ok(()); + } + if let Some(level) = level { + self.model + .borrow_mut() + .set_interval_width(level.try_into()?); + } + let predictions = { + let model = self.model.borrow(); + let prediction_data = model.make_future_dataframe( + NonZeroU32::try_from(horizon as u32).expect("horizon should be > 0"), + IncludeHistory::No, + )?; + model.predict(prediction_data)? + }; + forecast.point = predictions.yhat.point; + if let Some(intervals) = forecast.intervals.as_mut() { + intervals.lower = predictions + .yhat + .lower + // This `expect` is OK because we've set uncertainty_samples > 0 in the + // `ProphetForecaster` constructor. + .expect("uncertainty_samples should be > 0"); + intervals.upper = predictions + .yhat + .upper + // This `expect` is OK because we've set uncertainty_samples > 0 in the + // `ProphetForecaster` constructor. + .expect("uncertainty_samples should be > 0"); + } + Ok(()) + } + + fn training_data_size(&self) -> usize { + self.training_n + } +} diff --git a/crates/augurs-prophet/src/lib.rs b/crates/augurs-prophet/src/lib.rs index 457bb2ac..1c085172 100644 --- a/crates/augurs-prophet/src/lib.rs +++ b/crates/augurs-prophet/src/lib.rs @@ -2,6 +2,7 @@ mod data; mod error; mod features; +pub mod forecaster; // Export the optimizer module so that users can implement their own // optimizers. pub mod optimizer; diff --git a/crates/augurs-prophet/src/optimizer.rs b/crates/augurs-prophet/src/optimizer.rs index 17268be1..99899a0c 100644 --- a/crates/augurs-prophet/src/optimizer.rs +++ b/crates/augurs-prophet/src/optimizer.rs @@ -21,7 +21,7 @@ // WASM Components? // TODO: write a pure Rust optimizer for the default case. -use std::fmt; +use std::{fmt, sync::Arc}; use crate::positive_float::PositiveFloat; @@ -299,6 +299,17 @@ pub trait Optimizer: std::fmt::Debug { ) -> Result; } +impl Optimizer for Arc { + fn optimize( + &self, + init: &InitialParams, + data: &Data, + opts: &OptimizeOpts, + ) -> Result { + (**self).optimize(init, data, opts) + } +} + #[cfg(test)] pub(crate) mod mock_optimizer { use std::cell::RefCell; diff --git a/crates/augurs-prophet/src/prophet.rs b/crates/augurs-prophet/src/prophet.rs index 9cb5ff64..954f49a2 100644 --- a/crates/augurs-prophet/src/prophet.rs +++ b/crates/augurs-prophet/src/prophet.rs @@ -13,8 +13,8 @@ use prep::{ComponentColumns, Modes, Preprocessed, Scales}; use crate::{ optimizer::{InitialParams, OptimizeOpts, OptimizedParams, Optimizer}, - Error, EstimationMode, FeaturePrediction, IncludeHistory, PredictionData, Predictions, - Regressor, Seasonality, TimestampSeconds, TrainingData, + Error, EstimationMode, FeaturePrediction, IncludeHistory, IntervalWidth, PredictionData, + Predictions, Regressor, Seasonality, TimestampSeconds, TrainingData, }; /// The Prophet time series forecasting model. @@ -233,6 +233,15 @@ impl Prophet { Ok(PredictionData::new(ds)) } + /// Set the width of the uncertainty intervals. + /// + /// The interval width does not affect training, only predictions, + /// so this can be called after fitting the model to obtain predictions + /// with different levels of uncertainty. + pub fn set_interval_width(&mut self, interval_width: IntervalWidth) { + self.opts.interval_width = interval_width; + } + fn infer_freq(history_dates: &[TimestampSeconds]) -> Result { const INFER_N: usize = 5; let get_tried = || { From 7e4eee60d684096275a4ae2f56d7fea0bcdeaa19 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Tue, 3 Dec 2024 11:31:56 +0000 Subject: [PATCH 2/2] Rejig traits so Prophet forecaster actually works I'm not a huge fan of the API here, may rethink it before merging. --- crates/augurs-core/src/lib.rs | 2 +- crates/augurs-core/src/traits.rs | 116 +++++++++++++++++++-- crates/augurs-ets/src/auto.rs | 6 +- crates/augurs-ets/src/trend.rs | 2 +- crates/augurs-forecaster/src/data.rs | 35 ------- crates/augurs-forecaster/src/forecaster.rs | 16 +-- crates/augurs-forecaster/src/lib.rs | 2 - crates/augurs-mstl/src/lib.rs | 18 ++-- crates/augurs-mstl/src/trend.rs | 6 +- crates/augurs-prophet/src/error.rs | 5 + crates/augurs-prophet/src/forecaster.rs | 26 +++-- crates/pyaugurs/src/ets.rs | 6 +- crates/pyaugurs/src/mstl.rs | 6 +- crates/pyaugurs/src/trend.rs | 2 +- js/augurs-ets-js/src/lib.rs | 2 +- js/augurs-mstl-js/src/lib.rs | 2 +- 16 files changed, 168 insertions(+), 84 deletions(-) delete mode 100644 crates/augurs-forecaster/src/data.rs diff --git a/crates/augurs-core/src/lib.rs b/crates/augurs-core/src/lib.rs index 63e74988..516400f7 100644 --- a/crates/augurs-core/src/lib.rs +++ b/crates/augurs-core/src/lib.rs @@ -15,7 +15,7 @@ use std::convert::Infallible; pub use distance::DistanceMatrix; pub use forecast::{Forecast, ForecastIntervals}; -pub use traits::{Fit, Predict}; +pub use traits::{Data, Fit, MutableData, Predict}; /// An error produced by a time series forecasting model. pub trait ModelError: std::error::Error + Sync + Send + 'static {} diff --git a/crates/augurs-core/src/traits.rs b/crates/augurs-core/src/traits.rs index af266385..7c6c254e 100644 --- a/crates/augurs-core/src/traits.rs +++ b/crates/augurs-core/src/traits.rs @@ -1,7 +1,99 @@ use crate::{Forecast, ModelError}; +/// Trait for data that can be used as an input to [`Fit`]. +/// +/// This trait is implemented for a number of types including slices, arrays, and +/// vectors. It is also implemented for references to these types. +pub trait Data { + /// Return the data as a slice of `f64`. + fn as_slice(&self) -> &[f64]; +} + +impl Data for [f64; N] { + fn as_slice(&self) -> &[f64] { + self + } +} + +impl Data for &[f64] { + fn as_slice(&self) -> &[f64] { + self + } +} + +impl Data for &mut [f64] { + fn as_slice(&self) -> &[f64] { + self + } +} + +impl Data for Vec { + fn as_slice(&self) -> &[f64] { + self.as_slice() + } +} + +impl Data for &T +where + T: Data, +{ + fn as_slice(&self) -> &[f64] { + (**self).as_slice() + } +} + +impl Data for &mut T +where + T: Data, +{ + fn as_slice(&self) -> &[f64] { + (**self).as_slice() + } +} + +/// Trait for data that can be used in the forecaster. +/// +/// This trait is implemented for a number of types including slices, arrays, and +/// vectors. It is also implemented for references to these types. +pub trait MutableData: Data { + /// Update the `y` values to those in the provided slice. + fn set(&mut self, y: Vec); +} + +impl MutableData for [f64; N] { + fn set(&mut self, y: Vec) { + self.copy_from_slice(y.as_slice()); + } +} + +impl MutableData for &mut [f64] { + fn set(&mut self, y: Vec) { + self.copy_from_slice(y.as_slice()); + } +} + +impl MutableData for Vec { + fn set(&mut self, y: Vec) { + self.copy_from_slice(y.as_slice()); + } +} + +impl MutableData for &mut T +where + T: MutableData, +{ + fn set(&mut self, y: Vec) { + (**self).set(y); + } +} + /// A new, unfitted time series forecasting model. pub trait Fit { + /// The type of the training data used to fit the model. + type TrainingData<'a>: Data + where + Self: 'a; + /// The type of the fitted model produced by the `fit` method. type Fitted: Predict; @@ -9,19 +101,21 @@ pub trait Fit { type Error: ModelError; /// Fit the model to the training data. - fn fit(&self, y: &[f64]) -> Result; + fn fit<'a, 'b: 'a>(&'b self, y: Self::TrainingData<'a>) -> Result; } -impl Fit for Box -where - F: Fit, -{ - type Fitted = F::Fitted; - type Error = F::Error; - fn fit(&self, y: &[f64]) -> Result { - (**self).fit(y) - } -} +// impl<'a, F, TD> Fit for Box +// where +// F: Fit = TD>, +// TD: Data, +// { +// type TrainingData = TD; +// type Fitted = F::Fitted; +// type Error = F::Error; +// fn fit(&self, y: Self::TrainingData<'a>) -> Result { +// (**self).fit(y) +// } +// } /// A fitted time series forecasting model. pub trait Predict { diff --git a/crates/augurs-ets/src/auto.rs b/crates/augurs-ets/src/auto.rs index 81bd9a84..ac55fdb0 100644 --- a/crates/augurs-ets/src/auto.rs +++ b/crates/augurs-ets/src/auto.rs @@ -459,8 +459,10 @@ impl AutoETS { } impl Fit for AutoETS { + type TrainingData<'a> = &'a [f64]; type Fitted = FittedAutoETS; type Error = Error; + /// Search for the best model, fitting it to the data. /// /// The model is stored on the `AutoETS` struct and can be retrieved with @@ -470,7 +472,7 @@ impl Fit for AutoETS { /// /// If no model can be found, or if any parameters are invalid, this function /// returns an error. - fn fit(&self, y: &[f64]) -> Result { + fn fit<'a, 'b: 'a>(&'b self, y: Self::TrainingData<'a>) -> Result { let data_positive = y.iter().fold(f64::INFINITY, |a, &b| a.min(b)) > 0.0; if self.spec.error == ErrorSpec::Multiplicative && !data_positive { return Err(Error::InvalidModelSpec(format!( @@ -627,7 +629,7 @@ mod test { #[test] fn air_passengers_fit() { let auto = AutoETS::new(1, "ZZN").unwrap(); - let fit = auto.fit(AIR_PASSENGERS).expect("fit failed"); + let fit = auto.fit(&mut AIR_PASSENGERS.to_vec()).expect("fit failed"); assert_eq!(fit.model.model_type().error, ErrorComponent::Multiplicative); assert_eq!(fit.model.model_type().trend, TrendComponent::Additive); assert_eq!(fit.model.model_type().season, SeasonalComponent::None); diff --git a/crates/augurs-ets/src/trend.rs b/crates/augurs-ets/src/trend.rs index 06d6498e..a9583b1b 100644 --- a/crates/augurs-ets/src/trend.rs +++ b/crates/augurs-ets/src/trend.rs @@ -33,7 +33,7 @@ impl TrendModel for AutoETSTrendModel { fn fit( &self, - y: &[f64], + y: &mut [f64], ) -> Result< Box, Box, diff --git a/crates/augurs-forecaster/src/data.rs b/crates/augurs-forecaster/src/data.rs deleted file mode 100644 index 91d82899..00000000 --- a/crates/augurs-forecaster/src/data.rs +++ /dev/null @@ -1,35 +0,0 @@ -/// Trait for data that can be used in the forecaster. -/// -/// This trait is implemented for a number of types including slices, arrays, and -/// vectors. It is also implemented for references to these types. -pub trait Data { - /// Return the data as a slice of `f64`. - fn as_slice(&self) -> &[f64]; -} - -impl Data for [f64; N] { - fn as_slice(&self) -> &[f64] { - self - } -} - -impl Data for &[f64] { - fn as_slice(&self) -> &[f64] { - self - } -} - -impl Data for Vec { - fn as_slice(&self) -> &[f64] { - self.as_slice() - } -} - -impl Data for &T -where - T: Data, -{ - fn as_slice(&self) -> &[f64] { - (*self).as_slice() - } -} diff --git a/crates/augurs-forecaster/src/forecaster.rs b/crates/augurs-forecaster/src/forecaster.rs index 6c27b0f3..3823a759 100644 --- a/crates/augurs-forecaster/src/forecaster.rs +++ b/crates/augurs-forecaster/src/forecaster.rs @@ -1,6 +1,6 @@ -use augurs_core::{Fit, Forecast, Predict}; +use augurs_core::{Data, Fit, Forecast, MutableData, Predict}; -use crate::{Data, Error, Result, Transform, Transforms}; +use crate::{Error, Result, Transform, Transforms}; /// A high-level API to fit and predict time series forecasting models. /// @@ -16,10 +16,11 @@ pub struct Forecaster { transforms: Transforms, } -impl Forecaster +impl<'a, M: 'a> Forecaster where M: Fit, M::Fitted: Predict, + M::TrainingData<'a>: MutableData, { /// Create a new `Forecaster` with the given model. pub fn new(model: M) -> Self { @@ -37,12 +38,13 @@ where } /// Fit the model to the given time series. - pub fn fit(&mut self, y: D) -> Result<()> { + pub fn fit<'b: 'a>(&'b mut self, mut td: M::TrainingData<'a>) -> Result<()> { let data: Vec<_> = self .transforms - .transform(y.as_slice().iter().copied()) + .transform(td.as_slice().iter().copied()) .collect(); - self.fitted = Some(self.model.fit(&data).map_err(|e| Error::Fit { + td.set(data); + self.fitted = Some(self.model.fit(td).map_err(|e| Error::Fit { source: Box::new(e) as _, })?); Ok(()) @@ -105,7 +107,7 @@ mod test { #[test] fn test_forecaster() { - let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0]; + let data = &mut [1.0_f64, 2.0, 3.0, 4.0, 5.0]; let MinMaxResult::MinMax(min, max) = data .iter() .copied() diff --git a/crates/augurs-forecaster/src/lib.rs b/crates/augurs-forecaster/src/lib.rs index d8c2c689..897064cc 100644 --- a/crates/augurs-forecaster/src/lib.rs +++ b/crates/augurs-forecaster/src/lib.rs @@ -1,11 +1,9 @@ #![doc = include_str!("../README.md")] -mod data; mod error; mod forecaster; pub mod transforms; -pub use data::Data; pub use error::Error; pub use forecaster::Forecaster; pub use transforms::Transform; diff --git a/crates/augurs-mstl/src/lib.rs b/crates/augurs-mstl/src/lib.rs index 19f1d07a..1711c147 100644 --- a/crates/augurs-mstl/src/lib.rs +++ b/crates/augurs-mstl/src/lib.rs @@ -105,14 +105,14 @@ impl MSTLModel { // Determine the differencing term for the trend component. let trend = fit.trend(); let residual = fit.remainder(); - let deseasonalised = trend + let mut deseasonalised = trend .iter() .zip(residual) .map(|(t, r)| (t + r) as f64) .collect::>(); let fitted_trend_model = self .trend_model - .fit(&deseasonalised) + .fit(&mut deseasonalised) .map_err(Error::TrendModel)?; tracing::trace!( trend_model = ?self.trend_model, @@ -234,9 +234,13 @@ impl FittedMSTLModel { impl ModelError for Error {} impl augurs_core::Fit for MSTLModel { + type TrainingData<'a> = &'a mut [f64] where T: 'a; type Fitted = FittedMSTLModel; type Error = Error; - fn fit(&self, y: &[f64]) -> Result { + fn fit<'a, 'b: 'a>(&self, y: Self::TrainingData<'a>) -> Result + where + T: 'a, + { self.fit_impl(y) } } @@ -271,7 +275,7 @@ mod tests { #[test] fn results_match_r() { - let y = VIC_ELEC.clone(); + let mut y = VIC_ELEC.clone(); let mut stl_params = stlrs::params(); stl_params @@ -287,7 +291,7 @@ mod tests { let periods = vec![24, 24 * 7]; let trend_model = NaiveTrend::new(); let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params); - let fit = mstl.fit(&y).unwrap(); + let fit = mstl.fit(&mut y).unwrap(); let in_sample = fit.predict_in_sample(0.95).unwrap(); // The first 12 values from R. @@ -332,7 +336,7 @@ mod tests { #[test] fn predict_zero_horizon() { - let y = VIC_ELEC.clone(); + let mut y = VIC_ELEC.clone(); let mut stl_params = stlrs::params(); stl_params @@ -348,7 +352,7 @@ mod tests { let periods = vec![24, 24 * 7]; let trend_model = NaiveTrend::new(); let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params); - let fit = mstl.fit(&y).unwrap(); + let fit = mstl.fit(&mut y).unwrap(); let forecast = fit.predict(0, 0.95).unwrap(); assert!(forecast.point.is_empty()); let ForecastIntervals { lower, upper, .. } = forecast.intervals.unwrap(); diff --git a/crates/augurs-mstl/src/trend.rs b/crates/augurs-mstl/src/trend.rs index 24afb667..cb9a4ff5 100644 --- a/crates/augurs-mstl/src/trend.rs +++ b/crates/augurs-mstl/src/trend.rs @@ -28,7 +28,7 @@ pub trait TrendModel: Debug { /// Implementations should store any state required for prediction in the struct itself. fn fit( &self, - y: &[f64], + y: &mut [f64], ) -> Result< Box, Box, @@ -121,7 +121,7 @@ impl TrendModel for Box { fn fit( &self, - y: &[f64], + y: &mut [f64], ) -> Result< Box, Box, @@ -196,7 +196,7 @@ impl TrendModel for NaiveTrend { fn fit( &self, - y: &[f64], + y: &mut [f64], ) -> Result< Box, Box, diff --git a/crates/augurs-prophet/src/error.rs b/crates/augurs-prophet/src/error.rs index e6599a33..8ebec7d4 100644 --- a/crates/augurs-prophet/src/error.rs +++ b/crates/augurs-prophet/src/error.rs @@ -82,4 +82,9 @@ pub enum Error { /// there is no frequency that appears more often than others. #[error("Unable to infer frequency from dates: {0:?}")] UnableToInferFrequency(Vec), + /// The provided horizon was invalid. + /// + /// The horizon must be greater than 0. + #[error("Horizon must be > 0, got {0}")] + InvalidHorizon(usize), } diff --git a/crates/augurs-prophet/src/forecaster.rs b/crates/augurs-prophet/src/forecaster.rs index 656b65b8..26b681aa 100644 --- a/crates/augurs-prophet/src/forecaster.rs +++ b/crates/augurs-prophet/src/forecaster.rs @@ -1,7 +1,7 @@ //! [`Fit`] and [`Predict`] implementations for the Prophet algorithm. use std::{cell::RefCell, num::NonZeroU32, sync::Arc}; -use augurs_core::{Fit, ModelError, Predict}; +use augurs_core::{Data, Fit, ModelError, MutableData, Predict}; use crate::{ optimizer::OptimizeOpts, Error, IncludeHistory, Optimizer, Prophet, ProphetOptions, @@ -48,18 +48,32 @@ impl ProphetForecaster { } } +impl Data for TrainingData { + fn as_slice(&self) -> &[f64] { + self.y.as_slice() + } +} +impl MutableData for TrainingData { + fn set(&mut self, y: Vec) { + self.y = y; + } +} + impl Fit for ProphetForecaster { + type TrainingData<'a> = TrainingData; type Fitted = FittedProphetForecaster; type Error = Error; - fn fit(&self, y: &[f64]) -> Result { - let ds = vec![]; - let training_data = TrainingData::new(ds, y.to_vec())?; + fn fit<'a, 'b: 'a>( + &'b self, + training_data: Self::TrainingData<'b>, + ) -> Result { + let training_n = training_data.y.len(); let mut model = Prophet::new(self.opts.clone(), self.optimizer.clone()); model.fit(training_data, self.optimize_opts.clone())?; Ok(FittedProphetForecaster { model: RefCell::new(model), - training_n: y.len(), + training_n, }) } } @@ -120,7 +134,7 @@ impl Predict for FittedProphetForecaster { let predictions = { let model = self.model.borrow(); let prediction_data = model.make_future_dataframe( - NonZeroU32::try_from(horizon as u32).expect("horizon should be > 0"), + NonZeroU32::try_from(horizon as u32).map_err(|_| Error::InvalidHorizon(horizon))?, IncludeHistory::No, )?; model.predict(prediction_data)? diff --git a/crates/pyaugurs/src/ets.rs b/crates/pyaugurs/src/ets.rs index cc86a20d..fc23055d 100644 --- a/crates/pyaugurs/src/ets.rs +++ b/crates/pyaugurs/src/ets.rs @@ -1,6 +1,6 @@ //! Bindings for AutoETS model search. use augurs_core::{Fit, Predict}; -use numpy::PyReadonlyArrayDyn; +use numpy::PyReadwriteArrayDyn; use pyo3::{exceptions::PyException, prelude::*}; use crate::Forecast; @@ -48,9 +48,9 @@ impl AutoETS { /// /// If no model can be found, or if any parameters are invalid, this function /// returns an error. - pub fn fit(&mut self, y: PyReadonlyArrayDyn<'_, f64>) -> PyResult<()> { + pub fn fit(&mut self, mut y: PyReadwriteArrayDyn<'_, f64>) -> PyResult<()> { self.inner - .fit(y.as_slice()?) + .fit(y.as_slice_mut()?) .map_err(|e| PyException::new_err(e.to_string()))?; Ok(()) } diff --git a/crates/pyaugurs/src/mstl.rs b/crates/pyaugurs/src/mstl.rs index 37a132d2..1f42cd51 100644 --- a/crates/pyaugurs/src/mstl.rs +++ b/crates/pyaugurs/src/mstl.rs @@ -1,6 +1,6 @@ //! Bindings for Multiple Seasonal Trend using LOESS (MSTL). -use numpy::PyReadonlyArray1; +use numpy::PyReadwriteArray1; use pyo3::{exceptions::PyException, prelude::*, types::PyType}; use augurs_ets::{trend::AutoETSTrendModel, AutoETS}; @@ -74,9 +74,9 @@ impl MSTL { } /// Fit the model to the given time series. - pub fn fit(&mut self, y: PyReadonlyArray1<'_, f64>) -> PyResult<()> { + pub fn fit(&mut self, mut y: PyReadwriteArray1<'_, f64>) -> PyResult<()> { self.forecaster - .fit(y.as_slice()?) + .fit(y.as_slice_mut()?) .map_err(|e| PyException::new_err(format!("error fitting model: {e}")))?; self.fit = true; Ok(()) diff --git a/crates/pyaugurs/src/trend.rs b/crates/pyaugurs/src/trend.rs index e73aa052..f08e4003 100644 --- a/crates/pyaugurs/src/trend.rs +++ b/crates/pyaugurs/src/trend.rs @@ -64,7 +64,7 @@ impl TrendModel for PyTrendModel { fn fit( &self, - y: &[f64], + y: &mut [f64], ) -> Result< Box, Box, diff --git a/js/augurs-ets-js/src/lib.rs b/js/augurs-ets-js/src/lib.rs index 4b227820..47e4dc5d 100644 --- a/js/augurs-ets-js/src/lib.rs +++ b/js/augurs-ets-js/src/lib.rs @@ -42,7 +42,7 @@ impl AutoETS { /// returns an error. #[wasm_bindgen] pub fn fit(&mut self, y: VecF64) -> Result<(), JsError> { - self.fitted = Some(self.inner.fit(&y.convert()?)?); + self.fitted = Some(self.inner.fit(&mut y.convert()?)?); Ok(()) } diff --git a/js/augurs-mstl-js/src/lib.rs b/js/augurs-mstl-js/src/lib.rs index 59590f6f..2c0f652d 100644 --- a/js/augurs-mstl-js/src/lib.rs +++ b/js/augurs-mstl-js/src/lib.rs @@ -56,7 +56,7 @@ impl MSTL { #[wasm_bindgen] pub fn fit(&mut self, y: VecF64) -> Result<(), JsValue> { self.forecaster - .fit(y.convert()?) + .fit(&mut y.convert()?) .map_err(|e| e.to_string())?; Ok(()) }