From d9088eb8e2f26e5b38cc45160fdca49786b1a359 Mon Sep 17 00:00:00 2001 From: larsevj Date: Thu, 19 Dec 2024 10:35:56 +0100 Subject: [PATCH] Validate triangular dist parameters on startup --- src/ert/config/gen_kw_config.py | 21 +++++++ .../unit_tests/config/test_gen_kw_config.py | 56 +++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/src/ert/config/gen_kw_config.py b/src/ert/config/gen_kw_config.py index 65cd288c04d..f371c5622d1 100644 --- a/src/ert/config/gen_kw_config.py +++ b/src/ert/config/gen_kw_config.py @@ -205,6 +205,25 @@ def _check_non_negative_parameter(param: str, prior: PriorDict) -> None: ).set_context(self.name) ) + def _check_valid_triangular_parameters(prior: PriorDict) -> None: + key = prior["key"] + dist = prior["function"] + xmin, xmode, xmax = prior["parameters"].values() + if not (xmin < xmax): + errors.append( + ErrorInfo( + f"Minimum {xmin} must be strictly less than the maxiumum {xmax}" + f" for {dist} distributed parameter {key}", + ).set_context(self.name) + ) + if not (xmin <= xmode <= xmax): + errors.append( + ErrorInfo( + f"The mode {xmode} must be between the minimum {xmin} and maximum {xmax}" + f" for {dist} distributed parameter {key}", + ).set_context(self.name) + ) + unique_keys = set() for prior in self.get_priors(): key = prior["key"] @@ -219,6 +238,8 @@ def _check_non_negative_parameter(param: str, prior: PriorDict) -> None: if prior["function"] == "LOGNORMAL": _check_non_negative_parameter("MEAN", prior) _check_non_negative_parameter("STD", prior) + elif prior["function"] == "TRIANGULAR": + _check_valid_triangular_parameters(prior) elif prior["function"] in {"NORMAL", "TRUNCATED_NORMAL"}: _check_non_negative_parameter("STD", prior) if errors: diff --git a/tests/ert/unit_tests/config/test_gen_kw_config.py b/tests/ert/unit_tests/config/test_gen_kw_config.py index 71127de168f..b88d582a9e8 100644 --- a/tests/ert/unit_tests/config/test_gen_kw_config.py +++ b/tests/ert/unit_tests/config/test_gen_kw_config.py @@ -625,3 +625,59 @@ def test_suggestion_on_empty_parameter_file(tmp_path): ], } ) + + +@pytest.mark.parametrize( + "distribution, min, mode, max, error", + [ + ("TRIANGULAR", "0", "2", "3", None), + ( + "TRIANGULAR", + "3.0", + "3.0", + "3.0", + "Minimum 3.0 must be strictly less than the maxiumum 3.0", + ), + ("TRIANGULAR", "-1", "0", "1", None), + ( + "TRIANGULAR", + "3.0", + "6.0", + "5.5", + "The mode 6.0 must be between the minimum 3.0 and maximum 5.5", + ), + ( + "TRIANGULAR", + "3.0", + "-6.0", + "5.5", + "The mode -6.0 must be between the minimum 3.0 and maximum 5.5", + ), + ], +) +def test_validation_triangular_distribution( + tmpdir, distribution, min, mode, max, error +): + with tmpdir.as_cwd(): + config = dedent( + """ + JOBNAME my_name%d + NUM_REALIZATIONS 1 + GEN_KW KW_NAME template.txt kw.txt prior.txt + """ + ) + with open("config.ert", "w", encoding="utf-8") as fh: + fh.writelines(config) + with open("template.txt", "w", encoding="utf-8") as fh: + fh.writelines("MY_KEYWORD ") + with open("prior.txt", "w", encoding="utf-8") as fh: + fh.writelines(f"MY_KEYWORD {distribution} {min} {mode} {max}") + + if error: + with pytest.raises( + ConfigValidationError, + match=error, + ): + ErtConfig.from_file("config.ert") + else: + ErtConfig.from_file("config.ert")