Skip to content

Commit

Permalink
Accommodate restrictions of released stl-rs
Browse files Browse the repository at this point in the history
This mainly just means we have to clone a little more and cast things
between f32 and f64, because the released version of stl-rs doesn't
allow us to take ownership of the various components and doesn't include
f64 compatibility.

Note that this still currently points towards the main git branch for
stl-rust, but at least all the incorporated changes are more likely to
be merged sooner or later!

Unfortunately this results in some benchmark regressions, probably due
to the extra clones and having to convert to/from f32 :(

```
     Running benches/vic_elec.rs (target/release/deps/vic_elec-36a0fad3091799a8)
vic_elec                time:   [28.569 ms 28.591 ms 28.619 ms]
                        change: [+11.550% +11.750% +11.918%] (p = 0.00 < 0.05)
                        Performance has regressed.
Found 8 outliers among 100 measurements (8.00%)
  3 (3.00%) high mild
  5 (5.00%) high severe

```
  • Loading branch information
sd2k committed Sep 18, 2023
1 parent ff418a4 commit a0c4bad
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 51 deletions.
3 changes: 2 additions & 1 deletion crates/augurs-mstl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ description = "Multiple Seasonal-Trend decomposition with LOESS (MSTL) using the
augurs-core.workspace = true
distrs.workspace = true
serde = { workspace = true, features = ["derive"], optional = true }
stlrs = { git = "https://github.com/sd2k/stl-rust", branch = "python-lib", version = "0.2.1" }
stlrs = { git = "https://github.com/ankane/stl-rust", version = "0.2.2" }
# stlrs = "0.2.2"
thiserror.workspace = true
tracing.workspace = true

Expand Down
22 changes: 11 additions & 11 deletions crates/augurs-mstl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl<T: TrendModel> MSTLModel<T, Unfit> {
#[instrument(skip_all)]
pub fn fit(mut self, y: &[f64]) -> Result<MSTLModel<T, Fit>> {
// Run STL for each season length.
let decomposed = MSTL::new(y, &mut self.periods)
let decomposed = MSTL::new(y.iter().map(|&x| x as f32), &mut self.periods)
.stl_params(self.stl_params.clone())
.fit()?;
// Determine the differencing term for the trend component.
Expand All @@ -119,7 +119,7 @@ impl<T: TrendModel> MSTLModel<T, Unfit> {
let deseasonalised = trend
.iter()
.zip(residual)
.map(|(t, r)| t + r)
.map(|(t, r)| (t + r) as f64)
.collect::<Vec<_>>();
self.trend_model
.fit(&deseasonalised)
Expand Down Expand Up @@ -202,7 +202,7 @@ impl<T: TrendModel> MSTLModel<T, Fit> {
.for_each(|component| {
let period_contributions = component.iter().zip(trend.point.iter_mut());
match &mut trend.intervals {
None => period_contributions.for_each(|(c, p)| *p += c),
None => period_contributions.for_each(|(c, p)| *p += *c as f64),
Some(ForecastIntervals {
ref mut lower,
ref mut upper,
Expand All @@ -212,9 +212,9 @@ impl<T: TrendModel> MSTLModel<T, Fit> {
.zip(lower.iter_mut())
.zip(upper.iter_mut())
.for_each(|(((c, p), l), u)| {
*p += c;
*l += c;
*u += c;
*p += *c as f64;
*l += *c as f64;
*u += *c as f64;
});
}
}
Expand All @@ -238,7 +238,7 @@ impl<T: TrendModel> MSTLModel<T, Fit> {
.cycle()
.zip(trend.point.iter_mut());
match &mut trend.intervals {
None => period_contributions.for_each(|(c, p)| *p += c),
None => period_contributions.for_each(|(c, p)| *p += c as f64),
Some(ForecastIntervals {
ref mut lower,
ref mut upper,
Expand All @@ -248,9 +248,9 @@ impl<T: TrendModel> MSTLModel<T, Fit> {
.zip(lower.iter_mut())
.zip(upper.iter_mut())
.for_each(|(((c, p), l), u)| {
*p += c;
*l += c;
*u += c;
*p += c as f64;
*l += c as f64;
*u += c as f64;
});
}
}
Expand All @@ -277,7 +277,7 @@ mod tests {
if actual.is_nan() {
assert!(expected.is_nan());
} else {
assert_approx_eq!(actual, expected, 1e-2);
assert_approx_eq!(actual, expected, 1e-1);
}
}
}
Expand Down
86 changes: 47 additions & 39 deletions crates/augurs-mstl/src/mstl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{Error, Result};
#[allow(clippy::upper_case_acronyms)]
pub struct MSTL<'a> {
/// Time series to decompose.
y: &'a [f64],
y: Vec<f32>,
/// Periodicity of the seasonal components.
periods: &'a mut Vec<usize>,
/// Parameters for the STL decomposition.
Expand All @@ -36,9 +36,9 @@ impl<'a> MSTL<'a> {
/// Create a new MSTL decomposition.
///
/// Call `fit` to run the decomposition.
pub fn new(y: &'a [f64], periods: &'a mut Vec<usize>) -> Self {
pub fn new(y: impl Iterator<Item = f32>, periods: &'a mut Vec<usize>) -> Self {
Self {
y,
y: y.collect::<Vec<_>>(),
periods,
stl_params: stlrs::params(),
}
Expand All @@ -57,51 +57,59 @@ impl<'a> MSTL<'a> {
let seasonal_windows: Vec<usize> = self.seasonal_windows();
let iterate = if self.periods.len() == 1 { 1 } else { 2 };

let mut seasonals: HashMap<usize, Vec<f64>> = self
.periods
.iter()
.copied()
.map(|p| (p, vec![0.0; self.y.len()]))
.collect();
let mut deseas = self.y.to_vec();
let mut res: Option<StlResult<f64>> = None;
let mut seasonals: HashMap<usize, StlResult> = HashMap::with_capacity(self.periods.len());
// self.periods.iter().copied().map(|p| (p, None)).collect();
let mut deseas = self.y;
let mut res: Option<StlResult> = None;
for i in 0..iterate {
let zipped = self.periods.iter().zip(seasonal_windows.iter());
for (period, seasonal_window) in zipped {
let seas = seasonals.get_mut(period).unwrap();
let seas = seasonals.entry(*period);
// Start by adding on the seasonal effect.
deseas
.iter_mut()
.zip(seas.iter())
.for_each(|(d, s)| *d += *s);
if let std::collections::hash_map::Entry::Occupied(ref seas) = seas {
deseas
.iter_mut()
.zip(seas.get().seasonal().iter())
.for_each(|(d, s)| *d += *s);
}
// Decompose the time series for specific seasonal period.
let mut fit = tracing::debug_span!("STL.fit", i, seasonal_window, period)
.in_scope(|| {
let fit =
tracing::debug_span!("STL.fit", i, seasonal_window, period).in_scope(|| {
self.stl_params
.seasonal_length(*seasonal_window)
.fit(&deseas, *period)
})?;
*seas = std::mem::take(&mut fit.seasonal);
res = Some(fit);
// Subtract the seasonal effect again.
deseas
.iter_mut()
.zip(seas.iter())
.zip(fit.seasonal().iter())
.for_each(|(d, s)| *d -= *s);
match seas {
std::collections::hash_map::Entry::Occupied(mut o) => {
o.insert(fit.clone());
}
std::collections::hash_map::Entry::Vacant(x) => {
x.insert(fit.clone());
}
}
res = Some(fit);
}
}
let fit = res.ok_or_else(|| Error::MSTL("no STL fit".to_string()))?;
let trend = fit.trend;
let trend = fit.trend().to_vec();
deseas
.iter_mut()
.zip(trend.iter())
.for_each(|(d, r)| *d -= *r);
let rw = fit.weights;
let robust_weights = fit.weights().to_vec();
Ok(MSTLDecomposition {
trend,
seasonal: seasonals,
seasonal: seasonals
.into_iter()
.map(|(k, v)| (k, v.seasonal().to_vec()))
.collect(),
residuals: deseas,
robust_weights: rw,
robust_weights,
})
}

Expand Down Expand Up @@ -142,39 +150,39 @@ impl<'a> MSTL<'a> {
#[cfg_attr(test, derive(Default))]
pub struct MSTLDecomposition {
/// Trend component.
trend: Vec<f64>,
trend: Vec<f32>,
/// Mapping from period to seasonal component.
seasonal: HashMap<usize, Vec<f64>>,
seasonal: HashMap<usize, Vec<f32>>,
/// Residuals.
residuals: Vec<f64>,
residuals: Vec<f32>,
/// Weights used in the robust fit.
robust_weights: Vec<f64>,
robust_weights: Vec<f32>,
}

impl MSTLDecomposition {
/// Return the trend component.
pub fn trend(&self) -> &[f64] {
pub fn trend(&self) -> &[f32] {
&self.trend
}

/// Return the seasonal component for a given period,
/// or None if the period is not present.
pub fn seasonal(&self, period: usize) -> Option<&[f64]> {
pub fn seasonal(&self, period: usize) -> Option<&[f32]> {
self.seasonal.get(&period).map(|v| v.as_slice())
}

/// Return a mapping from period to seasonal component.
pub fn seasonals(&self) -> &HashMap<usize, Vec<f64>> {
pub fn seasonals(&self) -> &HashMap<usize, Vec<f32>> {
&self.seasonal
}

/// Return the residuals.
pub fn residuals(&self) -> &[f64] {
pub fn residuals(&self) -> &[f32] {
&self.residuals
}

/// Return the robust weights.
pub fn robust_weights(&self) -> &[f64] {
pub fn robust_weights(&self) -> &[f32] {
&self.robust_weights
}
}
Expand Down Expand Up @@ -224,29 +232,29 @@ mod tests {
.inner_loops(2)
.outer_loops(0);
let mut periods = vec![24, 24 * 7];
let mstl = MSTL::new(y, &mut periods).stl_params(params);
let mstl = MSTL::new(y.iter().map(|&x| x as f32), &mut periods).stl_params(params);
let res = mstl.fit().unwrap();

let expected = vic_elec_results();
res.trend()
.iter()
.zip(expected.trend().iter())
.for_each(|(a, b)| assert_approx_eq!(a, b, 1e-2_f64));
.for_each(|(a, b)| assert_approx_eq!(a, b, 1.0));
res.seasonal(24)
.unwrap()
.iter()
.zip(expected.seasonal(24).unwrap().iter())
// Some numeric instability somewhere causes this to differ by
// up to 1.0 somewhere :/
.for_each(|(&a, &b)| assert_approx_eq!(a, b, 1e1_f64));
.for_each(|(&a, &b)| assert_approx_eq!(a, b, 1e1_f32));
res.seasonal(168)
.unwrap()
.iter()
.zip(expected.seasonal(168).unwrap().iter())
.for_each(|(a, b)| assert_approx_eq!(a, b, 1e-1_f64));
.for_each(|(a, b)| assert_approx_eq!(a, b, 1e-1_f32));
res.residuals()
.iter()
.zip(expected.residuals().iter())
.for_each(|(a, b)| assert_approx_eq!(a, b, 1e1_f64));
.for_each(|(a, b)| assert_approx_eq!(a, b, 1e1_f32));
}
}

0 comments on commit a0c4bad

Please sign in to comment.