diff --git a/crates/augurs-forecaster/src/power_transforms.rs b/crates/augurs-forecaster/src/power_transforms.rs index adf7f14..8c5afff 100644 --- a/crates/augurs-forecaster/src/power_transforms.rs +++ b/crates/augurs-forecaster/src/power_transforms.rs @@ -2,9 +2,14 @@ use crate::transforms::box_cox; use argmin::core::*; use argmin::solver::brent::BrentOpt; -fn box_cox_log_likelihood(data: &[f64], lambda: f64) -> f64 { +fn box_cox_log_likelihood(data: &[f64], lambda: f64) -> Result { let n = data.len() as f64; - assert!(n > 0.0, "Data must not be empty"); + if n == 0.0 { + return Err(Error::msg("Data must not be empty")); + } + if data.iter().any(|&x| x <= 0.0) { + return Err(Error::msg("All data must be greater than 0")); + } let transformed_data: Vec = data.iter().map(|&x| box_cox(x, lambda)).collect(); let mean_transformed: f64 = transformed_data.iter().copied().sum::() / n; let variance: f64 = transformed_data @@ -14,9 +19,12 @@ fn box_cox_log_likelihood(data: &[f64], lambda: f64) -> f64 { / n; // Avoid log(0) by ensuring variance is positive + if variance <= 0.0 { + return Err(Error::msg("Variance must be positive")); + } let log_likelihood = -0.5 * n * variance.ln() + (lambda - 1.0) * data.iter().map(|&x| x.ln()).sum::(); - log_likelihood + Ok(log_likelihood) } #[derive(Clone)] @@ -30,12 +38,12 @@ impl CostFunction for BoxCoxProblem<'_> { // The goal is to minimize the negative log-likelihood fn cost(&self, lambda: &Self::Param) -> Result { - Ok(-box_cox_log_likelihood(&self.data, *lambda)) + box_cox_log_likelihood(&self.data, *lambda).map(|ll| -ll) } } /// Optimize the lambda parameter for the Box-Cox transformation -pub (crate) fn optimize_lambda(data: &[f64]) -> f64 { +pub (crate) fn optimize_lambda(data: &[f64]) -> Result { let cost = BoxCoxProblem { data: data }; let init_param = 0.5; let solver = BrentOpt::new(-2.0, 2.0); @@ -44,10 +52,12 @@ pub (crate) fn optimize_lambda(data: &[f64]) -> f64 { .configure(|state| state.param(init_param).max_iters(100)) .run(); - match result { - Ok(result) => result.state().best_param.unwrap(), - Err(error) => panic!("Optimization failed: {}", error), - } + result + .and_then(|res| { + res.state() + .best_param + .ok_or_else(|| Error::msg("No best parameter found")) + }) } #[cfg(test)] @@ -59,7 +69,23 @@ mod test { fn correct_optimal_lambda() { let data = &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]; let got = optimize_lambda(data); - assert_approx_eq!(got, 0.7123778635679304); + assert!(got.is_ok()); + let lambda = got.unwrap(); + assert_approx_eq!(lambda, 0.7123778635679304); + } + + #[test] + fn optimize_lambda_empty_data() { + let data = &[]; + let got = optimize_lambda(data); + assert!(got.is_err()); + } + + #[test] + fn optimize_lambda_non_positive_data() { + let data = &[0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]; + let got = optimize_lambda(data); + assert!(got.is_err()); } #[test] @@ -67,6 +93,16 @@ mod test { let data = &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]; let lambda = 1.0; let got = box_cox_log_likelihood(data, lambda); - assert_approx_eq!(got, 11.266065387038703); + assert!(got.is_ok()); + let llf = got.unwrap(); + assert_approx_eq!(llf, 11.266065387038703); + } + + #[test] + fn test_boxcox_llf_non_positive() { + let data = &[0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]; + let lambda = 0.0; + let got = box_cox_log_likelihood(data, lambda); + assert!(got.is_err()); } } diff --git a/crates/augurs-forecaster/src/transforms.rs b/crates/augurs-forecaster/src/transforms.rs index c9757c7..3c5d504 100644 --- a/crates/augurs-forecaster/src/transforms.rs +++ b/crates/augurs-forecaster/src/transforms.rs @@ -113,7 +113,7 @@ impl Transform { /// The Power transformation is defined as: /// pub fn power_transform(data: &[f64]) -> Self { - let lambda = optimize_lambda(data); + let lambda = optimize_lambda(data).unwrap_or_else(|_| 0.0); Self::BoxCox { lambda } }