Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change: use u32 instead of i32 for lower/upper windows #177

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions crates/augurs-prophet/src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ pub enum FeatureMode {
#[derive(Debug, Clone)]
pub struct Holiday {
pub(crate) ds: Vec<TimestampSeconds>,
pub(crate) lower_window: Option<Vec<i32>>,
pub(crate) upper_window: Option<Vec<i32>>,
pub(crate) lower_window: Option<Vec<u32>>,
pub(crate) upper_window: Option<Vec<u32>>,
pub(crate) prior_scale: Option<PositiveFloat>,
}

Expand All @@ -37,9 +37,9 @@ impl Holiday {
///
/// The lower window is the number of days before the holiday
/// that it is observed. For example, if the holiday is on
/// 2023-01-01 and the lower window is -1, then the holiday will
/// 2023-01-01 and the lower window is 1, then the holiday will
/// _also_ be observed on 2022-12-31.
pub fn with_lower_window(mut self, lower_window: Vec<i32>) -> Result<Self, Error> {
pub fn with_lower_window(mut self, lower_window: Vec<u32>) -> Result<Self, Error> {
if self.ds.len() != lower_window.len() {
return Err(Error::MismatchedLengths {
a_name: "ds".to_string(),
Expand All @@ -58,7 +58,7 @@ impl Holiday {
/// that it is observed. For example, if the holiday is on
/// 2023-01-01 and the upper window is 1, then the holiday will
/// _also_ be observed on 2023-01-02.
pub fn with_upper_window(mut self, upper_window: Vec<i32>) -> Result<Self, Error> {
pub fn with_upper_window(mut self, upper_window: Vec<u32>) -> Result<Self, Error> {
if self.ds.len() != upper_window.len() {
return Err(Error::MismatchedLengths {
a_name: "ds".to_string(),
Expand Down
12 changes: 9 additions & 3 deletions crates/augurs-prophet/src/prophet/prep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -664,12 +664,16 @@ impl<O> Prophet<O> {
let lower = holiday
.lower_window
.as_ref()
.map(|x| Box::new(x.iter().copied()) as Box<dyn Iterator<Item = i32>>)
.map(|x| {
Box::new(x.iter().copied().map(|x| x as i32)) as Box<dyn Iterator<Item = i32>>
})
.unwrap_or_else(|| Box::new(std::iter::repeat(0)));
let upper = holiday
.upper_window
.as_ref()
.map(|x| Box::new(x.iter().copied()) as Box<dyn Iterator<Item = i32>>)
.map(|x| {
Box::new(x.iter().copied().map(|x| x as i32)) as Box<dyn Iterator<Item = i32>>
})
.unwrap_or_else(|| Box::new(std::iter::repeat(0)));

for (dt, lower, upper) in izip!(holiday.ds, lower, upper) {
Expand All @@ -678,7 +682,9 @@ impl<O> Prophet<O> {
let dt_date = dt - remainder;

// Check each of the possible offsets allowed by the lower/upper windows.
for offset in lower..=upper {
// We know that the lower window is always positive since it was originally
// a u32, so we can use `-lower..upper` here.
for offset in -lower..=upper {
let offset_seconds = offset as i64 * ONE_DAY_IN_SECONDS as i64;
let occurrence = dt_date + offset_seconds;
let col_name = FeatureName::Holiday {
Expand Down
6 changes: 3 additions & 3 deletions js/augurs-prophet-js/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1230,10 +1230,10 @@ pub struct Holiday {
///
/// The lower window is the number of days before the holiday
/// that it is observed. For example, if the holiday is on
/// 2023-01-01 and the lower window is -1, then the holiday will
/// 2023-01-01 and the lower window is 1, then the holiday will
/// _also_ be observed on 2022-12-31.
#[tsify(optional)]
pub lower_window: Option<Vec<i32>>,
pub lower_window: Option<Vec<u32>>,

/// The upper window for the holiday.
///
Expand All @@ -1242,7 +1242,7 @@ pub struct Holiday {
/// 2023-01-01 and the upper window is 1, then the holiday will
/// _also_ be observed on 2023-01-02.
#[tsify(optional)]
pub upper_window: Option<Vec<i32>>,
pub upper_window: Option<Vec<u32>>,

/// The prior scale for the holiday.
#[tsify(optional)]
Expand Down
Loading