From a0c4badb29198f25103b929b94e94d601a6cd6d8 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Mon, 18 Sep 2023 17:57:57 +0100 Subject: [PATCH 1/3] Accommodate restrictions of released stl-rs This mainly just means we have to clone a little more and cast things between f32 and f64, because the released version of stl-rs doesn't allow us to take ownership of the various components and doesn't include f64 compatibility. Note that this still currently points towards the main git branch for stl-rust, but at least all the incorporated changes are more likely to be merged sooner or later! Unfortunately this results in some benchmark regressions, probably due to the extra clones and having to convert to/from f32 :( ``` Running benches/vic_elec.rs (target/release/deps/vic_elec-36a0fad3091799a8) vic_elec time: [28.569 ms 28.591 ms 28.619 ms] change: [+11.550% +11.750% +11.918%] (p = 0.00 < 0.05) Performance has regressed. Found 8 outliers among 100 measurements (8.00%) 3 (3.00%) high mild 5 (5.00%) high severe ``` --- crates/augurs-mstl/Cargo.toml | 3 +- crates/augurs-mstl/src/lib.rs | 22 ++++----- crates/augurs-mstl/src/mstl.rs | 86 +++++++++++++++++++--------------- 3 files changed, 60 insertions(+), 51 deletions(-) diff --git a/crates/augurs-mstl/Cargo.toml b/crates/augurs-mstl/Cargo.toml index 8944f667..456ab380 100644 --- a/crates/augurs-mstl/Cargo.toml +++ b/crates/augurs-mstl/Cargo.toml @@ -12,7 +12,8 @@ description = "Multiple Seasonal-Trend decomposition with LOESS (MSTL) using the augurs-core.workspace = true distrs.workspace = true serde = { workspace = true, features = ["derive"], optional = true } -stlrs = { git = "https://github.com/sd2k/stl-rust", branch = "python-lib", version = "0.2.1" } +stlrs = { git = "https://github.com/ankane/stl-rust", version = "0.2.2" } +# stlrs = "0.2.2" thiserror.workspace = true tracing.workspace = true diff --git a/crates/augurs-mstl/src/lib.rs b/crates/augurs-mstl/src/lib.rs index d5b8598e..e0525fb2 100644 --- a/crates/augurs-mstl/src/lib.rs +++ b/crates/augurs-mstl/src/lib.rs @@ -110,7 +110,7 @@ impl MSTLModel { #[instrument(skip_all)] pub fn fit(mut self, y: &[f64]) -> Result> { // Run STL for each season length. - let decomposed = MSTL::new(y, &mut self.periods) + let decomposed = MSTL::new(y.iter().map(|&x| x as f32), &mut self.periods) .stl_params(self.stl_params.clone()) .fit()?; // Determine the differencing term for the trend component. @@ -119,7 +119,7 @@ impl MSTLModel { let deseasonalised = trend .iter() .zip(residual) - .map(|(t, r)| t + r) + .map(|(t, r)| (t + r) as f64) .collect::>(); self.trend_model .fit(&deseasonalised) @@ -202,7 +202,7 @@ impl MSTLModel { .for_each(|component| { let period_contributions = component.iter().zip(trend.point.iter_mut()); match &mut trend.intervals { - None => period_contributions.for_each(|(c, p)| *p += c), + None => period_contributions.for_each(|(c, p)| *p += *c as f64), Some(ForecastIntervals { ref mut lower, ref mut upper, @@ -212,9 +212,9 @@ impl MSTLModel { .zip(lower.iter_mut()) .zip(upper.iter_mut()) .for_each(|(((c, p), l), u)| { - *p += c; - *l += c; - *u += c; + *p += *c as f64; + *l += *c as f64; + *u += *c as f64; }); } } @@ -238,7 +238,7 @@ impl MSTLModel { .cycle() .zip(trend.point.iter_mut()); match &mut trend.intervals { - None => period_contributions.for_each(|(c, p)| *p += c), + None => period_contributions.for_each(|(c, p)| *p += c as f64), Some(ForecastIntervals { ref mut lower, ref mut upper, @@ -248,9 +248,9 @@ impl MSTLModel { .zip(lower.iter_mut()) .zip(upper.iter_mut()) .for_each(|(((c, p), l), u)| { - *p += c; - *l += c; - *u += c; + *p += c as f64; + *l += c as f64; + *u += c as f64; }); } } @@ -277,7 +277,7 @@ mod tests { if actual.is_nan() { assert!(expected.is_nan()); } else { - assert_approx_eq!(actual, expected, 1e-2); + assert_approx_eq!(actual, expected, 1e-1); } } } diff --git a/crates/augurs-mstl/src/mstl.rs b/crates/augurs-mstl/src/mstl.rs index bc8f7aab..da406ec6 100644 --- a/crates/augurs-mstl/src/mstl.rs +++ b/crates/augurs-mstl/src/mstl.rs @@ -25,7 +25,7 @@ use crate::{Error, Result}; #[allow(clippy::upper_case_acronyms)] pub struct MSTL<'a> { /// Time series to decompose. - y: &'a [f64], + y: Vec, /// Periodicity of the seasonal components. periods: &'a mut Vec, /// Parameters for the STL decomposition. @@ -36,9 +36,9 @@ impl<'a> MSTL<'a> { /// Create a new MSTL decomposition. /// /// Call `fit` to run the decomposition. - pub fn new(y: &'a [f64], periods: &'a mut Vec) -> Self { + pub fn new(y: impl Iterator, periods: &'a mut Vec) -> Self { Self { - y, + y: y.collect::>(), periods, stl_params: stlrs::params(), } @@ -57,51 +57,59 @@ impl<'a> MSTL<'a> { let seasonal_windows: Vec = self.seasonal_windows(); let iterate = if self.periods.len() == 1 { 1 } else { 2 }; - let mut seasonals: HashMap> = self - .periods - .iter() - .copied() - .map(|p| (p, vec![0.0; self.y.len()])) - .collect(); - let mut deseas = self.y.to_vec(); - let mut res: Option> = None; + let mut seasonals: HashMap = HashMap::with_capacity(self.periods.len()); + // self.periods.iter().copied().map(|p| (p, None)).collect(); + let mut deseas = self.y; + let mut res: Option = None; for i in 0..iterate { let zipped = self.periods.iter().zip(seasonal_windows.iter()); for (period, seasonal_window) in zipped { - let seas = seasonals.get_mut(period).unwrap(); + let seas = seasonals.entry(*period); // Start by adding on the seasonal effect. - deseas - .iter_mut() - .zip(seas.iter()) - .for_each(|(d, s)| *d += *s); + if let std::collections::hash_map::Entry::Occupied(ref seas) = seas { + deseas + .iter_mut() + .zip(seas.get().seasonal().iter()) + .for_each(|(d, s)| *d += *s); + } // Decompose the time series for specific seasonal period. - let mut fit = tracing::debug_span!("STL.fit", i, seasonal_window, period) - .in_scope(|| { + let fit = + tracing::debug_span!("STL.fit", i, seasonal_window, period).in_scope(|| { self.stl_params .seasonal_length(*seasonal_window) .fit(&deseas, *period) })?; - *seas = std::mem::take(&mut fit.seasonal); - res = Some(fit); // Subtract the seasonal effect again. deseas .iter_mut() - .zip(seas.iter()) + .zip(fit.seasonal().iter()) .for_each(|(d, s)| *d -= *s); + match seas { + std::collections::hash_map::Entry::Occupied(mut o) => { + o.insert(fit.clone()); + } + std::collections::hash_map::Entry::Vacant(x) => { + x.insert(fit.clone()); + } + } + res = Some(fit); } } let fit = res.ok_or_else(|| Error::MSTL("no STL fit".to_string()))?; - let trend = fit.trend; + let trend = fit.trend().to_vec(); deseas .iter_mut() .zip(trend.iter()) .for_each(|(d, r)| *d -= *r); - let rw = fit.weights; + let robust_weights = fit.weights().to_vec(); Ok(MSTLDecomposition { trend, - seasonal: seasonals, + seasonal: seasonals + .into_iter() + .map(|(k, v)| (k, v.seasonal().to_vec())) + .collect(), residuals: deseas, - robust_weights: rw, + robust_weights, }) } @@ -142,39 +150,39 @@ impl<'a> MSTL<'a> { #[cfg_attr(test, derive(Default))] pub struct MSTLDecomposition { /// Trend component. - trend: Vec, + trend: Vec, /// Mapping from period to seasonal component. - seasonal: HashMap>, + seasonal: HashMap>, /// Residuals. - residuals: Vec, + residuals: Vec, /// Weights used in the robust fit. - robust_weights: Vec, + robust_weights: Vec, } impl MSTLDecomposition { /// Return the trend component. - pub fn trend(&self) -> &[f64] { + pub fn trend(&self) -> &[f32] { &self.trend } /// Return the seasonal component for a given period, /// or None if the period is not present. - pub fn seasonal(&self, period: usize) -> Option<&[f64]> { + pub fn seasonal(&self, period: usize) -> Option<&[f32]> { self.seasonal.get(&period).map(|v| v.as_slice()) } /// Return a mapping from period to seasonal component. - pub fn seasonals(&self) -> &HashMap> { + pub fn seasonals(&self) -> &HashMap> { &self.seasonal } /// Return the residuals. - pub fn residuals(&self) -> &[f64] { + pub fn residuals(&self) -> &[f32] { &self.residuals } /// Return the robust weights. - pub fn robust_weights(&self) -> &[f64] { + pub fn robust_weights(&self) -> &[f32] { &self.robust_weights } } @@ -224,29 +232,29 @@ mod tests { .inner_loops(2) .outer_loops(0); let mut periods = vec![24, 24 * 7]; - let mstl = MSTL::new(y, &mut periods).stl_params(params); + let mstl = MSTL::new(y.iter().map(|&x| x as f32), &mut periods).stl_params(params); let res = mstl.fit().unwrap(); let expected = vic_elec_results(); res.trend() .iter() .zip(expected.trend().iter()) - .for_each(|(a, b)| assert_approx_eq!(a, b, 1e-2_f64)); + .for_each(|(a, b)| assert_approx_eq!(a, b, 1.0)); res.seasonal(24) .unwrap() .iter() .zip(expected.seasonal(24).unwrap().iter()) // Some numeric instability somewhere causes this to differ by // up to 1.0 somewhere :/ - .for_each(|(&a, &b)| assert_approx_eq!(a, b, 1e1_f64)); + .for_each(|(&a, &b)| assert_approx_eq!(a, b, 1e1_f32)); res.seasonal(168) .unwrap() .iter() .zip(expected.seasonal(168).unwrap().iter()) - .for_each(|(a, b)| assert_approx_eq!(a, b, 1e-1_f64)); + .for_each(|(a, b)| assert_approx_eq!(a, b, 1e-1_f32)); res.residuals() .iter() .zip(expected.residuals().iter()) - .for_each(|(a, b)| assert_approx_eq!(a, b, 1e1_f64)); + .for_each(|(a, b)| assert_approx_eq!(a, b, 1e1_f32)); } } From f94f539800a2d17154d20accc36d8dffe0e62f98 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Sat, 23 Sep 2023 08:48:54 +0100 Subject: [PATCH 2/3] Use stlrs MSTL implementation stlrs will have its own implementation of MSTL in v0.3.0 which means less for this crate to maintain. Benchmarks show it runs pretty much identically to the augurs implementation which is nice. The only downside is the f64 -> f32 conversion required before calling the stlrs APIs but that would be required anyway if we want to avoid any forks. --- crates/augurs-mstl/Cargo.toml | 2 +- crates/augurs-mstl/benches/vic_elec.rs | 6 +- crates/augurs-mstl/benches/vic_elec_iai.rs | 8 +- crates/augurs-mstl/src/lib.rs | 102 +++++++++++---------- 4 files changed, 62 insertions(+), 56 deletions(-) diff --git a/crates/augurs-mstl/Cargo.toml b/crates/augurs-mstl/Cargo.toml index 456ab380..0006d01b 100644 --- a/crates/augurs-mstl/Cargo.toml +++ b/crates/augurs-mstl/Cargo.toml @@ -12,7 +12,7 @@ description = "Multiple Seasonal-Trend decomposition with LOESS (MSTL) using the augurs-core.workspace = true distrs.workspace = true serde = { workspace = true, features = ["derive"], optional = true } -stlrs = { git = "https://github.com/ankane/stl-rust", version = "0.2.2" } +stlrs = { git = "https://github.com/sd2k/stl-rust", branch = "mstl-debug-clone", version = "0.2.2" } # stlrs = "0.2.2" thiserror.workspace = true tracing.workspace = true diff --git a/crates/augurs-mstl/benches/vic_elec.rs b/crates/augurs-mstl/benches/vic_elec.rs index 42b8f70e..5689a8d6 100644 --- a/crates/augurs-mstl/benches/vic_elec.rs +++ b/crates/augurs-mstl/benches/vic_elec.rs @@ -16,12 +16,14 @@ fn vic_elec(c: &mut Criterion) { .low_pass_degree(1) .inner_loops(2) .outer_loops(0); + let mut mstl_params = stlrs::MstlParams::new(); + mstl_params.stl_params(stl_params); c.bench_function("vic_elec", |b| { b.iter_batched( - || (y.clone(), vec![24, 24 * 7], stl_params.clone()), + || (y.clone(), vec![24, 24 * 7], mstl_params.clone()), |(y, periods, stl_params)| { MSTLModel::new(periods, NaiveTrend::new()) - .stl_params(stl_params) + .mstl_params(stl_params) .fit(&y) }, BatchSize::SmallInput, diff --git a/crates/augurs-mstl/benches/vic_elec_iai.rs b/crates/augurs-mstl/benches/vic_elec_iai.rs index 843d465c..5dd3c1aa 100644 --- a/crates/augurs-mstl/benches/vic_elec_iai.rs +++ b/crates/augurs-mstl/benches/vic_elec_iai.rs @@ -3,9 +3,9 @@ use iai::{black_box, main}; use augurs_mstl::{MSTLModel, NaiveTrend}; use augurs_testing::data::VIC_ELEC; -fn vic_elec_fit(y: Vec, periods: Vec, params: stlrs::StlParams) { +fn vic_elec_fit(y: Vec, periods: Vec, params: stlrs::MstlParams) { MSTLModel::new(periods, NaiveTrend::new()) - .stl_params(params) + .mstl_params(params) .fit(&y) .ok(); } @@ -22,10 +22,12 @@ fn bench_vic_elec_fit() { .low_pass_degree(1) .inner_loops(2) .outer_loops(0); + let mut mstl_params = stlrs::MstlParams::new(); + mstl_params.stl_params(stl_params); vic_elec_fit( black_box(y.clone()), black_box(vec![24, 24 * 7]), - black_box(stl_params), + black_box(mstl_params), ); } diff --git a/crates/augurs-mstl/src/lib.rs b/crates/augurs-mstl/src/lib.rs index e0525fb2..c2347d16 100644 --- a/crates/augurs-mstl/src/lib.rs +++ b/crates/augurs-mstl/src/lib.rs @@ -8,17 +8,17 @@ use std::marker::PhantomData; +use stlrs::MstlResult; use tracing::instrument; use augurs_core::{Forecast, ForecastIntervals}; // mod approx; -pub mod mstl; +// pub mod mstl; // mod stationarity; mod trend; // mod utils; -use crate::mstl::{MSTLDecomposition, MSTL}; pub use crate::trend::{NaiveTrend, TrendModel}; /// A marker struct indicating that a model is fit. @@ -54,11 +54,11 @@ type Result = std::result::Result; pub struct MSTLModel { /// Periodicity of the seasonal components. periods: Vec, - stl_params: stlrs::StlParams, + mstl_params: stlrs::MstlParams, state: PhantomData, - decomposed: Option, + fit: Option, trend_model: T, } @@ -86,15 +86,18 @@ impl MSTLModel { Self { periods, state: PhantomData, - stl_params: stlrs::params(), - decomposed: None, + mstl_params: stlrs::MstlParams::new(), + fit: None, trend_model, } } - /// Set the parameters for the STL algorithm. - pub fn stl_params(mut self, params: stlrs::StlParams) -> Self { - self.stl_params = params; + /// Set the parameters for the MSTL algorithm. + /// + /// This can be used to control the parameters for the inner STL algorithm + /// by using [`MstlParams::stl_params`]. + pub fn mstl_params(mut self, params: stlrs::MstlParams) -> Self { + self.mstl_params = params; self } @@ -109,13 +112,11 @@ impl MSTLModel { /// are also propagated. #[instrument(skip_all)] pub fn fit(mut self, y: &[f64]) -> Result> { - // Run STL for each season length. - let decomposed = MSTL::new(y.iter().map(|&x| x as f32), &mut self.periods) - .stl_params(self.stl_params.clone()) - .fit()?; + let y = y.iter().copied().map(|y| y as f32).collect::>(); + let fit = self.mstl_params.fit(&y, &self.periods)?; // Determine the differencing term for the trend component. - let trend = decomposed.trend(); - let residual = decomposed.residuals(); + let trend = fit.trend(); + let residual = fit.remainder(); let deseasonalised = trend .iter() .zip(residual) @@ -130,9 +131,9 @@ impl MSTLModel { ); Ok(MSTLModel { periods: self.periods, - stl_params: self.stl_params, + mstl_params: self.mstl_params, state: PhantomData, - decomposed: Some(decomposed), + fit: Some(fit), trend_model: self.trend_model, }) } @@ -196,35 +197,32 @@ impl MSTLModel { } fn add_seasonal_in_sample(&self, trend: &mut Forecast) { - self.decomposed() - .seasonals() - .values() - .for_each(|component| { - let period_contributions = component.iter().zip(trend.point.iter_mut()); - match &mut trend.intervals { - None => period_contributions.for_each(|(c, p)| *p += *c as f64), - Some(ForecastIntervals { - ref mut lower, - ref mut upper, - .. - }) => { - period_contributions - .zip(lower.iter_mut()) - .zip(upper.iter_mut()) - .for_each(|(((c, p), l), u)| { - *p += *c as f64; - *l += *c as f64; - *u += *c as f64; - }); - } + self.fit().seasonal().iter().for_each(|component| { + let period_contributions = component.iter().zip(trend.point.iter_mut()); + match &mut trend.intervals { + None => period_contributions.for_each(|(c, p)| *p += *c as f64), + Some(ForecastIntervals { + ref mut lower, + ref mut upper, + .. + }) => { + period_contributions + .zip(lower.iter_mut()) + .zip(upper.iter_mut()) + .for_each(|(((c, p), l), u)| { + *p += *c as f64; + *l += *c as f64; + *u += *c as f64; + }); } - }); + } + }); } fn add_seasonal_out_of_sample(&self, trend: &mut Forecast) { - self.decomposed() - .seasonals() + self.periods .iter() + .zip(self.fit().seasonal()) .for_each(|(period, component)| { // For each seasonal period we're going to create a cycle iterator // which will repeat the seasonal component every `period` steps. @@ -257,9 +255,9 @@ impl MSTLModel { }); } - /// Return the MSTL decomposition of the training data. - pub fn decomposed(&self) -> &MSTLDecomposition { - self.decomposed.as_ref().unwrap() + /// Return the MSTL fit of the training data. + pub fn fit(&self) -> &MstlResult { + self.fit.as_ref().unwrap() } } @@ -286,8 +284,8 @@ mod tests { fn results_match_r() { let y = VIC_ELEC.clone(); - let mut params = stlrs::params(); - params + let mut stl_params = stlrs::params(); + stl_params .seasonal_degree(0) .seasonal_jump(1) .trend_degree(1) @@ -295,9 +293,11 @@ mod tests { .low_pass_degree(1) .inner_loops(2) .outer_loops(0); + let mut mstl_params = stlrs::MstlParams::new(); + mstl_params.stl_params(stl_params); let periods = vec![24, 24 * 7]; let trend_model = NaiveTrend::new(); - let mstl = MSTLModel::new(periods, trend_model).stl_params(params); + let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params); let fit = mstl.fit(&y).unwrap(); let in_sample = fit.predict_in_sample(0.95).unwrap(); @@ -344,8 +344,8 @@ mod tests { fn predict_zero_horizon() { let y = VIC_ELEC.clone(); - let mut params = stlrs::params(); - params + let mut stl_params = stlrs::params(); + stl_params .seasonal_degree(0) .seasonal_jump(1) .trend_degree(1) @@ -353,9 +353,11 @@ mod tests { .low_pass_degree(1) .inner_loops(2) .outer_loops(0); + let mut mstl_params = stlrs::MstlParams::new(); + mstl_params.stl_params(stl_params); let periods = vec![24, 24 * 7]; let trend_model = NaiveTrend::new(); - let mstl = MSTLModel::new(periods, trend_model).stl_params(params); + let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params); let fit = mstl.fit(&y).unwrap(); let forecast = fit.predict(0, 0.95).unwrap(); assert!(forecast.point.is_empty()); From f0b35110d5130ed5fe4f72848a213c654d97678b Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Mon, 25 Sep 2023 08:58:42 +0100 Subject: [PATCH 3/3] Use released version of stlrs --- crates/augurs-mstl/Cargo.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/augurs-mstl/Cargo.toml b/crates/augurs-mstl/Cargo.toml index 0006d01b..25005fba 100644 --- a/crates/augurs-mstl/Cargo.toml +++ b/crates/augurs-mstl/Cargo.toml @@ -12,8 +12,7 @@ description = "Multiple Seasonal-Trend decomposition with LOESS (MSTL) using the augurs-core.workspace = true distrs.workspace = true serde = { workspace = true, features = ["derive"], optional = true } -stlrs = { git = "https://github.com/sd2k/stl-rust", branch = "mstl-debug-clone", version = "0.2.2" } -# stlrs = "0.2.2" +stlrs = "0.3.0" thiserror.workspace = true tracing.workspace = true