Skip to content

Commit

Permalink
fix: add a separate feature for each holiday's lower/upper windows (#179
Browse files Browse the repository at this point in the history
)
  • Loading branch information
sd2k authored Nov 21, 2024
1 parent 4c58635 commit 856be42
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions crates/augurs-prophet/src/prophet/prep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl ComponentColumns {
}

/// The name of a feature column in the `X` matrix passed to Stan.
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(super) enum FeatureName {
/// A seasonality feature.
Seasonality {
Expand Down Expand Up @@ -660,6 +660,12 @@ impl<O> Prophet<O> {
) -> HashSet<String> {
let mut holiday_names = HashSet::with_capacity(holidays.len());
for (name, holiday) in holidays {
// Keep track of holiday columns here.
// For each day surrounding the holiday (decided by the lower and upper windows),
// plus the holiday itself, we want to create a new feature which is 0.0 for all
// days except that day, and 1.0 for that day.
let mut this_holiday_features: HashMap<FeatureName, Vec<f64>> = HashMap::new();

// Default to a window of 0 days either side.
let lower = holiday
.lower_window
Expand Down Expand Up @@ -691,7 +697,9 @@ impl<O> Prophet<O> {
name: name.clone(),
_offset: offset,
};
let mut col = vec![0.0; ds.len()];
let col = this_holiday_features
.entry(col_name.clone())
.or_insert_with(|| vec![0.0; ds.len()]);

// Get the indices of the ds column that are 'on holiday'.
// Set the value of the holiday column 1.0 for those dates.
Expand All @@ -701,16 +709,18 @@ impl<O> Prophet<O> {
{
col[loc] = 1.0;
}
// Add the holiday column to the features frame, and add a corresponding
// prior scale.
features.push(col_name, col);
prior_scales.push(
holiday
.prior_scale
.unwrap_or(self.opts.holidays_prior_scale),
);
}
}
// Add the holiday column to the features frame, and add a corresponding
// prior scale.
for (col_name, col) in this_holiday_features.drain() {
features.push(col_name, col);
prior_scales.push(
holiday
.prior_scale
.unwrap_or(self.opts.holidays_prior_scale),
);
}
holiday_names.insert(name.clone());
modes.insert(
self.opts
Expand Down

0 comments on commit 856be42

Please sign in to comment.