Skip to content

Commit

Permalink
Add warning when everest-models file outputs do not match everest obj…
Browse files Browse the repository at this point in the history
…ectives
  • Loading branch information
DanSava committed Nov 12, 2024
1 parent 44ece92 commit 8d47dce
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 14 deletions.
35 changes: 23 additions & 12 deletions src/everest/config/everest_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
model_validator,
)
from ruamel.yaml import YAML, YAMLError
from typing_extensions import Annotated
from typing_extensions import Annotated, Self

from ert.config import ErtConfig
from everest.config.control_variable_config import ControlVariableGuessListConfig
Expand All @@ -42,6 +42,9 @@
validate_forward_model_configs,
)
from everest.jobs import script_names
from everest.util.forward_models import (
check_forward_model_objective,
)

from ..config_file_loader import yaml_file_to_substituted_config_dict
from ..strings import (
Expand Down Expand Up @@ -221,7 +224,7 @@ class EverestConfig(BaseModelWithPropertySupport): # type: ignore
model_config = ConfigDict(extra="forbid")

@model_validator(mode="after")
def validate_install_job_sources(self): # pylint: disable=E0213
def validate_install_job_sources(self) -> Self: # pylint: disable=E0213
model = self.model
config_path = self.config_path
if not model or not config_path:
Expand Down Expand Up @@ -286,7 +289,7 @@ def validate_install_job_sources(self): # pylint: disable=E0213
return self

@model_validator(mode="after")
def validate_forward_model_job_name_installed(self): # pylint: disable=E0213
def validate_forward_model_job_name_installed(self) -> Self: # pylint: disable=E0213
install_jobs = self.install_jobs
forward_model_jobs = self.forward_model
if install_jobs is None:
Expand All @@ -308,7 +311,7 @@ def validate_forward_model_job_name_installed(self): # pylint: disable=E0213
return self

@model_validator(mode="after")
def validate_workflow_name_installed(self): # pylint: disable=E0213
def validate_workflow_name_installed(self) -> Self: # pylint: disable=E0213
workflows = self.workflows
if workflows is None:
return self
Expand Down Expand Up @@ -344,7 +347,7 @@ def validate_install_templates_unique_output_files(cls, install_templates): # p
return install_templates

@model_validator(mode="after")
def validate_install_templates_are_existing_files(self):
def validate_install_templates_are_existing_files(self) -> Self:
install_templates = self.install_templates

if not install_templates:
Expand Down Expand Up @@ -374,7 +377,7 @@ def validate_install_templates_are_existing_files(self):
return self

@model_validator(mode="after")
def validate_cvar_nreals_interval(self): # pylint: disable=E0213
def validate_cvar_nreals_interval(self) -> Self: # pylint: disable=E0213
optimization = self.optimization
if not optimization:
return self
Expand Down Expand Up @@ -402,7 +405,7 @@ def validate_cvar_nreals_interval(self): # pylint: disable=E0213
return self

@model_validator(mode="after")
def validate_install_data_source_exists(self):
def validate_install_data_source_exists(self) -> Self:
install_data = self.install_data or []
if not install_data:
return self
Expand All @@ -417,7 +420,7 @@ def validate_install_data_source_exists(self):
return self

@model_validator(mode="after")
def validate_model_data_file_exists(self): # pylint: disable=E0213
def validate_model_data_file_exists(self) -> Self: # pylint: disable=E0213
model = self.model
if not model:
return self
Expand All @@ -429,7 +432,7 @@ def validate_model_data_file_exists(self): # pylint: disable=E0213
return self

@model_validator(mode="after")
def validate_maintained_forward_models(self):
def validate_maintained_forward_models(self) -> Self:
install_data = self.install_data
model = self.model
realizations = model.realizations if model else [0]
Expand All @@ -440,9 +443,17 @@ def validate_maintained_forward_models(self):
validate_forward_model_configs(self.forward_model, self.install_jobs)
return self

@model_validator(mode="after")
def validate_maintained_forward_model_write_objectives(self) -> Self:
if not self.objective_functions:
return self
objectives = {objective.name for objective in self.objective_functions}
check_forward_model_objective(self.forward_model, objectives)
return self

@model_validator(mode="after")
# pylint: disable=E0213
def validate_input_constraints_weight_definition(self):
def validate_input_constraints_weight_definition(self) -> Self:
input_constraints = self.input_constraints
if not input_constraints:
return self
Expand Down Expand Up @@ -479,7 +490,7 @@ def validate_input_constraints_weight_definition(self):
return self

@model_validator(mode="after")
def validate_variable_name_match_well_name(self): # pylint: disable=E0213
def validate_variable_name_match_well_name(self) -> Self: # pylint: disable=E0213
controls = self.controls
wells = self.wells
if controls is None or wells is None:
Expand All @@ -497,7 +508,7 @@ def validate_variable_name_match_well_name(self): # pylint: disable=E0213
return self

@model_validator(mode="after")
def validate_that_environment_sim_folder_is_writeable(self):
def validate_that_environment_sim_folder_is_writeable(self) -> Self:
environment = self.environment
config_path = self.config_path
if environment is None or config_path is None:
Expand Down
10 changes: 9 additions & 1 deletion src/everest/plugins/hook_specs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence, Type, TypeVar
from typing import List, Sequence, Type, TypeVar

from everest.plugins import hookspec

Expand Down Expand Up @@ -103,3 +103,11 @@ def add_log_handle_to_root():
@hookspec
def get_forward_model_documentations():
""" """


@hookspec(firstresult=True)
def custom_forward_model_outputs(forward_model_steps: List[str]):
"""
Check if the given forward model steps will output to a file maching the
defined everest objective
"""
21 changes: 20 additions & 1 deletion src/everest/util/forward_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List, Type, TypeVar
from typing import List, Set, Type, TypeVar

from pydantic import BaseModel, ValidationError

from ert.config import ConfigWarning
from everest.plugins.everest_plugin_manager import EverestPluginManager

pm = EverestPluginManager()
Expand All @@ -19,6 +20,24 @@ def lint_forward_model_job(job: str, args) -> List[str]:
return pm.hook.lint_forward_model(job=job, args=args)


def check_forward_model_objective(
forward_model_steps: List[str], objectives: Set[str]
) -> None:
fm_outputs = pm.hook.custom_forward_model_outputs(
forward_model_steps=forward_model_steps,
objectives=objectives,
)
if fm_outputs is None:
return
unaccounted_objectives = objectives.difference(fm_outputs)
if unaccounted_objectives:
add_s = "s" if len(unaccounted_objectives) > 1 else ""
ConfigWarning.warn(
f"Warning: Forward model might not write the required output file{add_s}"
f" for {sorted(unaccounted_objectives)}"
)


def parse_forward_model_file(path: str, schema: Type[T], message: str) -> T:
try:
res = pm.hook.parse_forward_model_schema(path=path, schema=schema)
Expand Down
42 changes: 42 additions & 0 deletions tests/everest/test_config_validation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import os
import pathlib
import re
import warnings
from pathlib import Path
from typing import Any, Dict, List, Union

import pytest
from pydantic import ValidationError

from ert.config import ConfigWarning
from everest.config import EverestConfig, ModelConfig
from everest.config.control_variable_config import ControlVariableConfig
from everest.config.sampler_config import SamplerConfig
from tests.everest.utils import skipif_no_everest_models


def has_error(error: Union[ValidationError, List[dict]], match: str):
Expand Down Expand Up @@ -944,3 +947,42 @@ def test_that_non_existing_workflow_jobs_cause_error():
]
},
)


@skipif_no_everest_models
@pytest.mark.everest_models_test
@pytest.mark.parametrize(
["objective", "warning_msg"],
[
(
["npv", "rf"],
"Warning: Forward model might not write the required output file for \\['npv'\\]",
),
(
["npv", "npv2"],
"Warning: Forward model might not write the required output files for \\['npv', 'npv2'\\]",
),
(["rf"], None),
],
)
def test_warning_forward_model_write_objectives(objective, warning_msg):
fm_steps = [
"well_constraints -i files/well_readydate.json -c files/wc_config.yml -rc well_rate.json -o wc_wells.json",
"add_templates -i wc_wells.json -c files/at_config.yml -o at_wells.json",
"schmerge -s eclipse/include/schedule/schedule.tmpl -i at_wells.json -o eclipse/include/schedule/schedule.sch",
"eclipse100 TEST.DATA --version 2020.2",
"rf -s TEST -o rf",
]
if warning_msg is not None:
with pytest.warns(ConfigWarning, match=warning_msg):
EverestConfig.with_defaults(
objective_functions=[{"name": o} for o in objective],
forward_model=fm_steps,
)
else:
with warnings.catch_warnings():
warnings.simplefilter("error", category=ConfigWarning)
EverestConfig.with_defaults(
objective_functions=[{"name": o} for o in objective],
forward_model=fm_steps,
)

0 comments on commit 8d47dce

Please sign in to comment.