Skip to content

Commit

Permalink
change: use u32 instead of i32 for lower/upper windows
Browse files Browse the repository at this point in the history
The Python Prophet library expects the lower window to be negative
and the upper window to be positive, but here it makes more sense
to restrict both to being positive and just assume that the lower
window refers to 'before the holiday date'.
  • Loading branch information
sd2k committed Nov 21, 2024
1 parent effd7ee commit 157eb4a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
10 changes: 5 additions & 5 deletions crates/augurs-prophet/src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ pub enum FeatureMode {
#[derive(Debug, Clone)]
pub struct Holiday {
pub(crate) ds: Vec<TimestampSeconds>,
pub(crate) lower_window: Option<Vec<i32>>,
pub(crate) upper_window: Option<Vec<i32>>,
pub(crate) lower_window: Option<Vec<u32>>,
pub(crate) upper_window: Option<Vec<u32>>,
pub(crate) prior_scale: Option<PositiveFloat>,
}

Expand All @@ -37,9 +37,9 @@ impl Holiday {
///
/// The lower window is the number of days before the holiday
/// that it is observed. For example, if the holiday is on
/// 2023-01-01 and the lower window is -1, then the holiday will
/// 2023-01-01 and the lower window is 1, then the holiday will
/// _also_ be observed on 2022-12-31.
pub fn with_lower_window(mut self, lower_window: Vec<i32>) -> Result<Self, Error> {
pub fn with_lower_window(mut self, lower_window: Vec<u32>) -> Result<Self, Error> {
if self.ds.len() != lower_window.len() {
return Err(Error::MismatchedLengths {
a_name: "ds".to_string(),
Expand All @@ -58,7 +58,7 @@ impl Holiday {
/// that it is observed. For example, if the holiday is on
/// 2023-01-01 and the upper window is 1, then the holiday will
/// _also_ be observed on 2023-01-02.
pub fn with_upper_window(mut self, upper_window: Vec<i32>) -> Result<Self, Error> {
pub fn with_upper_window(mut self, upper_window: Vec<u32>) -> Result<Self, Error> {
if self.ds.len() != upper_window.len() {
return Err(Error::MismatchedLengths {
a_name: "ds".to_string(),
Expand Down
12 changes: 9 additions & 3 deletions crates/augurs-prophet/src/prophet/prep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -664,12 +664,16 @@ impl<O> Prophet<O> {
let lower = holiday
.lower_window
.as_ref()
.map(|x| Box::new(x.iter().copied()) as Box<dyn Iterator<Item = i32>>)
.map(|x| {
Box::new(x.iter().copied().map(|x| x as i32)) as Box<dyn Iterator<Item = i32>>
})
.unwrap_or_else(|| Box::new(std::iter::repeat(0)));
let upper = holiday
.upper_window
.as_ref()
.map(|x| Box::new(x.iter().copied()) as Box<dyn Iterator<Item = i32>>)
.map(|x| {
Box::new(x.iter().copied().map(|x| x as i32)) as Box<dyn Iterator<Item = i32>>
})
.unwrap_or_else(|| Box::new(std::iter::repeat(0)));

for (dt, lower, upper) in izip!(holiday.ds, lower, upper) {
Expand All @@ -678,7 +682,9 @@ impl<O> Prophet<O> {
let dt_date = dt - remainder;

// Check each of the possible offsets allowed by the lower/upper windows.
for offset in lower..=upper {
// We know that the lower window is always positive since it was originally
// a u32, so we can use `-lower..upper` here.
for offset in -lower..=upper {
let offset_seconds = offset as i64 * ONE_DAY_IN_SECONDS as i64;
let occurrence = dt_date + offset_seconds;
let col_name = FeatureName::Holiday {
Expand Down
6 changes: 3 additions & 3 deletions js/augurs-prophet-js/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1230,10 +1230,10 @@ pub struct Holiday {
///
/// The lower window is the number of days before the holiday
/// that it is observed. For example, if the holiday is on
/// 2023-01-01 and the lower window is -1, then the holiday will
/// 2023-01-01 and the lower window is 1, then the holiday will
/// _also_ be observed on 2022-12-31.
#[tsify(optional)]
pub lower_window: Option<Vec<i32>>,
pub lower_window: Option<Vec<u32>>,

/// The upper window for the holiday.
///
Expand All @@ -1242,7 +1242,7 @@ pub struct Holiday {
/// 2023-01-01 and the upper window is 1, then the holiday will
/// _also_ be observed on 2023-01-02.
#[tsify(optional)]
pub upper_window: Option<Vec<i32>>,
pub upper_window: Option<Vec<u32>>,

/// The prior scale for the holiday.
#[tsify(optional)]
Expand Down

0 comments on commit 157eb4a

Please sign in to comment.