diff --git a/crates/augurs-prophet/src/features.rs b/crates/augurs-prophet/src/features.rs index bde803b9..5e5fb23e 100644 --- a/crates/augurs-prophet/src/features.rs +++ b/crates/augurs-prophet/src/features.rs @@ -17,8 +17,8 @@ pub enum FeatureMode { #[derive(Debug, Clone)] pub struct Holiday { pub(crate) ds: Vec, - pub(crate) lower_window: Option>, - pub(crate) upper_window: Option>, + pub(crate) lower_window: Option>, + pub(crate) upper_window: Option>, pub(crate) prior_scale: Option, } @@ -37,9 +37,9 @@ impl Holiday { /// /// The lower window is the number of days before the holiday /// that it is observed. For example, if the holiday is on - /// 2023-01-01 and the lower window is -1, then the holiday will + /// 2023-01-01 and the lower window is 1, then the holiday will /// _also_ be observed on 2022-12-31. - pub fn with_lower_window(mut self, lower_window: Vec) -> Result { + pub fn with_lower_window(mut self, lower_window: Vec) -> Result { if self.ds.len() != lower_window.len() { return Err(Error::MismatchedLengths { a_name: "ds".to_string(), @@ -58,7 +58,7 @@ impl Holiday { /// that it is observed. For example, if the holiday is on /// 2023-01-01 and the upper window is 1, then the holiday will /// _also_ be observed on 2023-01-02. - pub fn with_upper_window(mut self, upper_window: Vec) -> Result { + pub fn with_upper_window(mut self, upper_window: Vec) -> Result { if self.ds.len() != upper_window.len() { return Err(Error::MismatchedLengths { a_name: "ds".to_string(), diff --git a/crates/augurs-prophet/src/prophet/prep.rs b/crates/augurs-prophet/src/prophet/prep.rs index 5893c878..2fb5e2f5 100644 --- a/crates/augurs-prophet/src/prophet/prep.rs +++ b/crates/augurs-prophet/src/prophet/prep.rs @@ -664,12 +664,16 @@ impl Prophet { let lower = holiday .lower_window .as_ref() - .map(|x| Box::new(x.iter().copied()) as Box>) + .map(|x| { + Box::new(x.iter().copied().map(|x| x as i32)) as Box> + }) .unwrap_or_else(|| Box::new(std::iter::repeat(0))); let upper = holiday .upper_window .as_ref() - .map(|x| Box::new(x.iter().copied()) as Box>) + .map(|x| { + Box::new(x.iter().copied().map(|x| x as i32)) as Box> + }) .unwrap_or_else(|| Box::new(std::iter::repeat(0))); for (dt, lower, upper) in izip!(holiday.ds, lower, upper) { @@ -678,7 +682,9 @@ impl Prophet { let dt_date = dt - remainder; // Check each of the possible offsets allowed by the lower/upper windows. - for offset in lower..=upper { + // We know that the lower window is always positive since it was originally + // a u32, so we can use `-lower..upper` here. + for offset in -lower..=upper { let offset_seconds = offset as i64 * ONE_DAY_IN_SECONDS as i64; let occurrence = dt_date + offset_seconds; let col_name = FeatureName::Holiday { diff --git a/js/augurs-prophet-js/src/lib.rs b/js/augurs-prophet-js/src/lib.rs index 633c9e96..26dce638 100644 --- a/js/augurs-prophet-js/src/lib.rs +++ b/js/augurs-prophet-js/src/lib.rs @@ -1230,10 +1230,10 @@ pub struct Holiday { /// /// The lower window is the number of days before the holiday /// that it is observed. For example, if the holiday is on - /// 2023-01-01 and the lower window is -1, then the holiday will + /// 2023-01-01 and the lower window is 1, then the holiday will /// _also_ be observed on 2022-12-31. #[tsify(optional)] - pub lower_window: Option>, + pub lower_window: Option>, /// The upper window for the holiday. /// @@ -1242,7 +1242,7 @@ pub struct Holiday { /// 2023-01-01 and the upper window is 1, then the holiday will /// _also_ be observed on 2023-01-02. #[tsify(optional)] - pub upper_window: Option>, + pub upper_window: Option>, /// The prior scale for the holiday. #[tsify(optional)]