diff --git a/crates/augurs-prophet/src/prophet/prep.rs b/crates/augurs-prophet/src/prophet/prep.rs index 2fb5e2f..0ba78ca 100644 --- a/crates/augurs-prophet/src/prophet/prep.rs +++ b/crates/augurs-prophet/src/prophet/prep.rs @@ -121,7 +121,7 @@ impl ComponentColumns { } /// The name of a feature column in the `X` matrix passed to Stan. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub(super) enum FeatureName { /// A seasonality feature. Seasonality { @@ -660,6 +660,12 @@ impl Prophet { ) -> HashSet { let mut holiday_names = HashSet::with_capacity(holidays.len()); for (name, holiday) in holidays { + // Keep track of holiday columns here. + // For each day surrounding the holiday (decided by the lower and upper windows), + // plus the holiday itself, we want to create a new feature which is 0.0 for all + // days except that day, and 1.0 for that day. + let mut this_holiday_features: HashMap> = HashMap::new(); + // Default to a window of 0 days either side. let lower = holiday .lower_window @@ -691,7 +697,9 @@ impl Prophet { name: name.clone(), _offset: offset, }; - let mut col = vec![0.0; ds.len()]; + let col = this_holiday_features + .entry(col_name.clone()) + .or_insert_with(|| vec![0.0; ds.len()]); // Get the indices of the ds column that are 'on holiday'. // Set the value of the holiday column 1.0 for those dates. @@ -701,16 +709,18 @@ impl Prophet { { col[loc] = 1.0; } - // Add the holiday column to the features frame, and add a corresponding - // prior scale. - features.push(col_name, col); - prior_scales.push( - holiday - .prior_scale - .unwrap_or(self.opts.holidays_prior_scale), - ); } } + // Add the holiday column to the features frame, and add a corresponding + // prior scale. + for (col_name, col) in this_holiday_features.drain() { + features.push(col_name, col); + prior_scales.push( + holiday + .prior_scale + .unwrap_or(self.opts.holidays_prior_scale), + ); + } holiday_names.insert(name.clone()); modes.insert( self.opts