Skip to content

Commit

Permalink
Refactor the everest run model
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Dec 18, 2024
1 parent 01156bb commit 7142c89
Show file tree
Hide file tree
Showing 10 changed files with 477 additions and 486 deletions.
818 changes: 430 additions & 388 deletions src/ert/run_models/everest_run_model.py

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions src/everest/config/environment_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Literal
from typing import Literal, Self

from pydantic import BaseModel, Field, field_validator
from numpy.random import SeedSequence
from pydantic import BaseModel, Field, field_validator, model_validator

from everest.config.validation_utils import check_path_valid

Expand Down Expand Up @@ -43,3 +44,9 @@ class EnvironmentConfig(BaseModel, extra="forbid"): # type: ignore
def validate_output_folder(cls, output_folder): # pylint:disable=E0213
check_path_valid(output_folder)
return output_folder

@model_validator(mode="after")
def validate_random_seed(self) -> Self:
if self.random_seed is None:
self.random_seed = SeedSequence().entropy
return self
43 changes: 24 additions & 19 deletions src/everest/detached/jobs/everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@
HTTPBasic,
HTTPBasicCredentials,
)
from ropt.enums import OptimizerExitCode

from ert.config import QueueSystem
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.run_models.everest_run_model import EverestRunModel
from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel
from everest import export_to_csv, export_with_progress
from everest.config import EverestConfig, ServerConfig
from everest.detached import ServerStatus, get_opt_status, update_everserver_status
Expand Down Expand Up @@ -373,25 +372,31 @@ def main():


def _get_optimization_status(exit_code, shared_data):
if exit_code == "max_batch_num_reached":
return ServerStatus.completed, "Maximum number of batches reached."

if exit_code == OptimizerExitCode.MAX_FUNCTIONS_REACHED:
return ServerStatus.completed, "Maximum number of function evaluations reached."
match exit_code:
case EverestExitCode.MAX_BATCH_NUM_REACHED:
return ServerStatus.completed, "Maximum number of batches reached."

case EverestExitCode.MAX_FUNCTIONS_REACHED:
return (
ServerStatus.completed,
"Maximum number of function evaluations reached.",
)

if exit_code == OptimizerExitCode.USER_ABORT:
return ServerStatus.stopped, "Optimization aborted."
case EverestExitCode.USER_ABORT:
return ServerStatus.stopped, "Optimization aborted."

if exit_code == OptimizerExitCode.TOO_FEW_REALIZATIONS:
status = (
ServerStatus.stopped if shared_data[STOP_ENDPOINT] else ServerStatus.failed
)
messages = _failed_realizations_messages(shared_data)
for msg in messages:
logging.getLogger(EVEREST).error(msg)
return status, "\n".join(messages)

return ServerStatus.completed, "Optimization completed."
case EverestExitCode.TOO_FEW_REALIZATIONS:
status = (
ServerStatus.stopped
if shared_data[STOP_ENDPOINT]
else ServerStatus.failed
)
messages = _failed_realizations_messages(shared_data)
for msg in messages:
logging.getLogger(EVEREST).error(msg)
return status, "\n".join(messages)
case _:
return ServerStatus.completed, "Optimization completed."


def _failed_realizations_messages(shared_data):
Expand Down
3 changes: 0 additions & 3 deletions src/everest/simulator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from everest.simulator.simulator_cache import SimulatorCache

JOB_SUCCESS = "Finished"
JOB_WAITING = "Waiting"
JOB_RUNNING = "Running"
Expand Down Expand Up @@ -109,5 +107,4 @@
"JOB_RUNNING",
"JOB_SUCCESS",
"JOB_WAITING",
"SimulatorCache",
]
58 changes: 0 additions & 58 deletions src/everest/simulator/simulator_cache.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/everest/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_seed(copy_math_func_test_data_to_tmp):
config.environment.random_seed = random_seed

run_model = EverestRunModel.create(config)
assert random_seed == run_model.everest_config.environment.random_seed
assert random_seed == run_model._everest_config.environment.random_seed

# Res
ert_config = _everest_to_ert_config_dict(config)
Expand All @@ -26,5 +26,5 @@ def test_loglevel(copy_math_func_test_data_to_tmp):
config = EverestConfig.load_file(CONFIG_FILE)
config.environment.log_level = "info"
run_model = EverestRunModel.create(config)
config = run_model.everest_config
config = run_model._everest_config
assert len(EverestConfig.lint_config_dict(config.to_dict())) == 0
8 changes: 1 addition & 7 deletions tests/everest/test_everest_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,7 @@ async def test_everest_output(copy_mocked_test_data_to_tmp):
initial_folders = set(folders)
initial_files = set(files)

# Tests in this class used to fail when a callback was passed in
# Use a callback just to see that everything works fine, even though
# the callback does nothing
def useless_cb(*args, **kwargs):
pass

EverestRunModel.create(config, optimization_callback=useless_cb)
EverestRunModel.create(config)

# Check the output folder is created when stating the optimization
# in everest workflow
Expand Down
8 changes: 4 additions & 4 deletions tests/everest/test_everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from pathlib import Path
from unittest.mock import patch

from ropt.enums import OptimizerExitCode
from seba_sqlite.snapshot import SebaSnapshot

from ert.run_models.everest_run_model import EverestExitCode
from everest.config import EverestConfig, ServerConfig
from everest.detached import ServerStatus, everserver_status
from everest.detached.jobs import everserver
Expand All @@ -33,8 +33,8 @@ def fail_optimization(self, from_ropt=False):
# shared_data (see set_shared_status() below).
self._sim_callback(None)
if from_ropt:
self._exit_code = OptimizerExitCode.TOO_FEW_REALIZATIONS
return OptimizerExitCode.TOO_FEW_REALIZATIONS
self._exit_code = EverestExitCode.TOO_FEW_REALIZATIONS
return EverestExitCode.TOO_FEW_REALIZATIONS

raise Exception("Failed optimization")

Expand Down Expand Up @@ -121,7 +121,7 @@ def test_everserver_status_failure(_1, copy_math_func_test_data_to_tmp):
"ert.run_models.everest_run_model.EverestRunModel.run_experiment",
autospec=True,
side_effect=lambda self, evaluator_server_config, restart=False: check_status(
ServerConfig.get_hostfile_path(self.everest_config.output_dir),
ServerConfig.get_hostfile_path(self._everest_config.output_dir),
status=ServerStatus.running,
),
)
Expand Down
2 changes: 1 addition & 1 deletion tests/everest/test_simulator_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def new_call(*args):
Path("everest_output/optimization_output/seba.db").unlink()

# The batch_id was used as a stopping criterion, so it must be reset:
run_model.batch_id = 0
run_model._batch_id = 0

run_model.run_experiment(evaluator_server_config)
assert n_evals == 0
Expand Down
8 changes: 6 additions & 2 deletions tests/everest/test_yaml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@ def test_random_seed(random_seed):
if random_seed:
config["environment"] = {"random_seed": random_seed}
ever_config = EverestConfig.with_defaults(**config)
assert ever_config.environment.random_seed == random_seed
ert_config = everest_to_ert_config(ever_config)
assert ert_config.random_seed == random_seed
if random_seed is None:
assert ever_config.environment.random_seed > 0
assert ert_config.random_seed > 0
else:
assert ever_config.environment.random_seed == random_seed
assert ert_config.random_seed == random_seed


def test_read_file(tmp_path, monkeypatch):
Expand Down

0 comments on commit 7142c89

Please sign in to comment.