Skip to content

Commit

Permalink
fix yeo johnson implementation and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
edwardcqian committed Dec 10, 2024
1 parent 934e431 commit 6e82356
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 9 deletions.
8 changes: 4 additions & 4 deletions crates/augurs-forecaster/src/forecaster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ mod test {
assert_all_approx_eq(
&forecasts.point,
&[
6.205557727170964,
6.000000132803496,
6.205557727170964,
6.000000132803496,
5.205557727170964,
5.000000132803496,
5.205557727170964,
5.000000132803496,
],
);
}
Expand Down
7 changes: 5 additions & 2 deletions crates/augurs-forecaster/src/power_transforms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,17 @@ struct OptimizationParams {
max_iterations: u64,
}


fn optimize_lambda<T: CostFunction<Param = f64, Output = f64>>(
cost: T,
params: OptimizationParams,
) -> Result<f64, Error> {
let solver = BrentOpt::new(params.lower_bound, params.upper_bound);
let result = Executor::new(cost, solver)
.configure(|state| state.param(params.initial_param).max_iters(params.max_iterations))
.configure(|state| {
state
.param(params.initial_param)
.max_iters(params.max_iterations)
})
.run();

result.and_then(|res| {
Expand Down
34 changes: 31 additions & 3 deletions crates/augurs-forecaster/src/transforms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -581,10 +581,20 @@ impl<T> YeoJohnsonExt for T where T: Iterator<Item = f64> {}

/// Returns the inverse Yeo-Johnson transformation of the given value.
fn inverse_yeo_johnson(y: f64, lambda: f64) -> f64 {
if lambda == 0.0 {
y.exp()
const EPSILON: f64 = 1e-6;

if y >= 0.0 && lambda.abs() < EPSILON {
// For lambda close to 0 (positive values)
(y.exp()) - 1.0
} else if y >= 0.0 {
// For positive values (lambda not close to 0)
(y * lambda + 1.0).powf(1.0 / lambda) - 1.0
} else if (lambda - 2.0).abs() < EPSILON {
// For lambda close to 2 (negative values)
-(-y.exp() - 1.0)
} else {
(y * lambda + 1.0).powf(1.0 / lambda)
// For negative values (lambda not close to 2)
-((-((2.0 - lambda) * y) + 1.0).powf(1.0 / (2.0 - lambda)) - 1.0)
}
}

Expand Down Expand Up @@ -770,4 +780,22 @@ mod test {
let actual: Vec<_> = data.into_iter().inverse_box_cox(lambda).collect();
assert_all_close(&expected, &actual);
}

#[test]
fn yeo_johnson_test() {
let data = vec![-1.0, 0.0, 1.0];
let lambda = 0.5;
let expected = vec![-1.2189514164974602, 0.0, 0.8284271247461903];
let actual: Vec<_> = data.into_iter().yeo_johnson(lambda).collect();
assert_all_close(&expected, &actual);
}

#[test]
fn inverse_yeo_johnson_test() {
let data = vec![-1.2189514164974602, 0.0, 0.8284271247461903];
let lambda = 0.5;
let expected = vec![-1.0, 0.0, 1.0];
let actual: Vec<_> = data.into_iter().inverse_yeo_johnson(lambda).collect();
assert_all_close(&expected, &actual);
}
}

0 comments on commit 6e82356

Please sign in to comment.