Skip to content

Commit

Permalink
feat!(prophet): support sub-daily & non-UTC holidays (#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
sd2k authored Nov 22, 2024
1 parent 856be42 commit 6224a6a
Show file tree
Hide file tree
Showing 8 changed files with 478 additions and 188 deletions.
278 changes: 230 additions & 48 deletions crates/augurs-prophet/src/features.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
//! Features used by Prophet, such as seasonality, regressors and holidays.
use std::num::NonZeroU32;

use crate::{positive_float::PositiveFloat, Error, TimestampSeconds};
use crate::{
positive_float::PositiveFloat, prophet::prep::ONE_DAY_IN_SECONDS_INT, TimestampSeconds,
};

/// The mode of a seasonality, regressor, or holiday.
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
Expand All @@ -13,71 +15,127 @@ pub enum FeatureMode {
Multiplicative,
}

/// A holiday to be considered by the Prophet model.
/// An occurrence of a holiday.
///
/// Each occurrence has a start and end time represented as
/// a Unix timestamp. Holiday occurrences are therefore
/// timestamp-unaware and can therefore span multiple days
/// or even sub-daily periods.
///
/// This differs from the Python and R Prophet implementations,
/// which require all holidays to be day-long events. Some
/// convenience methods are provided to create day-long
/// occurrences: see [`HolidayOccurrence::for_day`] and
/// [`HolidayOccurrence::for_day_in_tz`].
///
/// The caller is responsible for ensuring that the start
/// and end time provided are in the correct timezone.
/// One way to do this is to use [`chrono::FixedOffset`][fo]
/// to create an offset representing the time zone,
/// [`FixedOffset::with_ymd_and_hms`][wyah] to create a
/// [`DateTime`][dt] in that time zone, then [`DateTime::timestamp`][ts]
/// to get the Unix timestamp.
///
/// [fo]: https://docs.rs/chrono/latest/chrono/struct.FixedOffset.html
/// [wyah]: https://docs.rs/chrono/latest/chrono/struct.FixedOffset.html#method.with_ymd_and_hms
/// [dt]: https://docs.rs/chrono/latest/chrono/struct.DateTime.html
/// [ts]: https://docs.rs/chrono/latest/chrono/struct.DateTime.html#method.timestamp
#[derive(Debug, Clone)]
pub struct Holiday {
pub(crate) ds: Vec<TimestampSeconds>,
pub(crate) lower_window: Option<Vec<u32>>,
pub(crate) upper_window: Option<Vec<u32>>,
pub(crate) prior_scale: Option<PositiveFloat>,
pub struct HolidayOccurrence {
pub(crate) start: TimestampSeconds,
pub(crate) end: TimestampSeconds,
}

impl Holiday {
/// Create a new holiday.
pub fn new(ds: Vec<TimestampSeconds>) -> Self {
Self {
ds,
lower_window: None,
upper_window: None,
prior_scale: None,
}
impl HolidayOccurrence {
/// Create a new holiday occurrence with the given
/// start and end timestamp.
pub fn new(start: TimestampSeconds, end: TimestampSeconds) -> Self {
Self { start, end }
}

/// Set the lower window for the holiday.
/// Create a new holiday encompassing midnight on the day
/// of the given timestamp to midnight on the following day,
/// in UTC.
///
/// The lower window is the number of days before the holiday
/// that it is observed. For example, if the holiday is on
/// 2023-01-01 and the lower window is 1, then the holiday will
/// _also_ be observed on 2022-12-31.
pub fn with_lower_window(mut self, lower_window: Vec<u32>) -> Result<Self, Error> {
if self.ds.len() != lower_window.len() {
return Err(Error::MismatchedLengths {
a_name: "ds".to_string(),
a: self.ds.len(),
b_name: "lower_window".to_string(),
b: lower_window.len(),
});
}
self.lower_window = Some(lower_window);
Ok(self)
/// This is a convenience method to reproduce the behaviour
/// of the Python and R Prophet implementations, which require
/// all holidays to be day-long events.
///
/// Note that this will _not_ handle daylight saving time
/// transitions correctly. To handle this correctly, use
/// [`HolidayOccurrence::new`] with the correct start and
/// end times, e.g. by calculating them using [`chrono`].
///
/// [`chrono`]: https://docs.rs/chrono/latest/chrono
pub fn for_day(day: TimestampSeconds) -> Self {
Self::for_day_in_tz(day, 0)
}

/// Set the upper window for the holiday.
/// Create a new holiday encompassing midnight on the day
/// of the given timestamp to midnight on the following day,
/// in a timezone represented by the `utc_offset_seconds`.
///
/// The UTC offset can be calculated using, for example,
/// [`chrono::FixedOffset::local_minus_utc`]. Alternatively
/// it's the number of seconds to add to convert from the
/// local time to UTC, so UTC+1 is represented by `3600`
/// and UTC-5 by `-18000`.
///
/// This is a convenience method to reproduce the behaviour
/// of the Python and R Prophet implementations, which require
/// all holidays to be day-long events.
///
/// The upper window is the number of days after the holiday
/// that it is observed. For example, if the holiday is on
/// 2023-01-01 and the upper window is 1, then the holiday will
/// _also_ be observed on 2023-01-02.
pub fn with_upper_window(mut self, upper_window: Vec<u32>) -> Result<Self, Error> {
if self.ds.len() != upper_window.len() {
return Err(Error::MismatchedLengths {
a_name: "ds".to_string(),
a: self.ds.len(),
b_name: "upper_window".to_string(),
b: upper_window.len(),
});
/// Note that this will _not_ handle daylight saving time
/// transitions correctly. To handle this correctly, use
/// [`HolidayOccurrence::new`] with the correct start and
/// end times, e.g. by calculating them using [`chrono`].
///
/// [`chrono`]: https://docs.rs/chrono/latest/chrono
pub fn for_day_in_tz(day: TimestampSeconds, utc_offset_seconds: i32) -> Self {
let day = floor_day(day, utc_offset_seconds);
Self {
start: day,
end: day + ONE_DAY_IN_SECONDS_INT,
}
}

/// Check if the given timestamp is within this occurrence.
pub(crate) fn contains(&self, ds: TimestampSeconds) -> bool {
self.start <= ds && ds < self.end
}
}

/// A holiday to be considered by the Prophet model.
#[derive(Debug, Clone)]
pub struct Holiday {
pub(crate) occurrences: Vec<HolidayOccurrence>,
pub(crate) prior_scale: Option<PositiveFloat>,
}

impl Holiday {
/// Create a new holiday with the given occurrences.
pub fn new(occurrences: Vec<HolidayOccurrence>) -> Self {
Self {
occurrences,
prior_scale: None,
}
self.upper_window = Some(upper_window);
Ok(self)
}

/// Add a prior scale for the holiday.
/// Set the prior scale for the holiday.
pub fn with_prior_scale(mut self, prior_scale: PositiveFloat) -> Self {
self.prior_scale = Some(prior_scale);
self
}
}

fn floor_day(ds: TimestampSeconds, offset: i32) -> TimestampSeconds {
let adjusted_ds = ds + offset as TimestampSeconds;
let remainder =
((adjusted_ds % ONE_DAY_IN_SECONDS_INT) + ONE_DAY_IN_SECONDS_INT) % ONE_DAY_IN_SECONDS_INT;
// Adjust the date to the holiday's UTC offset.
ds - remainder
}

/// Whether or not to standardize a regressor.
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum Standardize {
Expand Down Expand Up @@ -232,3 +290,127 @@ impl Seasonality {
self
}
}

#[cfg(test)]
mod test {
use chrono::{FixedOffset, TimeZone, Utc};

use crate::features::floor_day;

#[test]
fn floor_day_no_offset() {
let offset = Utc;
let expected = offset
.with_ymd_and_hms(2024, 11, 21, 0, 0, 0)
.unwrap()
.timestamp();
assert_eq!(floor_day(expected, 0), expected);
assert_eq!(
floor_day(
offset
.with_ymd_and_hms(2024, 11, 21, 15, 3, 12)
.unwrap()
.timestamp(),
0
),
expected
);
}

#[test]
fn floor_day_positive_offset() {
let offset = FixedOffset::east_opt(60 * 60 * 4).unwrap();
let expected = offset
.with_ymd_and_hms(2024, 11, 21, 0, 0, 0)
.unwrap()
.timestamp();

assert_eq!(floor_day(expected, offset.local_minus_utc()), expected);
assert_eq!(
floor_day(
offset
.with_ymd_and_hms(2024, 11, 21, 15, 3, 12)
.unwrap()
.timestamp(),
offset.local_minus_utc()
),
expected
);
}

#[test]
fn floor_day_negative_offset() {
let offset = FixedOffset::west_opt(60 * 60 * 3).unwrap();
let expected = offset
.with_ymd_and_hms(2024, 11, 21, 0, 0, 0)
.unwrap()
.timestamp();

assert_eq!(floor_day(expected, offset.local_minus_utc()), expected);
assert_eq!(
floor_day(
offset
.with_ymd_and_hms(2024, 11, 21, 15, 3, 12)
.unwrap()
.timestamp(),
offset.local_minus_utc()
),
expected
);
}

#[test]
fn floor_day_edge_cases() {
// Test maximum valid offset (UTC+14)
let max_offset = 14 * 60 * 60;
let offset = FixedOffset::east_opt(max_offset).unwrap();
let expected = offset
.with_ymd_and_hms(2024, 11, 21, 0, 0, 0)
.unwrap()
.timestamp();
assert_eq!(
floor_day(
offset
.with_ymd_and_hms(2024, 11, 21, 12, 0, 0)
.unwrap()
.timestamp(),
offset.local_minus_utc()
),
expected
);

// Test near day boundary
let offset = FixedOffset::east_opt(60).unwrap();
let expected = offset
.with_ymd_and_hms(2024, 11, 21, 0, 0, 0)
.unwrap()
.timestamp();
assert_eq!(
floor_day(
offset
.with_ymd_and_hms(2024, 11, 21, 23, 59, 59)
.unwrap()
.timestamp(),
offset.local_minus_utc()
),
expected
);

// Test when the day is before the epoch.
let offset = FixedOffset::west_opt(3600).unwrap();
let expected = offset
.with_ymd_and_hms(1969, 1, 1, 0, 0, 0)
.unwrap()
.timestamp();
assert_eq!(
floor_day(
offset
.with_ymd_and_hms(1969, 1, 1, 0, 30, 0)
.unwrap()
.timestamp(),
offset.local_minus_utc()
),
expected
);
}
}
2 changes: 1 addition & 1 deletion crates/augurs-prophet/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub type TimestampSeconds = i64;
// navigate the module hierarchy.
pub use data::{PredictionData, TrainingData};
pub use error::Error;
pub use features::{FeatureMode, Holiday, Regressor, Seasonality, Standardize};
pub use features::{FeatureMode, Holiday, HolidayOccurrence, Regressor, Seasonality, Standardize};
pub use optimizer::{Algorithm, Optimizer, TrendIndicator};
pub use positive_float::{PositiveFloat, TryFromFloatError};
pub use prophet::{
Expand Down
44 changes: 20 additions & 24 deletions crates/augurs-prophet/src/prophet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ mod test_custom_seasonal {
optimizer::mock_optimizer::MockOptimizer,
prophet::prep::{FeatureName, Features},
testdata::daily_univariate_ts,
FeatureMode, Holiday, ProphetOptions, Seasonality, SeasonalityOption,
FeatureMode, Holiday, HolidayOccurrence, ProphetOptions, Seasonality, SeasonalityOption,
};

use super::Prophet;
Expand All @@ -534,12 +534,14 @@ mod test_custom_seasonal {
let holiday_dates = ["2017-01-02"]
.iter()
.map(|s| {
s.parse::<NaiveDate>()
.unwrap()
.and_hms_opt(0, 0, 0)
.unwrap()
.and_utc()
.timestamp()
HolidayOccurrence::for_day(
s.parse::<NaiveDate>()
.unwrap()
.and_hms_opt(0, 0, 0)
.unwrap()
.and_utc()
.timestamp(),
)
})
.collect();

Expand Down Expand Up @@ -703,33 +705,27 @@ mod test_holidays {
use chrono::NaiveDate;

use crate::{
optimizer::mock_optimizer::MockOptimizer, testdata::daily_univariate_ts, Holiday, Prophet,
ProphetOptions,
optimizer::mock_optimizer::MockOptimizer, testdata::daily_univariate_ts, Holiday,
HolidayOccurrence, Prophet, ProphetOptions,
};

#[test]
fn fit_predict_holiday() {
let holiday_dates = ["2012-10-09", "2013-10-09"]
.iter()
.map(|s| {
s.parse::<NaiveDate>()
.unwrap()
.and_hms_opt(0, 0, 0)
.unwrap()
.and_utc()
.timestamp()
HolidayOccurrence::for_day(
s.parse::<NaiveDate>()
.unwrap()
.and_hms_opt(0, 0, 0)
.unwrap()
.and_utc()
.timestamp(),
)
})
.collect();
let opts = ProphetOptions {
holidays: [(
"bens-bday".to_string(),
Holiday::new(holiday_dates)
.with_lower_window(vec![0, 0])
.unwrap()
.with_upper_window(vec![1, 1])
.unwrap(),
)]
.into(),
holidays: [("bens-bday".to_string(), Holiday::new(holiday_dates))].into(),
..Default::default()
};
let data = daily_univariate_ts();
Expand Down
Loading

0 comments on commit 6224a6a

Please sign in to comment.