Skip to content

Commit

Permalink
EverestRunModel: minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Dec 18, 2024
1 parent 9477e52 commit fab8972
Showing 1 changed file with 85 additions and 85 deletions.
170 changes: 85 additions & 85 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,28 +346,28 @@ def _on_before_forward_model_evaluation(
optimizer.abort_optimization()

def _forward_model_evaluator(
self, control_values: NDArray[np.float64], metadata: EvaluatorContext
self, control_values: NDArray[np.float64], evaluator_context: EvaluatorContext
) -> EvaluatorResult:
# Reset the current run status:
self._status = None

# Get cached_results:
cached_results = self._get_cached_results(control_values, metadata)
cached_results = self._get_cached_results(control_values, evaluator_context)

# Create the batch to run:
case_data = self._init_case_data(control_values, metadata, cached_results)
batch_data = _init_case_data(control_values, evaluator_context, cached_results)

# Initialize a new experiment in storage:
assert self._experiment is not None
ensemble = self._experiment.create_ensemble(
name=f"batch_{self._batch_id}",
ensemble_size=len(case_data),
ensemble_size=len(batch_data),
)
for sim_id, controls in enumerate(case_data.values()):
for sim_id, controls in enumerate(batch_data.values()):
self._setup_sim(sim_id, controls, ensemble)

# Evaluate the batch:
run_args = self._get_run_args(ensemble, metadata, case_data)
run_args = self._get_run_args(ensemble, evaluator_context, batch_data)
self._context_env.update(
{
"_ERT_EXPERIMENT_ID": str(ensemble.experiment_id),
Expand All @@ -384,14 +384,14 @@ def _forward_model_evaluator(
# Gather the results and create the result for ropt:
results = self._gather_simulation_results(ensemble)
evaluator_result = self._make_evaluator_result(
control_values, metadata, case_data, results, cached_results
control_values, evaluator_context, batch_data, results, cached_results
)

# Add the results from the evaluations to the cache:
self._add_results_to_cache(
control_values,
metadata,
case_data,
evaluator_context,
batch_data,
evaluator_result.objectives,
evaluator_result.constraints,
)
Expand All @@ -416,47 +416,6 @@ def _get_cached_results(
cached_results[sim_idx] = cached_data
return cached_results

@staticmethod
def _init_case_data(
control_values: NDArray[np.float64],
evaluator_context: EvaluatorContext,
cached_results: dict[int, Any],
) -> dict[int, dict[str, Any]]:
def add_control(
controls: dict[str, Any],
control_name: tuple[Any, ...],
control_value: float,
) -> None:
group_name = control_name[0]
variable_name = control_name[1]
group = controls.get(group_name, {})
if len(control_name) > 2:
index_name = str(control_name[2])
if variable_name in group:
group[variable_name][index_name] = control_value
else:
group[variable_name] = {index_name: control_value}
else:
group[variable_name] = control_value
controls[group_name] = group

case_data = {}
for control_idx in range(control_values.shape[0]):
if control_idx not in cached_results and (
evaluator_context.active is None
or evaluator_context.active[evaluator_context.realizations[control_idx]]
):
controls: dict[str, Any] = {}
assert evaluator_context.config.variables.names is not None
for control_name, control_value in zip(
evaluator_context.config.variables.names,
control_values[control_idx, :],
strict=False,
):
add_control(controls, control_name, control_value)
case_data[control_idx] = controls
return case_data

def _setup_sim(
self,
sim_id: int,
Expand Down Expand Up @@ -514,17 +473,17 @@ def _check_suffix(
def _get_run_args(
self,
ensemble: Ensemble,
metadata: EvaluatorContext,
case_data: dict[int, Any],
evaluator_context: EvaluatorContext,
batch_data: dict[int, Any],
) -> list[RunArg]:
substitutions = self.ert_config.substitutions
substitutions["<CASE_NAME>"] = ensemble.name
self.active_realizations = [True] * len(case_data)
assert metadata.config.realizations.names is not None
for sim_id, control_idx in enumerate(case_data.keys()):
realization = metadata.realizations[control_idx]
self.active_realizations = [True] * len(batch_data)
assert evaluator_context.config.realizations.names is not None
for sim_id, control_idx in enumerate(batch_data.keys()):
realization = evaluator_context.realizations[control_idx]
substitutions[f"<GEO_ID_{sim_id}_0>"] = str(
metadata.config.realizations.names[realization]
evaluator_context.config.realizations.names[realization]
)
run_paths = Runpaths(
jobname_format=self.ert_config.model_config.jobname_format_string,
Expand Down Expand Up @@ -584,26 +543,26 @@ def _gather_simulation_results(
def _make_evaluator_result(
self,
control_values: NDArray[np.float64],
metadata: EvaluatorContext,
case_data: dict[int, Any],
evaluator_context: EvaluatorContext,
batch_data: dict[int, Any],
results: list[dict[str, NDArray[np.float64]]],
cached_results: dict[int, Any],
) -> EvaluatorResult:
# We minimize the negative of the objectives:
objectives = -self._get_simulation_results(
objectives = -_get_simulation_results(
results,
metadata.config.objectives.names, # type: ignore
evaluator_context.config.objectives.names, # type: ignore
control_values,
case_data,
batch_data,
)

constraints = None
if metadata.config.nonlinear_constraints is not None:
constraints = self._get_simulation_results(
if evaluator_context.config.nonlinear_constraints is not None:
constraints = _get_simulation_results(
results,
metadata.config.nonlinear_constraints.names, # type: ignore
evaluator_context.config.nonlinear_constraints.names, # type: ignore
control_values,
case_data,
batch_data,
)

if self._simulator_cache is not None:
Expand All @@ -617,41 +576,25 @@ def _make_evaluator_result(
constraints[control_idx, ...] = cached_constraints

sim_ids = np.full(control_values.shape[0], -1, dtype=np.intc)
sim_ids[list(case_data.keys())] = np.arange(len(case_data), dtype=np.intc)
sim_ids[list(batch_data.keys())] = np.arange(len(batch_data), dtype=np.intc)
return EvaluatorResult(
objectives=objectives,
constraints=constraints,
batch_id=self._batch_id,
evaluation_ids=sim_ids,
)

@staticmethod
def _get_simulation_results(
results: list[dict[str, NDArray[np.float64]]],
names: tuple[str],
controls: NDArray[np.float64],
case_data: dict[int, Any],
) -> NDArray[np.float64]:
control_indices = list(case_data.keys())
values = np.zeros((controls.shape[0], len(names)), dtype=float64)
for func_idx, name in enumerate(names):
values[control_indices, func_idx] = np.fromiter(
(np.nan if not result else result[name][0] for result in results),
dtype=np.float64,
)
return values

def _add_results_to_cache(
self,
control_values: NDArray[np.float64],
evaluator_context: EvaluatorContext,
case_data: dict[int, Any],
batch_data: dict[int, Any],
objectives: NDArray[np.float64],
constraints: NDArray[np.float64] | None,
) -> None:
if self._simulator_cache is not None:
assert evaluator_context.config.realizations.names is not None
for control_idx in case_data:
for control_idx in batch_data:
realization = evaluator_context.realizations[control_idx]
self._simulator_cache.add(
evaluator_context.config.realizations.names[realization],
Expand Down Expand Up @@ -741,6 +684,63 @@ def _handle_errors(
fm_logger.error(err_msg.format(error_id, ""))


def _init_case_data(
control_values: NDArray[np.float64],
evaluator_context: EvaluatorContext,
cached_results: dict[int, Any],
) -> dict[int, dict[str, Any]]:
def add_control(
controls: dict[str, Any],
control_name: tuple[Any, ...],
control_value: float,
) -> None:
group_name = control_name[0]
variable_name = control_name[1]
group = controls.get(group_name, {})
if len(control_name) > 2:
index_name = str(control_name[2])
if variable_name in group:
group[variable_name][index_name] = control_value
else:
group[variable_name] = {index_name: control_value}
else:
group[variable_name] = control_value
controls[group_name] = group

batch_data = {}
for control_idx in range(control_values.shape[0]):
if control_idx not in cached_results and (
evaluator_context.active is None
or evaluator_context.active[evaluator_context.realizations[control_idx]]
):
controls: dict[str, Any] = {}
assert evaluator_context.config.variables.names is not None
for control_name, control_value in zip(
evaluator_context.config.variables.names,
control_values[control_idx, :],
strict=False,
):
add_control(controls, control_name, control_value)
batch_data[control_idx] = controls
return batch_data


def _get_simulation_results(
results: list[dict[str, NDArray[np.float64]]],
names: tuple[str],
controls: NDArray[np.float64],
batch_data: dict[int, Any],
) -> NDArray[np.float64]:
control_indices = list(batch_data.keys())
values = np.zeros((controls.shape[0], len(names)), dtype=float64)
for func_idx, name in enumerate(names):
values[control_indices, func_idx] = np.fromiter(
(np.nan if not result else result[name][0] for result in results),
dtype=np.float64,
)
return values


class SimulatorCache:
EPS = float(np.finfo(np.float32).eps)

Expand Down

0 comments on commit fab8972

Please sign in to comment.