Skip to content

Commit

Permalink
update optimize_lambda to handle errors
Browse files Browse the repository at this point in the history
  • Loading branch information
edwardcqian committed Dec 9, 2024
1 parent 214bee9 commit 3d5ea97
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 12 deletions.
58 changes: 47 additions & 11 deletions crates/augurs-forecaster/src/power_transforms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@ use crate::transforms::box_cox;
use argmin::core::*;
use argmin::solver::brent::BrentOpt;

fn box_cox_log_likelihood(data: &[f64], lambda: f64) -> f64 {
fn box_cox_log_likelihood(data: &[f64], lambda: f64) -> Result<f64, Error> {
let n = data.len() as f64;
assert!(n > 0.0, "Data must not be empty");
if n == 0.0 {
return Err(Error::msg("Data must not be empty"));
}
if data.iter().any(|&x| x <= 0.0) {
return Err(Error::msg("All data must be greater than 0"));
}
let transformed_data: Vec<f64> = data.iter().map(|&x| box_cox(x, lambda)).collect();
let mean_transformed: f64 = transformed_data.iter().copied().sum::<f64>() / n;
let variance: f64 = transformed_data
Expand All @@ -14,9 +19,12 @@ fn box_cox_log_likelihood(data: &[f64], lambda: f64) -> f64 {
/ n;

// Avoid log(0) by ensuring variance is positive
if variance <= 0.0 {
return Err(Error::msg("Variance must be positive"));
}
let log_likelihood =
-0.5 * n * variance.ln() + (lambda - 1.0) * data.iter().map(|&x| x.ln()).sum::<f64>();
log_likelihood
Ok(log_likelihood)
}

#[derive(Clone)]
Expand All @@ -30,12 +38,12 @@ impl CostFunction for BoxCoxProblem<'_> {

// The goal is to minimize the negative log-likelihood
fn cost(&self, lambda: &Self::Param) -> Result<Self::Output, Error> {
Ok(-box_cox_log_likelihood(&self.data, *lambda))
box_cox_log_likelihood(&self.data, *lambda).map(|ll| -ll)
}
}

/// Optimize the lambda parameter for the Box-Cox transformation
pub (crate) fn optimize_lambda(data: &[f64]) -> f64 {
pub (crate) fn optimize_lambda(data: &[f64]) -> Result<f64, Error> {
let cost = BoxCoxProblem { data: data };
let init_param = 0.5;
let solver = BrentOpt::new(-2.0, 2.0);
Expand All @@ -44,10 +52,12 @@ pub (crate) fn optimize_lambda(data: &[f64]) -> f64 {
.configure(|state| state.param(init_param).max_iters(100))
.run();

match result {
Ok(result) => result.state().best_param.unwrap(),
Err(error) => panic!("Optimization failed: {}", error),
}
result
.and_then(|res| {
res.state()
.best_param
.ok_or_else(|| Error::msg("No best parameter found"))
})
}

#[cfg(test)]
Expand All @@ -59,14 +69,40 @@ mod test {
fn correct_optimal_lambda() {
let data = &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let got = optimize_lambda(data);
assert_approx_eq!(got, 0.7123778635679304);
assert!(got.is_ok());
let lambda = got.unwrap();
assert_approx_eq!(lambda, 0.7123778635679304);
}

#[test]
fn optimize_lambda_empty_data() {
let data = &[];
let got = optimize_lambda(data);
assert!(got.is_err());
}

#[test]
fn optimize_lambda_non_positive_data() {
let data = &[0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let got = optimize_lambda(data);
assert!(got.is_err());
}

#[test]
fn test_boxcox_llf() {
let data = &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let lambda = 1.0;
let got = box_cox_log_likelihood(data, lambda);
assert_approx_eq!(got, 11.266065387038703);
assert!(got.is_ok());
let llf = got.unwrap();
assert_approx_eq!(llf, 11.266065387038703);
}

#[test]
fn test_boxcox_llf_non_positive() {
let data = &[0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let lambda = 0.0;
let got = box_cox_log_likelihood(data, lambda);
assert!(got.is_err());
}
}
2 changes: 1 addition & 1 deletion crates/augurs-forecaster/src/transforms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl Transform {
/// The Power transformation is defined as:
///
pub fn power_transform(data: &[f64]) -> Self {
let lambda = optimize_lambda(data);
let lambda = optimize_lambda(data).unwrap_or_else(|_| 0.0);
Self::BoxCox { lambda }
}

Expand Down

0 comments on commit 3d5ea97

Please sign in to comment.