From 6224a6a402b2c2c38a3d4aa27bc8f3d75163c350 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Fri, 22 Nov 2024 16:12:25 +0000 Subject: [PATCH] feat!(prophet): support sub-daily & non-UTC holidays (#181) --- crates/augurs-prophet/src/features.rs | 278 ++++++++++++++++++---- crates/augurs-prophet/src/lib.rs | 2 +- crates/augurs-prophet/src/prophet.rs | 44 ++-- crates/augurs-prophet/src/prophet/prep.rs | 263 +++++++++++++------- js/augurs-prophet-js/src/lib.rs | 61 ++--- js/testpkg/package-lock.json | 2 +- js/testpkg/prophet.test.ts | 14 +- justfile | 2 +- 8 files changed, 478 insertions(+), 188 deletions(-) diff --git a/crates/augurs-prophet/src/features.rs b/crates/augurs-prophet/src/features.rs index 5e5fb23e..96ec5325 100644 --- a/crates/augurs-prophet/src/features.rs +++ b/crates/augurs-prophet/src/features.rs @@ -1,7 +1,9 @@ //! Features used by Prophet, such as seasonality, regressors and holidays. use std::num::NonZeroU32; -use crate::{positive_float::PositiveFloat, Error, TimestampSeconds}; +use crate::{ + positive_float::PositiveFloat, prophet::prep::ONE_DAY_IN_SECONDS_INT, TimestampSeconds, +}; /// The mode of a seasonality, regressor, or holiday. #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] @@ -13,71 +15,127 @@ pub enum FeatureMode { Multiplicative, } -/// A holiday to be considered by the Prophet model. +/// An occurrence of a holiday. +/// +/// Each occurrence has a start and end time represented as +/// a Unix timestamp. Holiday occurrences are therefore +/// timestamp-unaware and can therefore span multiple days +/// or even sub-daily periods. +/// +/// This differs from the Python and R Prophet implementations, +/// which require all holidays to be day-long events. Some +/// convenience methods are provided to create day-long +/// occurrences: see [`HolidayOccurrence::for_day`] and +/// [`HolidayOccurrence::for_day_in_tz`]. +/// +/// The caller is responsible for ensuring that the start +/// and end time provided are in the correct timezone. +/// One way to do this is to use [`chrono::FixedOffset`][fo] +/// to create an offset representing the time zone, +/// [`FixedOffset::with_ymd_and_hms`][wyah] to create a +/// [`DateTime`][dt] in that time zone, then [`DateTime::timestamp`][ts] +/// to get the Unix timestamp. +/// +/// [fo]: https://docs.rs/chrono/latest/chrono/struct.FixedOffset.html +/// [wyah]: https://docs.rs/chrono/latest/chrono/struct.FixedOffset.html#method.with_ymd_and_hms +/// [dt]: https://docs.rs/chrono/latest/chrono/struct.DateTime.html +/// [ts]: https://docs.rs/chrono/latest/chrono/struct.DateTime.html#method.timestamp #[derive(Debug, Clone)] -pub struct Holiday { - pub(crate) ds: Vec, - pub(crate) lower_window: Option>, - pub(crate) upper_window: Option>, - pub(crate) prior_scale: Option, +pub struct HolidayOccurrence { + pub(crate) start: TimestampSeconds, + pub(crate) end: TimestampSeconds, } -impl Holiday { - /// Create a new holiday. - pub fn new(ds: Vec) -> Self { - Self { - ds, - lower_window: None, - upper_window: None, - prior_scale: None, - } +impl HolidayOccurrence { + /// Create a new holiday occurrence with the given + /// start and end timestamp. + pub fn new(start: TimestampSeconds, end: TimestampSeconds) -> Self { + Self { start, end } } - /// Set the lower window for the holiday. + /// Create a new holiday encompassing midnight on the day + /// of the given timestamp to midnight on the following day, + /// in UTC. /// - /// The lower window is the number of days before the holiday - /// that it is observed. For example, if the holiday is on - /// 2023-01-01 and the lower window is 1, then the holiday will - /// _also_ be observed on 2022-12-31. - pub fn with_lower_window(mut self, lower_window: Vec) -> Result { - if self.ds.len() != lower_window.len() { - return Err(Error::MismatchedLengths { - a_name: "ds".to_string(), - a: self.ds.len(), - b_name: "lower_window".to_string(), - b: lower_window.len(), - }); - } - self.lower_window = Some(lower_window); - Ok(self) + /// This is a convenience method to reproduce the behaviour + /// of the Python and R Prophet implementations, which require + /// all holidays to be day-long events. + /// + /// Note that this will _not_ handle daylight saving time + /// transitions correctly. To handle this correctly, use + /// [`HolidayOccurrence::new`] with the correct start and + /// end times, e.g. by calculating them using [`chrono`]. + /// + /// [`chrono`]: https://docs.rs/chrono/latest/chrono + pub fn for_day(day: TimestampSeconds) -> Self { + Self::for_day_in_tz(day, 0) } - /// Set the upper window for the holiday. + /// Create a new holiday encompassing midnight on the day + /// of the given timestamp to midnight on the following day, + /// in a timezone represented by the `utc_offset_seconds`. + /// + /// The UTC offset can be calculated using, for example, + /// [`chrono::FixedOffset::local_minus_utc`]. Alternatively + /// it's the number of seconds to add to convert from the + /// local time to UTC, so UTC+1 is represented by `3600` + /// and UTC-5 by `-18000`. + /// + /// This is a convenience method to reproduce the behaviour + /// of the Python and R Prophet implementations, which require + /// all holidays to be day-long events. /// - /// The upper window is the number of days after the holiday - /// that it is observed. For example, if the holiday is on - /// 2023-01-01 and the upper window is 1, then the holiday will - /// _also_ be observed on 2023-01-02. - pub fn with_upper_window(mut self, upper_window: Vec) -> Result { - if self.ds.len() != upper_window.len() { - return Err(Error::MismatchedLengths { - a_name: "ds".to_string(), - a: self.ds.len(), - b_name: "upper_window".to_string(), - b: upper_window.len(), - }); + /// Note that this will _not_ handle daylight saving time + /// transitions correctly. To handle this correctly, use + /// [`HolidayOccurrence::new`] with the correct start and + /// end times, e.g. by calculating them using [`chrono`]. + /// + /// [`chrono`]: https://docs.rs/chrono/latest/chrono + pub fn for_day_in_tz(day: TimestampSeconds, utc_offset_seconds: i32) -> Self { + let day = floor_day(day, utc_offset_seconds); + Self { + start: day, + end: day + ONE_DAY_IN_SECONDS_INT, + } + } + + /// Check if the given timestamp is within this occurrence. + pub(crate) fn contains(&self, ds: TimestampSeconds) -> bool { + self.start <= ds && ds < self.end + } +} + +/// A holiday to be considered by the Prophet model. +#[derive(Debug, Clone)] +pub struct Holiday { + pub(crate) occurrences: Vec, + pub(crate) prior_scale: Option, +} + +impl Holiday { + /// Create a new holiday with the given occurrences. + pub fn new(occurrences: Vec) -> Self { + Self { + occurrences, + prior_scale: None, } - self.upper_window = Some(upper_window); - Ok(self) } - /// Add a prior scale for the holiday. + /// Set the prior scale for the holiday. pub fn with_prior_scale(mut self, prior_scale: PositiveFloat) -> Self { self.prior_scale = Some(prior_scale); self } } +fn floor_day(ds: TimestampSeconds, offset: i32) -> TimestampSeconds { + let adjusted_ds = ds + offset as TimestampSeconds; + let remainder = + ((adjusted_ds % ONE_DAY_IN_SECONDS_INT) + ONE_DAY_IN_SECONDS_INT) % ONE_DAY_IN_SECONDS_INT; + // Adjust the date to the holiday's UTC offset. + ds - remainder +} + /// Whether or not to standardize a regressor. #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] pub enum Standardize { @@ -232,3 +290,127 @@ impl Seasonality { self } } + +#[cfg(test)] +mod test { + use chrono::{FixedOffset, TimeZone, Utc}; + + use crate::features::floor_day; + + #[test] + fn floor_day_no_offset() { + let offset = Utc; + let expected = offset + .with_ymd_and_hms(2024, 11, 21, 0, 0, 0) + .unwrap() + .timestamp(); + assert_eq!(floor_day(expected, 0), expected); + assert_eq!( + floor_day( + offset + .with_ymd_and_hms(2024, 11, 21, 15, 3, 12) + .unwrap() + .timestamp(), + 0 + ), + expected + ); + } + + #[test] + fn floor_day_positive_offset() { + let offset = FixedOffset::east_opt(60 * 60 * 4).unwrap(); + let expected = offset + .with_ymd_and_hms(2024, 11, 21, 0, 0, 0) + .unwrap() + .timestamp(); + + assert_eq!(floor_day(expected, offset.local_minus_utc()), expected); + assert_eq!( + floor_day( + offset + .with_ymd_and_hms(2024, 11, 21, 15, 3, 12) + .unwrap() + .timestamp(), + offset.local_minus_utc() + ), + expected + ); + } + + #[test] + fn floor_day_negative_offset() { + let offset = FixedOffset::west_opt(60 * 60 * 3).unwrap(); + let expected = offset + .with_ymd_and_hms(2024, 11, 21, 0, 0, 0) + .unwrap() + .timestamp(); + + assert_eq!(floor_day(expected, offset.local_minus_utc()), expected); + assert_eq!( + floor_day( + offset + .with_ymd_and_hms(2024, 11, 21, 15, 3, 12) + .unwrap() + .timestamp(), + offset.local_minus_utc() + ), + expected + ); + } + + #[test] + fn floor_day_edge_cases() { + // Test maximum valid offset (UTC+14) + let max_offset = 14 * 60 * 60; + let offset = FixedOffset::east_opt(max_offset).unwrap(); + let expected = offset + .with_ymd_and_hms(2024, 11, 21, 0, 0, 0) + .unwrap() + .timestamp(); + assert_eq!( + floor_day( + offset + .with_ymd_and_hms(2024, 11, 21, 12, 0, 0) + .unwrap() + .timestamp(), + offset.local_minus_utc() + ), + expected + ); + + // Test near day boundary + let offset = FixedOffset::east_opt(60).unwrap(); + let expected = offset + .with_ymd_and_hms(2024, 11, 21, 0, 0, 0) + .unwrap() + .timestamp(); + assert_eq!( + floor_day( + offset + .with_ymd_and_hms(2024, 11, 21, 23, 59, 59) + .unwrap() + .timestamp(), + offset.local_minus_utc() + ), + expected + ); + + // Test when the day is before the epoch. + let offset = FixedOffset::west_opt(3600).unwrap(); + let expected = offset + .with_ymd_and_hms(1969, 1, 1, 0, 0, 0) + .unwrap() + .timestamp(); + assert_eq!( + floor_day( + offset + .with_ymd_and_hms(1969, 1, 1, 0, 30, 0) + .unwrap() + .timestamp(), + offset.local_minus_utc() + ), + expected + ); + } +} diff --git a/crates/augurs-prophet/src/lib.rs b/crates/augurs-prophet/src/lib.rs index edf2e71b..457bb2ac 100644 --- a/crates/augurs-prophet/src/lib.rs +++ b/crates/augurs-prophet/src/lib.rs @@ -23,7 +23,7 @@ pub type TimestampSeconds = i64; // navigate the module hierarchy. pub use data::{PredictionData, TrainingData}; pub use error::Error; -pub use features::{FeatureMode, Holiday, Regressor, Seasonality, Standardize}; +pub use features::{FeatureMode, Holiday, HolidayOccurrence, Regressor, Seasonality, Standardize}; pub use optimizer::{Algorithm, Optimizer, TrendIndicator}; pub use positive_float::{PositiveFloat, TryFromFloatError}; pub use prophet::{ diff --git a/crates/augurs-prophet/src/prophet.rs b/crates/augurs-prophet/src/prophet.rs index d651678c..9cb5ff64 100644 --- a/crates/augurs-prophet/src/prophet.rs +++ b/crates/augurs-prophet/src/prophet.rs @@ -524,7 +524,7 @@ mod test_custom_seasonal { optimizer::mock_optimizer::MockOptimizer, prophet::prep::{FeatureName, Features}, testdata::daily_univariate_ts, - FeatureMode, Holiday, ProphetOptions, Seasonality, SeasonalityOption, + FeatureMode, Holiday, HolidayOccurrence, ProphetOptions, Seasonality, SeasonalityOption, }; use super::Prophet; @@ -534,12 +534,14 @@ mod test_custom_seasonal { let holiday_dates = ["2017-01-02"] .iter() .map(|s| { - s.parse::() - .unwrap() - .and_hms_opt(0, 0, 0) - .unwrap() - .and_utc() - .timestamp() + HolidayOccurrence::for_day( + s.parse::() + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .and_utc() + .timestamp(), + ) }) .collect(); @@ -703,8 +705,8 @@ mod test_holidays { use chrono::NaiveDate; use crate::{ - optimizer::mock_optimizer::MockOptimizer, testdata::daily_univariate_ts, Holiday, Prophet, - ProphetOptions, + optimizer::mock_optimizer::MockOptimizer, testdata::daily_univariate_ts, Holiday, + HolidayOccurrence, Prophet, ProphetOptions, }; #[test] @@ -712,24 +714,18 @@ mod test_holidays { let holiday_dates = ["2012-10-09", "2013-10-09"] .iter() .map(|s| { - s.parse::() - .unwrap() - .and_hms_opt(0, 0, 0) - .unwrap() - .and_utc() - .timestamp() + HolidayOccurrence::for_day( + s.parse::() + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .and_utc() + .timestamp(), + ) }) .collect(); let opts = ProphetOptions { - holidays: [( - "bens-bday".to_string(), - Holiday::new(holiday_dates) - .with_lower_window(vec![0, 0]) - .unwrap() - .with_upper_window(vec![1, 1]) - .unwrap(), - )] - .into(), + holidays: [("bens-bday".to_string(), Holiday::new(holiday_dates))].into(), ..Default::default() }; let data = daily_univariate_ts(); diff --git a/crates/augurs-prophet/src/prophet/prep.rs b/crates/augurs-prophet/src/prophet/prep.rs index 0ba78ca8..8c1565fa 100644 --- a/crates/augurs-prophet/src/prophet/prep.rs +++ b/crates/augurs-prophet/src/prophet/prep.rs @@ -3,7 +3,7 @@ use std::{ num::NonZeroU32, }; -use itertools::{izip, Either, Itertools, MinMaxResult}; +use itertools::{Either, Itertools, MinMaxResult}; use crate::{ features::RegressorScale, @@ -16,7 +16,7 @@ use crate::{ const ONE_YEAR_IN_SECONDS: f64 = 365.25 * 24.0 * 60.0 * 60.0; const ONE_WEEK_IN_SECONDS: f64 = 7.0 * 24.0 * 60.0 * 60.0; const ONE_DAY_IN_SECONDS: f64 = 24.0 * 60.0 * 60.0; -const ONE_DAY_IN_SECONDS_INT: i64 = 24 * 60 * 60; +pub(crate) const ONE_DAY_IN_SECONDS_INT: i64 = 24 * 60 * 60; #[derive(Debug, Clone, Default)] pub(super) struct Scales { @@ -138,9 +138,6 @@ pub(super) enum FeatureName { Holiday { /// The name of the holiday. name: String, - /// The offset from the holiday date, as permitted - /// by the lower or upper window. - _offset: i32, }, Dummy, } @@ -666,49 +663,17 @@ impl Prophet { // days except that day, and 1.0 for that day. let mut this_holiday_features: HashMap> = HashMap::new(); - // Default to a window of 0 days either side. - let lower = holiday - .lower_window - .as_ref() - .map(|x| { - Box::new(x.iter().copied().map(|x| x as i32)) as Box> - }) - .unwrap_or_else(|| Box::new(std::iter::repeat(0))); - let upper = holiday - .upper_window - .as_ref() - .map(|x| { - Box::new(x.iter().copied().map(|x| x as i32)) as Box> - }) - .unwrap_or_else(|| Box::new(std::iter::repeat(0))); - - for (dt, lower, upper) in izip!(holiday.ds, lower, upper) { - // Round down the original timestamps to the nearest day. - let remainder = dt % ONE_DAY_IN_SECONDS_INT; - let dt_date = dt - remainder; - - // Check each of the possible offsets allowed by the lower/upper windows. - // We know that the lower window is always positive since it was originally - // a u32, so we can use `-lower..upper` here. - for offset in -lower..=upper { - let offset_seconds = offset as i64 * ONE_DAY_IN_SECONDS as i64; - let occurrence = dt_date + offset_seconds; - let col_name = FeatureName::Holiday { - name: name.clone(), - _offset: offset, - }; - let col = this_holiday_features - .entry(col_name.clone()) - .or_insert_with(|| vec![0.0; ds.len()]); - - // Get the indices of the ds column that are 'on holiday'. - // Set the value of the holiday column 1.0 for those dates. - for loc in ds - .iter() - .positions(|x| (x - (x % ONE_DAY_IN_SECONDS_INT)) == occurrence) - { - col[loc] = 1.0; - } + for occurrence in holiday.occurrences { + let col_name = FeatureName::Holiday { name: name.clone() }; + + let col = this_holiday_features + .entry(col_name.clone()) + .or_insert_with(|| vec![0.0; ds.len()]); + + // Get the indices of the ds column that are 'on holiday'. + // Set the value of the holiday column to 1.0 for those dates. + for loc in ds.iter().positions(|&x| occurrence.contains(x)) { + col[loc] = 1.0; } } // Add the holiday column to the features frame, and add a corresponding @@ -1086,6 +1051,7 @@ impl Preprocessed { #[cfg(test)] mod test { use crate::{ + features::HolidayOccurrence, optimizer::mock_optimizer::MockOptimizer, testdata::{daily_univariate_ts, train_test_split}, util::FloatIterExt, @@ -1094,9 +1060,19 @@ mod test { use super::*; use augurs_testing::assert_approx_eq; - use chrono::NaiveDate; + use chrono::{Days, FixedOffset, NaiveDate, TimeZone, Utc}; use pretty_assertions::assert_eq; + macro_rules! concat_all { + ($($x:expr),+ $(,)?) => {{ + let mut result = Vec::new(); + $( + result.extend($x.iter().cloned()); + )+ + result + }}; + } + #[test] fn setup_dataframe() { let (data, _) = train_test_split(daily_univariate_ts(), 0.5); @@ -1213,29 +1189,156 @@ mod test { ); } + #[test] + fn make_holiday_features() { + // Create some hourly data between 2024-01-01 and 2024-01-07. + let start = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); + let end = Utc.with_ymd_and_hms(2024, 1, 7, 0, 0, 0).unwrap(); + let ds = std::iter::successors(Some(start), |d| { + d.checked_add_signed(chrono::Duration::hours(1)) + }) + .take_while(|d| *d < end) + .map(|d| d.timestamp()) + .collect_vec(); + // Create two holidays: one in UTC on 2024-01-02 and 2024-01-04; + // one in UTC-3 on the same dates. + // The holidays may appear more than once since the data is hourly, + // and this shouldn't affect the results. + // Ignore windows for now. + let non_utc_tz = FixedOffset::west_opt(3600 * 3).unwrap(); + let holidays: HashMap = [ + ( + "UTC holiday".to_string(), + Holiday::new(vec![ + HolidayOccurrence::for_day( + Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0) + .unwrap() + .timestamp(), + ), + HolidayOccurrence::for_day( + Utc.with_ymd_and_hms(2024, 1, 2, 12, 0, 0) + .unwrap() + .timestamp(), + ), + HolidayOccurrence::for_day( + Utc.with_ymd_and_hms(2024, 1, 4, 0, 0, 0) + .unwrap() + .timestamp(), + ), + ]), + ), + ( + "Non-UTC holiday".to_string(), + Holiday::new(vec![ + HolidayOccurrence::for_day_in_tz( + non_utc_tz + .with_ymd_and_hms(2024, 1, 2, 0, 0, 0) + .unwrap() + .timestamp(), + -3 * 3600, + ), + HolidayOccurrence::for_day_in_tz( + non_utc_tz + .with_ymd_and_hms(2024, 1, 2, 12, 0, 0) + .unwrap() + .timestamp(), + -3 * 3600, + ), + HolidayOccurrence::for_day_in_tz( + non_utc_tz + .with_ymd_and_hms(2024, 1, 4, 0, 0, 0) + .unwrap() + .timestamp(), + -3 * 3600, + ), + ]), + ), + ] + .into(); + let opts = ProphetOptions { + holidays: holidays.clone(), + ..Default::default() + }; + let prophet = Prophet::new(opts, MockOptimizer::new()); + let mut features_frame = FeaturesFrame::new(); + let mut prior_scales = Vec::new(); + let mut modes = Modes::default(); + + let holiday_names = prophet.make_holiday_features( + &ds, + holidays, + &mut features_frame, + &mut prior_scales, + &mut modes, + ); + assert_eq!( + holiday_names, + HashSet::from(["UTC holiday".to_string(), "Non-UTC holiday".to_string(),]) + ); + + assert_eq!(features_frame.names.len(), 2); + let utc_idx = features_frame + .names + .iter() + .position(|x| matches!(x, FeatureName::Holiday { name } if name == "UTC holiday")) + .unwrap(); + assert_eq!( + features_frame.data[utc_idx], + concat_all!( + &[0.0; 24], // 2024-01-01 - off holiday + &[1.0; 24], // 2024-01-02 - on holiday + &[0.0; 24], // 2024-01-03 - off holiday + &[1.0; 24], // 2024-01-04 - on holiday + &[0.0; 48], // 2024-01-05 and 2024-01-06 - off holiday + ), + ); + let non_utc_idx = features_frame + .names + .iter() + .position(|x| matches!(x, FeatureName::Holiday { name } if name == "Non-UTC holiday")) + .unwrap(); + assert_eq!( + features_frame.data[non_utc_idx], + concat_all!( + &[0.0; 24], // 2024-01-01 - off holiday + &[0.0; 3], // first 3 hours of 2024-01-02 in UTC are off holiday + &[1.0; 24], // rest of 2024-01-02 in UTC, and first 3 hours of the next day, are on holiday + &[0.0; 24], // continue the cycle... + &[1.0; 24], + &[0.0; 21 + 24], + ), + ); + } + #[test] fn regressor_column_matrix() { let holiday_dates = ["2012-10-09", "2013-10-09"] .iter() - .map(|s| { - s.parse::() - .unwrap() - .and_hms_opt(0, 0, 0) - .unwrap() - .and_utc() - .timestamp() + .flat_map(|s| { + [ + HolidayOccurrence::for_day( + s.parse::() + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .and_utc() + .timestamp(), + ), + HolidayOccurrence::for_day( + s.parse::() + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .and_utc() + .checked_add_days(Days::new(1)) + .unwrap() + .timestamp(), + ), + ] }) .collect(); let opts = ProphetOptions { - holidays: [( - "bens-bday".to_string(), - Holiday::new(holiday_dates) - .with_lower_window(vec![0, 0]) - .unwrap() - .with_upper_window(vec![1, 1]) - .unwrap(), - )] - .into(), + holidays: [("bens-bday".to_string(), Holiday::new(holiday_dates))].into(), ..Default::default() }; let mut prophet = Prophet::new(opts, MockOptimizer::new()); @@ -1296,11 +1399,6 @@ mod test { }, FeatureName::Holiday { name: "bens-bday".to_string(), - _offset: 0, - }, - FeatureName::Holiday { - name: "bens-bday".to_string(), - _offset: 1, }, FeatureName::Regressor("binary_feature".to_string()), FeatureName::Regressor("numeric_feature".to_string()), @@ -1310,46 +1408,43 @@ mod test { &["bens-bday".to_string()].into_iter().collect(), &mut modes, ); - assert_eq!(cols.additive, vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1]); - assert_eq!( - cols.multiplicative, - vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0] - ); - assert_eq!(cols.all_holidays, vec![0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]); + assert_eq!(cols.additive, vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1]); + assert_eq!(cols.multiplicative, vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]); + assert_eq!(cols.all_holidays, vec![0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]); assert_eq!( cols.regressors_additive, - vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1] + vec![0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1] ); assert_eq!( cols.regressors_multiplicative, - vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0] + vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0] ); assert_eq!(cols.seasonalities.len(), 1); assert_eq!( cols.seasonalities["weekly"], - &[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] + &[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0] ); assert_eq!(cols.holidays.len(), 1); assert_eq!( cols.holidays["bens-bday"], - &[0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0] + &[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0] ); assert_eq!(cols.regressors.len(), 4); assert_eq!( cols.regressors["binary_feature"], - &[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] + &[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] ); assert_eq!( cols.regressors["numeric_feature"], - &[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0] + &[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0] ); assert_eq!( cols.regressors["numeric_feature2"], - &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0] + &[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0] ); assert_eq!( cols.regressors["binary_feature2"], - &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1] + &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1] ); assert_eq!( modes, diff --git a/js/augurs-prophet-js/src/lib.rs b/js/augurs-prophet-js/src/lib.rs index 26dce638..34af3ffe 100644 --- a/js/augurs-prophet-js/src/lib.rs +++ b/js/augurs-prophet-js/src/lib.rs @@ -1217,32 +1217,43 @@ impl From for augurs_prophet::FeatureMode { } } -/// A holiday to be considered by the Prophet model. +/// An occurrence of a holiday. +/// +/// Each occurrence has a start and end time represented as +/// a Unix timestamp. Holiday occurrences are therefore +/// timestamp-unaware and can therefore span multiple days +/// or even sub-daily periods. +/// +/// This differs from the Python and R Prophet implementations, +/// which require all holidays to be day-long events. +/// +/// The caller is responsible for ensuring that the start +/// and end time provided are in the correct timezone. #[derive(Clone, Debug, Deserialize, Tsify)] #[serde(rename_all = "camelCase")] #[tsify(from_wasm_abi, type_prefix = "Prophet")] -pub struct Holiday { - /// The dates of the holiday. - #[tsify(type = "TimestampSeconds[]")] - pub ds: Vec, +pub struct HolidayOccurrence { + /// The start of the holiday, as a Unix timestamp in seconds. + #[tsify(type = "TimestampSeconds")] + pub start: TimestampSeconds, + /// The end of the holiday, as a Unix timestamp in seconds. + #[tsify(type = "TimestampSeconds")] + pub end: TimestampSeconds, +} - /// The lower window for the holiday. - /// - /// The lower window is the number of days before the holiday - /// that it is observed. For example, if the holiday is on - /// 2023-01-01 and the lower window is 1, then the holiday will - /// _also_ be observed on 2022-12-31. - #[tsify(optional)] - pub lower_window: Option>, +impl From for augurs_prophet::HolidayOccurrence { + fn from(value: HolidayOccurrence) -> Self { + Self::new(value.start, value.end) + } +} - /// The upper window for the holiday. - /// - /// The upper window is the number of days after the holiday - /// that it is observed. For example, if the holiday is on - /// 2023-01-01 and the upper window is 1, then the holiday will - /// _also_ be observed on 2023-01-02. - #[tsify(optional)] - pub upper_window: Option>, +/// A holiday to be considered by the Prophet model. +#[derive(Clone, Debug, Deserialize, Tsify)] +#[serde(rename_all = "camelCase")] +#[tsify(from_wasm_abi, type_prefix = "Prophet")] +pub struct Holiday { + /// The occurrences of the holiday. + pub occurrences: Vec, /// The prior scale for the holiday. #[tsify(optional)] @@ -1253,13 +1264,7 @@ impl TryFrom for augurs_prophet::Holiday { type Error = JsError; fn try_from(value: Holiday) -> Result { - let mut holiday = Self::new(value.ds); - if let Some(lower_window) = value.lower_window { - holiday = holiday.with_lower_window(lower_window)?; - } - if let Some(upper_window) = value.upper_window { - holiday = holiday.with_upper_window(upper_window)?; - } + let mut holiday = Self::new(value.occurrences.into_iter().map(|x| x.into()).collect()); if let Some(prior_scale) = value.prior_scale { holiday = holiday.with_prior_scale(prior_scale.try_into()?); } diff --git a/js/testpkg/package-lock.json b/js/testpkg/package-lock.json index 9cd80871..b65ada92 100644 --- a/js/testpkg/package-lock.json +++ b/js/testpkg/package-lock.json @@ -15,7 +15,7 @@ }, "../augurs": { "name": "@bsull/augurs", - "version": "0.5.0", + "version": "0.6.3", "dev": true, "license": "MIT OR Apache-2.0" }, diff --git a/js/testpkg/prophet.test.ts b/js/testpkg/prophet.test.ts index d23b9406..8448e565 100644 --- a/js/testpkg/prophet.test.ts +++ b/js/testpkg/prophet.test.ts @@ -1,7 +1,7 @@ import { webcrypto } from 'node:crypto' import { readFileSync } from "node:fs"; -import { Prophet, initSync } from '@bsull/augurs/prophet'; +import { Prophet, ProphetHoliday, ProphetHolidayOccurrence, initSync } from '@bsull/augurs/prophet'; import { optimizer } from '@bsull/augurs-prophet-wasmstan'; import { describe, expect, it } from 'vitest'; @@ -54,4 +54,16 @@ describe('Prophet', () => { expect(preds.yhat.point).toHaveLength(y.length); expect(preds.yhat.point).toBeInstanceOf(Array); }); + + describe('holidays', () => { + it('can be set', () => { + const occurrences: ProphetHolidayOccurrence[] = [ + { start: new Date('2024-12-25').getTime() / 1000, end: new Date('2024-12-26').getTime() / 1000 }, + ] + const holidays: Map = new Map([ + ["Christmas", { occurrences }], + ]); + new Prophet({ optimizer, holidays }); + }); + }) }); diff --git a/justfile b/justfile index d443f6f6..0601fcb0 100644 --- a/justfile +++ b/justfile @@ -40,7 +40,7 @@ doctest: --exclude pyaugurs \ doc: - cargo doc --all-features --workspace --exclude augurs-js --exclude pyaugurs --open + cargo doc --all-features --workspace --exclude *-js --exclude pyaugurs --open watch: bacon