Skip to content

Commit

Permalink
Split ERT 3 parameters into separate records
Browse files Browse the repository at this point in the history
  • Loading branch information
pinkwah committed Jun 18, 2021
1 parent 09ff647 commit 43d2b7c
Show file tree
Hide file tree
Showing 7 changed files with 285 additions and 119 deletions.
11 changes: 9 additions & 2 deletions ert3/config/_parameters_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,15 @@ class ParametersConfig(_ParametersConfig):
def __iter__(self) -> Iterator[_ParameterConfig]: # type: ignore
return iter(self.__root__)

def __getitem__(self, item: int) -> _ParameterConfig:
return self.__root__[item]
def __getitem__(self, item: Union[int, str]) -> _ParameterConfig:
if isinstance(item, int):
return self.__root__[item]
elif isinstance(item, str):
for group in self:
if group.name == item:
return group
raise ValueError(f"No parameter group found named: {item}")
raise TypeError(f"Item should be int or str, not {type(item)}")

def __len__(self) -> int:
return len(self.__root__)
Expand Down
12 changes: 1 addition & 11 deletions ert3/engine/_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,6 @@ def load_record(workspace: Path, record_name: str, record_file: Path) -> None:
)


def _get_distribution(
parameter_group_name: str, parameters_config: ert3.config.ParametersConfig
) -> ert3.stats.Distribution:
for parameter_group in parameters_config:
if parameter_group.name == parameter_group_name:
return parameter_group.as_distribution()

raise ValueError(f"No parameter group found named: {parameter_group_name}")


# pylint: disable=too-many-arguments
def sample_record(
workspace: Path,
Expand All @@ -38,7 +28,7 @@ def sample_record(
ensemble_size: int,
experiment_name: Optional[str] = None,
) -> None:
distribution = _get_distribution(parameter_group_name, parameters_config)
distribution = parameters_config[parameter_group_name].as_distribution()
ensrecord = ert3.data.EnsembleRecord(
records=[distribution.sample() for _ in range(ensemble_size)]
)
Expand Down
46 changes: 41 additions & 5 deletions ert3/engine/_run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pathlib
from typing import List, Dict
from typing import List, Dict, Set, Union

import ert3

Expand All @@ -9,19 +9,51 @@ def _prepare_experiment(
experiment_name: str,
ensemble: ert3.config.EnsembleConfig,
ensemble_size: int,
parameters_config: ert3.config.ParametersConfig,
) -> None:
if ert3.workspace.experiment_has_run(workspace_root, experiment_name):
raise ValueError(f"Experiment {experiment_name} have been carried out.")

parameter_names = [elem.record for elem in ensemble.input]
parameters: Dict[str, List[str]] = {}
for input_record in ensemble.input:
record_name = input_record.record
record_source = input_record.source.split(".")
parameters[record_name] = _get_experiment_record_indices(
workspace_root, record_name, record_source, parameters_config
)

ert3.storage.init_experiment(
workspace=workspace_root,
experiment_name=experiment_name,
parameters=parameter_names,
parameters=parameters,
ensemble_size=ensemble_size,
)


def _get_experiment_record_indices(
workspace_root: pathlib.Path,
record_name: str,
record_source: List[str],
parameters_config: ert3.config.ParametersConfig,
) -> List[str]:
if record_source[0] == "storage":
assert len(record_source) == 2
ensemble_record = ert3.storage.get_ensemble_record(
workspace=workspace_root, record_name=record_source[1]
)
indices: Set[Union[str, int]] = set()
for record in ensemble_record.records:
assert record.index is not None
indices |= set(record.index)
return [str(x) for x in indices]

elif record_source[0] == "stochastic":
assert len(record_source) == 2
return list(parameters_config[record_source[1]].variables)

raise ValueError("Unknown record source location {}".format(record_source[0]))


# pylint: disable=too-many-arguments
def _prepare_experiment_record(
record_name: str,
Expand Down Expand Up @@ -64,7 +96,9 @@ def _prepare_evaluation(
# This reassures mypy that the ensemble size is defined
assert ensemble.size is not None

_prepare_experiment(workspace_root, experiment_name, ensemble, ensemble.size)
_prepare_experiment(
workspace_root, experiment_name, ensemble, ensemble.size, parameters_config
)

for input_record in ensemble.input:
record_name = input_record.record
Expand Down Expand Up @@ -111,7 +145,9 @@ def _prepare_sensitivity(
)
input_records = ert3.algorithms.one_at_the_time(parameter_distributions)

_prepare_experiment(workspace_root, experiment_name, ensemble, len(input_records))
_prepare_experiment(
workspace_root, experiment_name, ensemble, len(input_records), parameters_config
)

parameters: Dict[str, List[ert3.data.Record]] = {
param.record: [] for param in ensemble.input
Expand Down
Loading

0 comments on commit 43d2b7c

Please sign in to comment.