From d58f6d9b911f161a314eec884587b24dc20143e1 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:32:25 -0500 Subject: [PATCH] Update tests for `Statistics` to pass --- .../gempyor_pkg/src/gempyor/statistics.py | 3 ++- .../tests/statistics/test_statistic_class.py | 19 +++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/statistics.py b/flepimop/gempyor_pkg/src/gempyor/statistics.py index 324131601..c96d87821 100644 --- a/flepimop/gempyor_pkg/src/gempyor/statistics.py +++ b/flepimop/gempyor_pkg/src/gempyor/statistics.py @@ -253,7 +253,8 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray) -> xr.DataArray: x, loc=loc, scale=scale * loc.where(loc > 5, 5) ), # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> - # NEW: names of distributions: `norm` --> `norm_homoskedastic`, `norm_cov` --> `norm_heteroskedastic`; names of input `scale` --> `sd` + # NEW: names of distributions: `norm` --> `norm_homoskedastic`, `norm_cov` + # --> `norm_heteroskedastic`; names of input `scale` --> `sd` "norm_homoskedastic": lambda x, loc, sd: scipy.stats.norm.logpdf( x, loc=loc, scale=self.params.get("sd", sd) ), # scale = standard deviation diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py index 4843986c1..7c2e13196 100644 --- a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -468,7 +468,7 @@ def test_llik(self, factory: Callable[[], MockStatisticInput]) -> None: mock_inputs.gt_data[mock_inputs.config["data_var"]].coords ) dist_name = mock_inputs.config["likelihood"]["dist"] - if dist_name in {"absolute_error", "rmse"}: + if dist_name == "absolute_error": # MAE produces a single repeated number assert np.allclose( log_likelihood.values, @@ -481,6 +481,21 @@ def test_llik(self, factory: Callable[[], MockStatisticInput]) -> None: ) ), ) + elif dist_name == "rmse": + assert np.allclose( + log_likelihood.values, + -np.log( + np.sqrt( + np.nansum( + np.power( + mock_inputs.model_data[mock_inputs.config["sim_var"]] + - mock_inputs.gt_data[mock_inputs.config["data_var"]], + 2.0, + ) + ) + ) + ), + ) elif dist_name == "pois": assert np.allclose( log_likelihood.values, @@ -489,7 +504,7 @@ def test_llik(self, factory: Callable[[], MockStatisticInput]) -> None: mock_inputs.model_data[mock_inputs.config["data_var"]].values, ), ) - elif dist_name == {"norm", "norm_cov"}: + elif dist_name in {"norm", "norm_cov"}: scale = mock_inputs.config["likelihood"]["params"]["scale"] if dist_name == "norm_cov": scale *= mock_inputs.model_data[mock_inputs.config["sim_var"]].where(