Skip to content

Commit

Permalink
Refactor predict_features
Browse files Browse the repository at this point in the history
  • Loading branch information
sd2k committed Oct 10, 2024
1 parent 6c4aec1 commit 467a531
Showing 1 changed file with 35 additions and 61 deletions.
96 changes: 35 additions & 61 deletions crates/augurs-prophet/src/prophet/predict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
optimizer::OptimizedParams, util::FloatIterExt, Error, GrowthType, Prophet, TimestampSeconds,
};

use super::prep::{ComponentName, Features, FeaturesFrame, ProcessedData};
use super::prep::{ComponentName, Features, FeaturesFrame, Modes, ProcessedData};

/// The prediction for a feature.
///
Expand Down Expand Up @@ -252,6 +252,9 @@ impl<O> Prophet<O> {
modes,
..
} = features;
let predict_feature = |col, f: fn(String) -> ComponentName| {
Self::predict_components(col, &features.data, &params.beta, y_scale, modes, f)
};
Ok(FeaturePredictions {
additive: Self::predict_feature(
&component_columns.additive,
Expand All @@ -267,69 +270,40 @@ impl<O> Prophet<O> {
y_scale,
false,
),
holidays: component_columns
.holidays
.iter()
.map(|(name, holiday)| {
(
name.clone(),
Self::predict_feature(
holiday,
&features.data,
&params.beta,
y_scale,
modes
.additive
// Annoying that we have to clone here, we could work around it
// if we made a `ComponentNameRef` type but that's a lot of work.
.contains(&ComponentName::Holiday(name.clone())),
),
)
})
.collect(),
seasonalities: component_columns
.seasonalities
.iter()
.map(|(name, seasonality)| {
(
name.clone(),
Self::predict_feature(
seasonality,
&features.data,
&params.beta,
y_scale,
modes
.additive
// Annoying that we have to clone here, we could work around it
// if we made a `ComponentNameRef` type but that's a lot of work.
.contains(&ComponentName::Holiday(name.clone())),
),
)
})
.collect(),
regressors: component_columns
.regressors
.iter()
.map(|(name, regressor)| {
(
name.clone(),
Self::predict_feature(
regressor,
&features.data,
&params.beta,
y_scale,
modes
.additive
// Annoying that we have to clone here, we could work around it
// if we made a `ComponentNameRef` type but that's a lot of work.
.contains(&ComponentName::Holiday(name.clone())),
),
)
})
.collect(),
holidays: predict_feature(&component_columns.holidays, ComponentName::Holiday),
seasonalities: predict_feature(
&component_columns.seasonalities,
ComponentName::Seasonality,
),
regressors: predict_feature(&component_columns.regressors, ComponentName::Regressor),
})
}

fn predict_components(
component_columns: &HashMap<String, Vec<i32>>,
#[allow(non_snake_case)] X: &[Vec<f64>],
beta: &[f64],
y_scale: f64,
modes: &Modes,
make_mode: impl Fn(String) -> ComponentName,
) -> HashMap<String, FeaturePrediction> {
component_columns
.iter()
.map(|(name, component_col)| {
(
name.clone(),
Self::predict_feature(
component_col,
X,
beta,
y_scale,
modes.additive.contains(&make_mode(name.clone())),
),
)
})
.collect()
}

pub(super) fn predict_feature(
component_col: &[i32],
#[allow(non_snake_case)] X: &[Vec<f64>],
Expand Down

0 comments on commit 467a531

Please sign in to comment.