diff --git a/crates/augurs-mstl/Cargo.toml b/crates/augurs-mstl/Cargo.toml index 8944f667..25005fba 100644 --- a/crates/augurs-mstl/Cargo.toml +++ b/crates/augurs-mstl/Cargo.toml @@ -12,7 +12,7 @@ description = "Multiple Seasonal-Trend decomposition with LOESS (MSTL) using the augurs-core.workspace = true distrs.workspace = true serde = { workspace = true, features = ["derive"], optional = true } -stlrs = { git = "https://github.com/sd2k/stl-rust", branch = "python-lib", version = "0.2.1" } +stlrs = "0.3.0" thiserror.workspace = true tracing.workspace = true diff --git a/crates/augurs-mstl/benches/vic_elec.rs b/crates/augurs-mstl/benches/vic_elec.rs index 42b8f70e..5689a8d6 100644 --- a/crates/augurs-mstl/benches/vic_elec.rs +++ b/crates/augurs-mstl/benches/vic_elec.rs @@ -16,12 +16,14 @@ fn vic_elec(c: &mut Criterion) { .low_pass_degree(1) .inner_loops(2) .outer_loops(0); + let mut mstl_params = stlrs::MstlParams::new(); + mstl_params.stl_params(stl_params); c.bench_function("vic_elec", |b| { b.iter_batched( - || (y.clone(), vec![24, 24 * 7], stl_params.clone()), + || (y.clone(), vec![24, 24 * 7], mstl_params.clone()), |(y, periods, stl_params)| { MSTLModel::new(periods, NaiveTrend::new()) - .stl_params(stl_params) + .mstl_params(stl_params) .fit(&y) }, BatchSize::SmallInput, diff --git a/crates/augurs-mstl/benches/vic_elec_iai.rs b/crates/augurs-mstl/benches/vic_elec_iai.rs index 843d465c..5dd3c1aa 100644 --- a/crates/augurs-mstl/benches/vic_elec_iai.rs +++ b/crates/augurs-mstl/benches/vic_elec_iai.rs @@ -3,9 +3,9 @@ use iai::{black_box, main}; use augurs_mstl::{MSTLModel, NaiveTrend}; use augurs_testing::data::VIC_ELEC; -fn vic_elec_fit(y: Vec, periods: Vec, params: stlrs::StlParams) { +fn vic_elec_fit(y: Vec, periods: Vec, params: stlrs::MstlParams) { MSTLModel::new(periods, NaiveTrend::new()) - .stl_params(params) + .mstl_params(params) .fit(&y) .ok(); } @@ -22,10 +22,12 @@ fn bench_vic_elec_fit() { .low_pass_degree(1) .inner_loops(2) .outer_loops(0); + let mut mstl_params = stlrs::MstlParams::new(); + mstl_params.stl_params(stl_params); vic_elec_fit( black_box(y.clone()), black_box(vec![24, 24 * 7]), - black_box(stl_params), + black_box(mstl_params), ); } diff --git a/crates/augurs-mstl/src/lib.rs b/crates/augurs-mstl/src/lib.rs index d5b8598e..c2347d16 100644 --- a/crates/augurs-mstl/src/lib.rs +++ b/crates/augurs-mstl/src/lib.rs @@ -8,17 +8,17 @@ use std::marker::PhantomData; +use stlrs::MstlResult; use tracing::instrument; use augurs_core::{Forecast, ForecastIntervals}; // mod approx; -pub mod mstl; +// pub mod mstl; // mod stationarity; mod trend; // mod utils; -use crate::mstl::{MSTLDecomposition, MSTL}; pub use crate::trend::{NaiveTrend, TrendModel}; /// A marker struct indicating that a model is fit. @@ -54,11 +54,11 @@ type Result = std::result::Result; pub struct MSTLModel { /// Periodicity of the seasonal components. periods: Vec, - stl_params: stlrs::StlParams, + mstl_params: stlrs::MstlParams, state: PhantomData, - decomposed: Option, + fit: Option, trend_model: T, } @@ -86,15 +86,18 @@ impl MSTLModel { Self { periods, state: PhantomData, - stl_params: stlrs::params(), - decomposed: None, + mstl_params: stlrs::MstlParams::new(), + fit: None, trend_model, } } - /// Set the parameters for the STL algorithm. - pub fn stl_params(mut self, params: stlrs::StlParams) -> Self { - self.stl_params = params; + /// Set the parameters for the MSTL algorithm. + /// + /// This can be used to control the parameters for the inner STL algorithm + /// by using [`MstlParams::stl_params`]. + pub fn mstl_params(mut self, params: stlrs::MstlParams) -> Self { + self.mstl_params = params; self } @@ -109,17 +112,15 @@ impl MSTLModel { /// are also propagated. #[instrument(skip_all)] pub fn fit(mut self, y: &[f64]) -> Result> { - // Run STL for each season length. - let decomposed = MSTL::new(y, &mut self.periods) - .stl_params(self.stl_params.clone()) - .fit()?; + let y = y.iter().copied().map(|y| y as f32).collect::>(); + let fit = self.mstl_params.fit(&y, &self.periods)?; // Determine the differencing term for the trend component. - let trend = decomposed.trend(); - let residual = decomposed.residuals(); + let trend = fit.trend(); + let residual = fit.remainder(); let deseasonalised = trend .iter() .zip(residual) - .map(|(t, r)| t + r) + .map(|(t, r)| (t + r) as f64) .collect::>(); self.trend_model .fit(&deseasonalised) @@ -130,9 +131,9 @@ impl MSTLModel { ); Ok(MSTLModel { periods: self.periods, - stl_params: self.stl_params, + mstl_params: self.mstl_params, state: PhantomData, - decomposed: Some(decomposed), + fit: Some(fit), trend_model: self.trend_model, }) } @@ -196,35 +197,32 @@ impl MSTLModel { } fn add_seasonal_in_sample(&self, trend: &mut Forecast) { - self.decomposed() - .seasonals() - .values() - .for_each(|component| { - let period_contributions = component.iter().zip(trend.point.iter_mut()); - match &mut trend.intervals { - None => period_contributions.for_each(|(c, p)| *p += c), - Some(ForecastIntervals { - ref mut lower, - ref mut upper, - .. - }) => { - period_contributions - .zip(lower.iter_mut()) - .zip(upper.iter_mut()) - .for_each(|(((c, p), l), u)| { - *p += c; - *l += c; - *u += c; - }); - } + self.fit().seasonal().iter().for_each(|component| { + let period_contributions = component.iter().zip(trend.point.iter_mut()); + match &mut trend.intervals { + None => period_contributions.for_each(|(c, p)| *p += *c as f64), + Some(ForecastIntervals { + ref mut lower, + ref mut upper, + .. + }) => { + period_contributions + .zip(lower.iter_mut()) + .zip(upper.iter_mut()) + .for_each(|(((c, p), l), u)| { + *p += *c as f64; + *l += *c as f64; + *u += *c as f64; + }); } - }); + } + }); } fn add_seasonal_out_of_sample(&self, trend: &mut Forecast) { - self.decomposed() - .seasonals() + self.periods .iter() + .zip(self.fit().seasonal()) .for_each(|(period, component)| { // For each seasonal period we're going to create a cycle iterator // which will repeat the seasonal component every `period` steps. @@ -238,7 +236,7 @@ impl MSTLModel { .cycle() .zip(trend.point.iter_mut()); match &mut trend.intervals { - None => period_contributions.for_each(|(c, p)| *p += c), + None => period_contributions.for_each(|(c, p)| *p += c as f64), Some(ForecastIntervals { ref mut lower, ref mut upper, @@ -248,18 +246,18 @@ impl MSTLModel { .zip(lower.iter_mut()) .zip(upper.iter_mut()) .for_each(|(((c, p), l), u)| { - *p += c; - *l += c; - *u += c; + *p += c as f64; + *l += c as f64; + *u += c as f64; }); } } }); } - /// Return the MSTL decomposition of the training data. - pub fn decomposed(&self) -> &MSTLDecomposition { - self.decomposed.as_ref().unwrap() + /// Return the MSTL fit of the training data. + pub fn fit(&self) -> &MstlResult { + self.fit.as_ref().unwrap() } } @@ -277,7 +275,7 @@ mod tests { if actual.is_nan() { assert!(expected.is_nan()); } else { - assert_approx_eq!(actual, expected, 1e-2); + assert_approx_eq!(actual, expected, 1e-1); } } } @@ -286,8 +284,8 @@ mod tests { fn results_match_r() { let y = VIC_ELEC.clone(); - let mut params = stlrs::params(); - params + let mut stl_params = stlrs::params(); + stl_params .seasonal_degree(0) .seasonal_jump(1) .trend_degree(1) @@ -295,9 +293,11 @@ mod tests { .low_pass_degree(1) .inner_loops(2) .outer_loops(0); + let mut mstl_params = stlrs::MstlParams::new(); + mstl_params.stl_params(stl_params); let periods = vec![24, 24 * 7]; let trend_model = NaiveTrend::new(); - let mstl = MSTLModel::new(periods, trend_model).stl_params(params); + let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params); let fit = mstl.fit(&y).unwrap(); let in_sample = fit.predict_in_sample(0.95).unwrap(); @@ -344,8 +344,8 @@ mod tests { fn predict_zero_horizon() { let y = VIC_ELEC.clone(); - let mut params = stlrs::params(); - params + let mut stl_params = stlrs::params(); + stl_params .seasonal_degree(0) .seasonal_jump(1) .trend_degree(1) @@ -353,9 +353,11 @@ mod tests { .low_pass_degree(1) .inner_loops(2) .outer_loops(0); + let mut mstl_params = stlrs::MstlParams::new(); + mstl_params.stl_params(stl_params); let periods = vec![24, 24 * 7]; let trend_model = NaiveTrend::new(); - let mstl = MSTLModel::new(periods, trend_model).stl_params(params); + let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params); let fit = mstl.fit(&y).unwrap(); let forecast = fit.predict(0, 0.95).unwrap(); assert!(forecast.point.is_empty()); diff --git a/crates/augurs-mstl/src/mstl.rs b/crates/augurs-mstl/src/mstl.rs index bc8f7aab..da406ec6 100644 --- a/crates/augurs-mstl/src/mstl.rs +++ b/crates/augurs-mstl/src/mstl.rs @@ -25,7 +25,7 @@ use crate::{Error, Result}; #[allow(clippy::upper_case_acronyms)] pub struct MSTL<'a> { /// Time series to decompose. - y: &'a [f64], + y: Vec, /// Periodicity of the seasonal components. periods: &'a mut Vec, /// Parameters for the STL decomposition. @@ -36,9 +36,9 @@ impl<'a> MSTL<'a> { /// Create a new MSTL decomposition. /// /// Call `fit` to run the decomposition. - pub fn new(y: &'a [f64], periods: &'a mut Vec) -> Self { + pub fn new(y: impl Iterator, periods: &'a mut Vec) -> Self { Self { - y, + y: y.collect::>(), periods, stl_params: stlrs::params(), } @@ -57,51 +57,59 @@ impl<'a> MSTL<'a> { let seasonal_windows: Vec = self.seasonal_windows(); let iterate = if self.periods.len() == 1 { 1 } else { 2 }; - let mut seasonals: HashMap> = self - .periods - .iter() - .copied() - .map(|p| (p, vec![0.0; self.y.len()])) - .collect(); - let mut deseas = self.y.to_vec(); - let mut res: Option> = None; + let mut seasonals: HashMap = HashMap::with_capacity(self.periods.len()); + // self.periods.iter().copied().map(|p| (p, None)).collect(); + let mut deseas = self.y; + let mut res: Option = None; for i in 0..iterate { let zipped = self.periods.iter().zip(seasonal_windows.iter()); for (period, seasonal_window) in zipped { - let seas = seasonals.get_mut(period).unwrap(); + let seas = seasonals.entry(*period); // Start by adding on the seasonal effect. - deseas - .iter_mut() - .zip(seas.iter()) - .for_each(|(d, s)| *d += *s); + if let std::collections::hash_map::Entry::Occupied(ref seas) = seas { + deseas + .iter_mut() + .zip(seas.get().seasonal().iter()) + .for_each(|(d, s)| *d += *s); + } // Decompose the time series for specific seasonal period. - let mut fit = tracing::debug_span!("STL.fit", i, seasonal_window, period) - .in_scope(|| { + let fit = + tracing::debug_span!("STL.fit", i, seasonal_window, period).in_scope(|| { self.stl_params .seasonal_length(*seasonal_window) .fit(&deseas, *period) })?; - *seas = std::mem::take(&mut fit.seasonal); - res = Some(fit); // Subtract the seasonal effect again. deseas .iter_mut() - .zip(seas.iter()) + .zip(fit.seasonal().iter()) .for_each(|(d, s)| *d -= *s); + match seas { + std::collections::hash_map::Entry::Occupied(mut o) => { + o.insert(fit.clone()); + } + std::collections::hash_map::Entry::Vacant(x) => { + x.insert(fit.clone()); + } + } + res = Some(fit); } } let fit = res.ok_or_else(|| Error::MSTL("no STL fit".to_string()))?; - let trend = fit.trend; + let trend = fit.trend().to_vec(); deseas .iter_mut() .zip(trend.iter()) .for_each(|(d, r)| *d -= *r); - let rw = fit.weights; + let robust_weights = fit.weights().to_vec(); Ok(MSTLDecomposition { trend, - seasonal: seasonals, + seasonal: seasonals + .into_iter() + .map(|(k, v)| (k, v.seasonal().to_vec())) + .collect(), residuals: deseas, - robust_weights: rw, + robust_weights, }) } @@ -142,39 +150,39 @@ impl<'a> MSTL<'a> { #[cfg_attr(test, derive(Default))] pub struct MSTLDecomposition { /// Trend component. - trend: Vec, + trend: Vec, /// Mapping from period to seasonal component. - seasonal: HashMap>, + seasonal: HashMap>, /// Residuals. - residuals: Vec, + residuals: Vec, /// Weights used in the robust fit. - robust_weights: Vec, + robust_weights: Vec, } impl MSTLDecomposition { /// Return the trend component. - pub fn trend(&self) -> &[f64] { + pub fn trend(&self) -> &[f32] { &self.trend } /// Return the seasonal component for a given period, /// or None if the period is not present. - pub fn seasonal(&self, period: usize) -> Option<&[f64]> { + pub fn seasonal(&self, period: usize) -> Option<&[f32]> { self.seasonal.get(&period).map(|v| v.as_slice()) } /// Return a mapping from period to seasonal component. - pub fn seasonals(&self) -> &HashMap> { + pub fn seasonals(&self) -> &HashMap> { &self.seasonal } /// Return the residuals. - pub fn residuals(&self) -> &[f64] { + pub fn residuals(&self) -> &[f32] { &self.residuals } /// Return the robust weights. - pub fn robust_weights(&self) -> &[f64] { + pub fn robust_weights(&self) -> &[f32] { &self.robust_weights } } @@ -224,29 +232,29 @@ mod tests { .inner_loops(2) .outer_loops(0); let mut periods = vec![24, 24 * 7]; - let mstl = MSTL::new(y, &mut periods).stl_params(params); + let mstl = MSTL::new(y.iter().map(|&x| x as f32), &mut periods).stl_params(params); let res = mstl.fit().unwrap(); let expected = vic_elec_results(); res.trend() .iter() .zip(expected.trend().iter()) - .for_each(|(a, b)| assert_approx_eq!(a, b, 1e-2_f64)); + .for_each(|(a, b)| assert_approx_eq!(a, b, 1.0)); res.seasonal(24) .unwrap() .iter() .zip(expected.seasonal(24).unwrap().iter()) // Some numeric instability somewhere causes this to differ by // up to 1.0 somewhere :/ - .for_each(|(&a, &b)| assert_approx_eq!(a, b, 1e1_f64)); + .for_each(|(&a, &b)| assert_approx_eq!(a, b, 1e1_f32)); res.seasonal(168) .unwrap() .iter() .zip(expected.seasonal(168).unwrap().iter()) - .for_each(|(a, b)| assert_approx_eq!(a, b, 1e-1_f64)); + .for_each(|(a, b)| assert_approx_eq!(a, b, 1e-1_f32)); res.residuals() .iter() .zip(expected.residuals().iter()) - .for_each(|(a, b)| assert_approx_eq!(a, b, 1e1_f64)); + .for_each(|(a, b)| assert_approx_eq!(a, b, 1e1_f32)); } }