diff --git a/crates/augurs-forecaster/src/forecaster.rs b/crates/augurs-forecaster/src/forecaster.rs index 7761f1e..d40603f 100644 --- a/crates/augurs-forecaster/src/forecaster.rs +++ b/crates/augurs-forecaster/src/forecaster.rs @@ -155,10 +155,10 @@ mod test { assert_all_approx_eq( &forecasts.point, &[ - 6.205557727170964, - 6.000000132803496, - 6.205557727170964, - 6.000000132803496, + 5.205557727170964, + 5.000000132803496, + 5.205557727170964, + 5.000000132803496, ], ); } diff --git a/crates/augurs-forecaster/src/power_transforms.rs b/crates/augurs-forecaster/src/power_transforms.rs index e15640d..2373bbd 100644 --- a/crates/augurs-forecaster/src/power_transforms.rs +++ b/crates/augurs-forecaster/src/power_transforms.rs @@ -109,14 +109,17 @@ struct OptimizationParams { max_iterations: u64, } - fn optimize_lambda>( cost: T, params: OptimizationParams, ) -> Result { let solver = BrentOpt::new(params.lower_bound, params.upper_bound); let result = Executor::new(cost, solver) - .configure(|state| state.param(params.initial_param).max_iters(params.max_iterations)) + .configure(|state| { + state + .param(params.initial_param) + .max_iters(params.max_iterations) + }) .run(); result.and_then(|res| { diff --git a/crates/augurs-forecaster/src/transforms.rs b/crates/augurs-forecaster/src/transforms.rs index 8f49450..688bea5 100644 --- a/crates/augurs-forecaster/src/transforms.rs +++ b/crates/augurs-forecaster/src/transforms.rs @@ -581,10 +581,20 @@ impl YeoJohnsonExt for T where T: Iterator {} /// Returns the inverse Yeo-Johnson transformation of the given value. fn inverse_yeo_johnson(y: f64, lambda: f64) -> f64 { - if lambda == 0.0 { - y.exp() + const EPSILON: f64 = 1e-6; + + if y >= 0.0 && lambda.abs() < EPSILON { + // For lambda close to 0 (positive values) + (y.exp()) - 1.0 + } else if y >= 0.0 { + // For positive values (lambda not close to 0) + (y * lambda + 1.0).powf(1.0 / lambda) - 1.0 + } else if (lambda - 2.0).abs() < EPSILON { + // For lambda close to 2 (negative values) + -(-y.exp() - 1.0) } else { - (y * lambda + 1.0).powf(1.0 / lambda) + // For negative values (lambda not close to 2) + -((-((2.0 - lambda) * y) + 1.0).powf(1.0 / (2.0 - lambda)) - 1.0) } } @@ -770,4 +780,22 @@ mod test { let actual: Vec<_> = data.into_iter().inverse_box_cox(lambda).collect(); assert_all_close(&expected, &actual); } + + #[test] + fn yeo_johnson_test() { + let data = vec![-1.0, 0.0, 1.0]; + let lambda = 0.5; + let expected = vec![-1.2189514164974602, 0.0, 0.8284271247461903]; + let actual: Vec<_> = data.into_iter().yeo_johnson(lambda).collect(); + assert_all_close(&expected, &actual); + } + + #[test] + fn inverse_yeo_johnson_test() { + let data = vec![-1.2189514164974602, 0.0, 0.8284271247461903]; + let lambda = 0.5; + let expected = vec![-1.0, 0.0, 1.0]; + let actual: Vec<_> = data.into_iter().inverse_yeo_johnson(lambda).collect(); + assert_all_close(&expected, &actual); + } }