-
Notifications
You must be signed in to change notification settings - Fork 108
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add property test of ensemble_smoother
- Loading branch information
1 parent
c453403
commit a87a4d4
Showing
1 changed file
with
243 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
import stat | ||
import warnings | ||
from pathlib import Path | ||
|
||
import hypothesis.strategies as st | ||
import numpy as np | ||
import pytest | ||
from hypothesis import given, note, settings | ||
from pytest import MonkeyPatch, TempPathFactory | ||
|
||
from ert.cli.main import ErtCliError | ||
from ert.config.gen_kw_config import DISTRIBUTION_PARAMETERS | ||
from ert.mode_definitions import ENSEMBLE_SMOOTHER_MODE | ||
from ert.storage import open_storage | ||
|
||
from .run_cli import run_cli_with_pm | ||
|
||
names = st.text( | ||
min_size=1, | ||
max_size=8, | ||
alphabet=st.characters( | ||
min_codepoint=ord("!"), | ||
max_codepoint=ord("~"), | ||
exclude_characters="\"'$,:%", # These have specific meaning in configs | ||
), | ||
) | ||
|
||
|
||
@st.composite | ||
def distribution_values(draw, k, vs): | ||
d = {} | ||
biggest = 100.0 | ||
if "LOG" in k: | ||
biggest = 10.0 | ||
epsilon = biggest / 1000.0 | ||
if "MIN" in vs: | ||
d["MIN"] = draw(st.floats(min_value=epsilon, max_value=biggest / 10.0)) | ||
if "MAX" in vs: | ||
d["MAX"] = draw(st.floats(min_value=d["MIN"] + 5 * epsilon, max_value=biggest)) | ||
if "MEAN" in vs: | ||
d["MEAN"] = draw( | ||
st.floats( | ||
min_value=d.get("MIN", 2 * epsilon) + epsilon, | ||
max_value=d.get("MAX", biggest) - epsilon, | ||
) | ||
) | ||
if "MODE" in vs: | ||
d["MODE"] = draw( | ||
st.floats( | ||
min_value=d.get("MIN", 2 * epsilon) + epsilon, | ||
max_value=d.get("MAX", biggest) - epsilon, | ||
) | ||
) | ||
if "STEPS" in vs: | ||
d["STEPS"] = draw(st.integers(min_value=2, max_value=10)) | ||
return [d.get(v, draw(st.floats(min_value=0.1, max_value=1.0))) for v in vs] | ||
|
||
|
||
distributions = st.one_of( | ||
[ | ||
st.tuples( | ||
st.just(k), | ||
distribution_values(k, vs), | ||
) | ||
for k, vs in DISTRIBUTION_PARAMETERS.items() | ||
] | ||
) | ||
|
||
config_contents = """\ | ||
NUM_REALIZATIONS {num_realizations} | ||
QUEUE_SYSTEM LOCAL | ||
QUEUE_OPTION LOCAL MAX_RUNNING {num_realizations} | ||
ENSPATH storage | ||
RANDOM_SEED 1234 | ||
OBS_CONFIG observations | ||
GEN_KW COEFFS coeff_priors | ||
GEN_DATA POLY_RES RESULT_FILE:poly.out | ||
INSTALL_JOB poly_eval POLY_EVAL | ||
FORWARD_MODEL poly_eval | ||
ANALYSIS_SET_VAR OBSERVATIONS AUTO_SCALE * | ||
""" | ||
|
||
coeff_priors = """\ | ||
coeff_0 {distribution0} 0 1 | ||
coeff_1 {distribution1} 0 2 | ||
coeff_2 {distribution2} 0 5 | ||
""" | ||
|
||
observation = """ | ||
GENERAL_OBSERVATION POLY_OBS_{i} {{ | ||
DATA = POLY_RES; | ||
INDEX_FILE = index_{i}.txt; | ||
OBS_FILE = poly_obs_{i}.txt; | ||
}}; | ||
""" | ||
|
||
poly_eval = """\ | ||
#!/usr/bin/env python3 | ||
import json | ||
import numpy as np | ||
coeffs = json.load(open("parameters.json"))["COEFFS"] | ||
c = [np.array(coeffs[f"coeff_" + str(i)]) for i in range(len(coeffs))] | ||
with open("poly.out", "w", encoding="utf-8") as f: | ||
f.write("\\n".join(map(str, [np.polyval(c, x) for x in range({num_points})]))) | ||
""" | ||
|
||
POLY_EVAL = "EXECUTABLE poly_eval.py" | ||
|
||
|
||
@pytest.mark.timeout(None) | ||
@settings(max_examples=1000) | ||
@given( | ||
num_realizations=st.integers(min_value=20, max_value=40), | ||
num_points=st.integers(min_value=1, max_value=20), | ||
distributions=st.lists(distributions, min_size=1, max_size=10), | ||
data=st.data(), | ||
) | ||
def test_update_lowers_generalized_variance_or_deactives_observations( | ||
tmp_path_factory: TempPathFactory, | ||
num_realizations: int, | ||
num_points: int, | ||
distributions: list[tuple[str, list[float]]], | ||
data, | ||
): | ||
indecies = data.draw( | ||
st.lists( | ||
st.integers(min_value=0, max_value=num_points - 1), | ||
min_size=1, | ||
max_size=num_points, | ||
unique=True, | ||
) | ||
) | ||
values = data.draw( | ||
st.lists( | ||
st.floats(min_value=-10.0, max_value=10.0), | ||
min_size=len(indecies), | ||
max_size=len(indecies), | ||
) | ||
) | ||
errs = data.draw( | ||
st.lists( | ||
st.floats(min_value=0.1, max_value=0.5), | ||
min_size=len(indecies), | ||
max_size=len(indecies), | ||
) | ||
) | ||
num_groups = data.draw(st.integers(min_value=1, max_value=num_points)) | ||
per_group = num_points // num_groups | ||
print(num_groups, num_points, per_group) | ||
|
||
tmp_path = tmp_path_factory.mktemp("parameter_example") | ||
note(f"Running in directory {tmp_path}") | ||
with MonkeyPatch.context() as patch: | ||
patch.chdir(tmp_path) | ||
contents = config_contents.format( | ||
num_realizations=num_realizations, | ||
) | ||
note(f"config file: {contents}") | ||
Path("config.ert").write_text(contents, encoding="utf-8") | ||
py = Path("poly_eval.py") | ||
py.write_text(poly_eval.format(num_points=num_points)) | ||
mode = os.stat(py) | ||
os.chmod(py, mode.st_mode | stat.S_IEXEC) | ||
|
||
for i in range(num_groups): | ||
print(f"{i * per_group } {(i + 1) * per_group}") | ||
with open("observations", mode="a", encoding="utf-8") as f: | ||
f.write(observation.format(i=i)) | ||
Path(f"poly_obs_{i}.txt").write_text( | ||
"\n".join( | ||
f"{x} {y}" | ||
for x, y in zip( | ||
values[i * per_group : (i + 1) * per_group], | ||
errs[i * per_group : (i + 1) * per_group], | ||
strict=False, | ||
) | ||
), | ||
encoding="utf-8", | ||
) | ||
Path(f"index_{i}.txt").write_text( | ||
"\n".join( | ||
f"{x}" for x in indecies[i * per_group : (i + 1) * per_group] | ||
), | ||
encoding="utf-8", | ||
) | ||
|
||
Path("coeff_priors").write_text( | ||
"\n".join( | ||
f"coeff_{i} {d} {' '.join(str(p) for p in v)}" | ||
for i, (d, v) in enumerate(distributions) | ||
), | ||
encoding="utf-8", | ||
) | ||
Path("POLY_EVAL").write_text(POLY_EVAL, encoding="utf-8") | ||
|
||
success = True | ||
with warnings.catch_warnings(record=True) as all_warnings: | ||
warnings.simplefilter("always") | ||
try: | ||
run_cli_with_pm( | ||
[ | ||
ENSEMBLE_SMOOTHER_MODE, | ||
"--disable-monitor", | ||
"--experiment-name", | ||
"experiment", | ||
"config.ert", | ||
] | ||
) | ||
except ErtCliError as err: | ||
success = False | ||
se = str(err) | ||
# "No active observations" is expected | ||
# when std deviation is too low, the | ||
# other errors are discussed here: | ||
# https://github.com/equinor/ert/issues/9581 | ||
# https://github.com/equinor/ert/issues/9585 | ||
assert ( | ||
"No active observations" in se | ||
or "math domain error" in se | ||
or "math range error" in se | ||
) | ||
|
||
if any("Ill-conditioned matrix" not in str(w.message) for w in all_warnings): | ||
success = False | ||
|
||
if success: | ||
with open_storage("storage") as storage: | ||
experiment = storage.get_experiment_by_name("experiment") | ||
prior = experiment.get_ensemble_by_name("iter-0").load_all_gen_kw_data() | ||
posterior = experiment.get_ensemble_by_name( | ||
"iter-1" | ||
).load_all_gen_kw_data() | ||
|
||
assert ( | ||
np.linalg.det(posterior.cov().to_numpy()) | ||
<= np.linalg.det(prior.cov().to_numpy()) + 0.001 | ||
) |