Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use upstream stlrs crate #37

Merged
merged 3 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/augurs-mstl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ description = "Multiple Seasonal-Trend decomposition with LOESS (MSTL) using the
augurs-core.workspace = true
distrs.workspace = true
serde = { workspace = true, features = ["derive"], optional = true }
stlrs = { git = "https://github.com/sd2k/stl-rust", branch = "python-lib", version = "0.2.1" }
stlrs = "0.3.0"
thiserror.workspace = true
tracing.workspace = true

Expand Down
6 changes: 4 additions & 2 deletions crates/augurs-mstl/benches/vic_elec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ fn vic_elec(c: &mut Criterion) {
.low_pass_degree(1)
.inner_loops(2)
.outer_loops(0);
let mut mstl_params = stlrs::MstlParams::new();
mstl_params.stl_params(stl_params);
c.bench_function("vic_elec", |b| {
b.iter_batched(
|| (y.clone(), vec![24, 24 * 7], stl_params.clone()),
|| (y.clone(), vec![24, 24 * 7], mstl_params.clone()),
|(y, periods, stl_params)| {
MSTLModel::new(periods, NaiveTrend::new())
.stl_params(stl_params)
.mstl_params(stl_params)
.fit(&y)
},
BatchSize::SmallInput,
Expand Down
8 changes: 5 additions & 3 deletions crates/augurs-mstl/benches/vic_elec_iai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use iai::{black_box, main};
use augurs_mstl::{MSTLModel, NaiveTrend};
use augurs_testing::data::VIC_ELEC;

fn vic_elec_fit(y: Vec<f64>, periods: Vec<usize>, params: stlrs::StlParams) {
fn vic_elec_fit(y: Vec<f64>, periods: Vec<usize>, params: stlrs::MstlParams) {
MSTLModel::new(periods, NaiveTrend::new())
.stl_params(params)
.mstl_params(params)
.fit(&y)
.ok();
}
Expand All @@ -22,10 +22,12 @@ fn bench_vic_elec_fit() {
.low_pass_degree(1)
.inner_loops(2)
.outer_loops(0);
let mut mstl_params = stlrs::MstlParams::new();
mstl_params.stl_params(stl_params);
vic_elec_fit(
black_box(y.clone()),
black_box(vec![24, 24 * 7]),
black_box(stl_params),
black_box(mstl_params),
);
}

Expand Down
114 changes: 58 additions & 56 deletions crates/augurs-mstl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@

use std::marker::PhantomData;

use stlrs::MstlResult;
use tracing::instrument;

use augurs_core::{Forecast, ForecastIntervals};

// mod approx;
pub mod mstl;
// pub mod mstl;
// mod stationarity;
mod trend;
// mod utils;

use crate::mstl::{MSTLDecomposition, MSTL};
pub use crate::trend::{NaiveTrend, TrendModel};

/// A marker struct indicating that a model is fit.
Expand Down Expand Up @@ -54,11 +54,11 @@ type Result<T> = std::result::Result<T, Error>;
pub struct MSTLModel<T, F> {
/// Periodicity of the seasonal components.
periods: Vec<usize>,
stl_params: stlrs::StlParams,
mstl_params: stlrs::MstlParams,

state: PhantomData<F>,

decomposed: Option<MSTLDecomposition>,
fit: Option<MstlResult>,
trend_model: T,
}

Expand Down Expand Up @@ -86,15 +86,18 @@ impl<T: TrendModel> MSTLModel<T, Unfit> {
Self {
periods,
state: PhantomData,
stl_params: stlrs::params(),
decomposed: None,
mstl_params: stlrs::MstlParams::new(),
fit: None,
trend_model,
}
}

/// Set the parameters for the STL algorithm.
pub fn stl_params(mut self, params: stlrs::StlParams) -> Self {
self.stl_params = params;
/// Set the parameters for the MSTL algorithm.
///
/// This can be used to control the parameters for the inner STL algorithm
/// by using [`MstlParams::stl_params`].
pub fn mstl_params(mut self, params: stlrs::MstlParams) -> Self {
self.mstl_params = params;
self
}

Expand All @@ -109,17 +112,15 @@ impl<T: TrendModel> MSTLModel<T, Unfit> {
/// are also propagated.
#[instrument(skip_all)]
pub fn fit(mut self, y: &[f64]) -> Result<MSTLModel<T, Fit>> {
// Run STL for each season length.
let decomposed = MSTL::new(y, &mut self.periods)
.stl_params(self.stl_params.clone())
.fit()?;
let y = y.iter().copied().map(|y| y as f32).collect::<Vec<_>>();
let fit = self.mstl_params.fit(&y, &self.periods)?;
// Determine the differencing term for the trend component.
let trend = decomposed.trend();
let residual = decomposed.residuals();
let trend = fit.trend();
let residual = fit.remainder();
let deseasonalised = trend
.iter()
.zip(residual)
.map(|(t, r)| t + r)
.map(|(t, r)| (t + r) as f64)
.collect::<Vec<_>>();
self.trend_model
.fit(&deseasonalised)
Expand All @@ -130,9 +131,9 @@ impl<T: TrendModel> MSTLModel<T, Unfit> {
);
Ok(MSTLModel {
periods: self.periods,
stl_params: self.stl_params,
mstl_params: self.mstl_params,
state: PhantomData,
decomposed: Some(decomposed),
fit: Some(fit),
trend_model: self.trend_model,
})
}
Expand Down Expand Up @@ -196,35 +197,32 @@ impl<T: TrendModel> MSTLModel<T, Fit> {
}

fn add_seasonal_in_sample(&self, trend: &mut Forecast) {
self.decomposed()
.seasonals()
.values()
.for_each(|component| {
let period_contributions = component.iter().zip(trend.point.iter_mut());
match &mut trend.intervals {
None => period_contributions.for_each(|(c, p)| *p += c),
Some(ForecastIntervals {
ref mut lower,
ref mut upper,
..
}) => {
period_contributions
.zip(lower.iter_mut())
.zip(upper.iter_mut())
.for_each(|(((c, p), l), u)| {
*p += c;
*l += c;
*u += c;
});
}
self.fit().seasonal().iter().for_each(|component| {
let period_contributions = component.iter().zip(trend.point.iter_mut());
match &mut trend.intervals {
None => period_contributions.for_each(|(c, p)| *p += *c as f64),
Some(ForecastIntervals {
ref mut lower,
ref mut upper,
..
}) => {
period_contributions
.zip(lower.iter_mut())
.zip(upper.iter_mut())
.for_each(|(((c, p), l), u)| {
*p += *c as f64;
*l += *c as f64;
*u += *c as f64;
});
}
});
}
});
}

fn add_seasonal_out_of_sample(&self, trend: &mut Forecast) {
self.decomposed()
.seasonals()
self.periods
.iter()
.zip(self.fit().seasonal())
.for_each(|(period, component)| {
// For each seasonal period we're going to create a cycle iterator
// which will repeat the seasonal component every `period` steps.
Expand All @@ -238,7 +236,7 @@ impl<T: TrendModel> MSTLModel<T, Fit> {
.cycle()
.zip(trend.point.iter_mut());
match &mut trend.intervals {
None => period_contributions.for_each(|(c, p)| *p += c),
None => period_contributions.for_each(|(c, p)| *p += c as f64),
Some(ForecastIntervals {
ref mut lower,
ref mut upper,
Expand All @@ -248,18 +246,18 @@ impl<T: TrendModel> MSTLModel<T, Fit> {
.zip(lower.iter_mut())
.zip(upper.iter_mut())
.for_each(|(((c, p), l), u)| {
*p += c;
*l += c;
*u += c;
*p += c as f64;
*l += c as f64;
*u += c as f64;
});
}
}
});
}

/// Return the MSTL decomposition of the training data.
pub fn decomposed(&self) -> &MSTLDecomposition {
self.decomposed.as_ref().unwrap()
/// Return the MSTL fit of the training data.
pub fn fit(&self) -> &MstlResult {
self.fit.as_ref().unwrap()
}
}

Expand All @@ -277,7 +275,7 @@ mod tests {
if actual.is_nan() {
assert!(expected.is_nan());
} else {
assert_approx_eq!(actual, expected, 1e-2);
assert_approx_eq!(actual, expected, 1e-1);
}
}
}
Expand All @@ -286,18 +284,20 @@ mod tests {
fn results_match_r() {
let y = VIC_ELEC.clone();

let mut params = stlrs::params();
params
let mut stl_params = stlrs::params();
stl_params
.seasonal_degree(0)
.seasonal_jump(1)
.trend_degree(1)
.trend_jump(1)
.low_pass_degree(1)
.inner_loops(2)
.outer_loops(0);
let mut mstl_params = stlrs::MstlParams::new();
mstl_params.stl_params(stl_params);
let periods = vec![24, 24 * 7];
let trend_model = NaiveTrend::new();
let mstl = MSTLModel::new(periods, trend_model).stl_params(params);
let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params);
let fit = mstl.fit(&y).unwrap();

let in_sample = fit.predict_in_sample(0.95).unwrap();
Expand Down Expand Up @@ -344,18 +344,20 @@ mod tests {
fn predict_zero_horizon() {
let y = VIC_ELEC.clone();

let mut params = stlrs::params();
params
let mut stl_params = stlrs::params();
stl_params
.seasonal_degree(0)
.seasonal_jump(1)
.trend_degree(1)
.trend_jump(1)
.low_pass_degree(1)
.inner_loops(2)
.outer_loops(0);
let mut mstl_params = stlrs::MstlParams::new();
mstl_params.stl_params(stl_params);
let periods = vec![24, 24 * 7];
let trend_model = NaiveTrend::new();
let mstl = MSTLModel::new(periods, trend_model).stl_params(params);
let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params);
let fit = mstl.fit(&y).unwrap();
let forecast = fit.predict(0, 0.95).unwrap();
assert!(forecast.point.is_empty());
Expand Down
Loading
Loading