From bc61f89c36bbff50989f382bfb945986863ec787 Mon Sep 17 00:00:00 2001 From: Zohar Malamant Date: Tue, 8 Jun 2021 17:56:05 +0200 Subject: [PATCH] Split ERT 3 parameters into separate records --- ert3/config/_parameters_config.py | 11 +- ert3/engine/_record.py | 12 +- ert3/engine/_run.py | 54 ++++- ert3/storage/_storage.py | 266 ++++++++++++++------- tests/ert3/console/integration/test_cli.py | 10 +- tests/ert3/storage/test_storage.py | 100 ++++++-- tests/ert3/workspace/test_workspace.py | 2 +- 7 files changed, 332 insertions(+), 123 deletions(-) diff --git a/ert3/config/_parameters_config.py b/ert3/config/_parameters_config.py index efc2c279084..f001f5bbb61 100644 --- a/ert3/config/_parameters_config.py +++ b/ert3/config/_parameters_config.py @@ -145,8 +145,15 @@ class ParametersConfig(_ParametersConfig): def __iter__(self) -> Iterator[_ParameterConfig]: # type: ignore return iter(self.__root__) - def __getitem__(self, item: int) -> _ParameterConfig: - return self.__root__[item] + def __getitem__(self, item: Union[int, str]) -> _ParameterConfig: + if isinstance(item, int): + return self.__root__[item] + elif isinstance(item, str): + for group in self: + if group.name == item: + return group + raise ValueError(f"No parameter group found named: {item}") + raise TypeError(f"Item should be int or str, not {type(item)}") def __len__(self) -> int: return len(self.__root__) diff --git a/ert3/engine/_record.py b/ert3/engine/_record.py index cc25ca2ae3d..ba6efa04ea5 100644 --- a/ert3/engine/_record.py +++ b/ert3/engine/_record.py @@ -19,16 +19,6 @@ def load_record(workspace: Path, record_name: str, record_file: Path) -> None: ) -def _get_distribution( - parameter_group_name: str, parameters_config: ert3.config.ParametersConfig -) -> ert3.stats.Distribution: - for parameter_group in parameters_config: - if parameter_group.name == parameter_group_name: - return parameter_group.as_distribution() - - raise ValueError(f"No parameter group found named: {parameter_group_name}") - - # pylint: disable=too-many-arguments def sample_record( workspace: Path, @@ -38,7 +28,7 @@ def sample_record( ensemble_size: int, experiment_name: Optional[str] = None, ) -> None: - distribution = _get_distribution(parameter_group_name, parameters_config) + distribution = parameters_config[parameter_group_name].as_distribution() ensrecord = ert3.data.EnsembleRecord( records=[distribution.sample() for _ in range(ensemble_size)] ) diff --git a/ert3/engine/_run.py b/ert3/engine/_run.py index f30bb3e34b6..e96a8c3870f 100644 --- a/ert3/engine/_run.py +++ b/ert3/engine/_run.py @@ -1,29 +1,65 @@ import pathlib -from typing import List, Dict +from typing import List, Dict, Set, Union import ert3 +# Character used to separate record source "paths". +_SOURCE_SEPARATOR = "." + + def _prepare_experiment( workspace_root: pathlib.Path, experiment_name: str, ensemble: ert3.config.EnsembleConfig, ensemble_size: int, + parameters_config: ert3.config.ParametersConfig, ) -> None: if ert3.workspace.experiment_has_run(workspace_root, experiment_name): raise ValueError(f"Experiment {experiment_name} have been carried out.") - parameter_names = [elem.record for elem in ensemble.input] + parameters: Dict[str, List[str]] = {} + for input_record in ensemble.input: + record_name = input_record.record + record_source = input_record.source.split(_SOURCE_SEPARATOR) + parameters[record_name] = _get_experiment_record_indices( + workspace_root, record_name, record_source, parameters_config + ) responses = [elem.record for elem in ensemble.output] ert3.storage.init_experiment( workspace=workspace_root, experiment_name=experiment_name, - parameters=parameter_names, + parameters=parameters, ensemble_size=ensemble_size, responses=responses, ) +def _get_experiment_record_indices( + workspace_root: pathlib.Path, + record_name: str, + record_source: List[str], + parameters_config: ert3.config.ParametersConfig, +) -> List[str]: + assert len(record_source) == 2 + source, source_record_name = record_source + + if source == "storage": + ensemble_record = ert3.storage.get_ensemble_record( + workspace=workspace_root, record_name=source_record_name + ) + indices: Set[Union[str, int]] = set() + for record in ensemble_record.records: + assert record.index is not None + indices.update(record.index) + return [str(x) for x in indices] + + elif source == "stochastic": + return list(parameters_config[source_record_name].variables) + + raise ValueError("Unknown record source location {}".format(source)) + + # pylint: disable=too-many-arguments def _prepare_experiment_record( record_name: str, @@ -66,11 +102,13 @@ def _prepare_evaluation( # This reassures mypy that the ensemble size is defined assert ensemble.size is not None - _prepare_experiment(workspace_root, experiment_name, ensemble, ensemble.size) + _prepare_experiment( + workspace_root, experiment_name, ensemble, ensemble.size, parameters_config + ) for input_record in ensemble.input: record_name = input_record.record - record_source = input_record.source.split(".") + record_source = input_record.source.split(_SOURCE_SEPARATOR) _prepare_experiment_record( record_name, @@ -94,7 +132,7 @@ def _load_ensemble_parameters( ensemble_parameters = {} for input_record in ensemble.input: record_name = input_record.record - record_source = input_record.source.split(".") + record_source = input_record.source.split(_SOURCE_SEPARATOR) assert len(record_source) == 2 assert record_source[0] == "stochastic" parameter_group_name = record_source[1] @@ -113,7 +151,9 @@ def _prepare_sensitivity( ) input_records = ert3.algorithms.one_at_the_time(parameter_distributions) - _prepare_experiment(workspace_root, experiment_name, ensemble, len(input_records)) + _prepare_experiment( + workspace_root, experiment_name, ensemble, len(input_records), parameters_config + ) parameters: Dict[str, List[ert3.data.Record]] = { param.record: [] for param in ensemble.input diff --git a/ert3/storage/_storage.py b/ert3/storage/_storage.py index d9b41e382f7..57a8e820cad 100644 --- a/ert3/storage/_storage.py +++ b/ert3/storage/_storage.py @@ -1,6 +1,16 @@ import json from pathlib import Path -from typing import Any, Dict, Iterable, Optional, Set +from typing import ( + Any, + Dict, + Iterable, + Mapping, + MutableMapping, + Optional, + Set, + List, + Union, +) import io import logging import pandas as pd @@ -17,6 +27,11 @@ _ENSEMBLE_RECORDS = "__ensemble_records__" _SPECIAL_KEYS = (_ENSEMBLE_RECORDS,) +# Character used as separator for parameter record names. This is used as a +# workaround for webviz-ert, which expects each parameter record to have exactly +# one value per realisation. +_PARAMETER_RECORD_SEPARATOR = "." + class _NumericalMetaData(BaseModel): class Config: @@ -130,7 +145,7 @@ def init(*, workspace: Path) -> None: _init_experiment( workspace=workspace, experiment_name=f"{workspace}.{special_key}", - parameters=[], + parameters={}, ensemble_size=-1, responses=[], ) @@ -140,7 +155,7 @@ def init_experiment( *, workspace: Path, experiment_name: str, - parameters: Iterable[str], + parameters: Mapping[str, Iterable[str]], ensemble_size: int, responses: Iterable[str], ) -> None: @@ -160,7 +175,7 @@ def _init_experiment( *, workspace: Path, experiment_name: str, - parameters: Iterable[str], + parameters: Mapping[str, Iterable[str]], ensemble_size: int, responses: Iterable[str], ) -> None: @@ -172,15 +187,24 @@ def _init_experiment( f"Cannot initialize existing experiment: {experiment_name}" ) + if len(set(parameters.keys()).intersection(responses)) > 0: + raise ert3.exceptions.StorageError( + "Experiment parameters and responses cannot have a name in common" + ) + exp_response = _post_to_server(path="experiments", json={"name": experiment_name}) exp_id = exp_response.json()["id"] response = _post_to_server( - path=f"experiments/{exp_id}/ensembles", + f"experiments/{exp_id}/ensembles", json={ - "parameter_names": list(parameters), + "parameter_names": [ + f"{record}.{param}" + for record, params in parameters.items() + for param in params + ], "response_names": list(responses), "size": ensemble_size, - "metadata": {"name": experiment_name}, + "userdata": {"name": experiment_name}, }, ) if response.status_code != 200: @@ -210,7 +234,8 @@ def _add_numerical_data( workspace: Path, experiment_name: str, record_name: str, - ensemble_record: ert3.data.EnsembleRecord, + record_data: Union[pd.DataFrame, pd.Series], + record_type: ert3.data.RecordType, ) -> None: experiment = _get_experiment_by_name(experiment_name) if experiment is None: @@ -220,75 +245,95 @@ def _add_numerical_data( ) metadata = _NumericalMetaData( - ensemble_size=ensemble_record.ensemble_size, - record_type=_get_record_type(ensemble_record), + ensemble_size=len(record_data), + record_type=record_type, ) ensemble_id = experiment["ensemble_ids"][0] # currently just one ens per exp record_url = f"ensembles/{ensemble_id}/records/{record_name}" - for idx, record in enumerate(ensemble_record.records): - df = pd.DataFrame([record.data], columns=record.index, index=[idx]) - response = _post_to_server( - path=f"{record_url}/matrix", - params={"realization_index": idx}, - data=df.to_csv().encode(), - headers={"content-type": "text/csv"}, - ) + response = _post_to_server( + f"{record_url}/matrix", + data=record_data.to_csv().encode(), + headers={"content-type": "text/csv"}, + ) - if response.status_code == 409: - raise ert3.exceptions.ElementExistsError("Record already exists") + if response.status_code == 409: + raise ert3.exceptions.ElementExistsError("Record already exists") - if response.status_code != 200: - raise ert3.exceptions.StorageError(response.text) + if response.status_code != 200: + raise ert3.exceptions.StorageError(response.text) - meta_response = _put_to_server( - path=f"{record_url}/userdata", - params={"realization_index": idx}, - json=metadata.dict(), - ) + meta_response = _put_to_server(f"{record_url}/userdata", json=metadata.dict()) - if meta_response.status_code != 200: - raise ert3.exceptions.StorageError(meta_response.text) + if meta_response.status_code != 200: + raise ert3.exceptions.StorageError(meta_response.text) -def _response2record( - response_content: bytes, record_type: ert3.data.RecordType, realization_id: int -) -> ert3.data.Record: +def _response2records( + response_content: bytes, record_type: ert3.data.RecordType +) -> ert3.data.EnsembleRecord: dataframe = pd.read_csv( io.BytesIO(response_content), index_col=0, float_precision="round_trip" ) - raw_index = tuple(dataframe.columns) + records: List[ert3.data.Record] if record_type == ert3.data.RecordType.LIST_FLOAT: - array_data = tuple( - float(dataframe.loc[realization_id][raw_idx]) for raw_idx in raw_index - ) - return ert3.data.Record(data=array_data) + records = [ + ert3.data.Record(data=row.to_list()) for _, row in dataframe.iterrows() + ] elif record_type == ert3.data.RecordType.MAPPING_INT_FLOAT: - int_index = tuple(int(e) for e in dataframe.columns) - idata = { - idx: float(dataframe.loc[realization_id][raw_idx]) - for raw_idx, idx in zip(raw_index, int_index) - } - return ert3.data.Record(data=idata) + records = [ + ert3.data.Record(data={int(k): v for k, v in row.to_dict().items()}) + for _, row in dataframe.iterrows() + ] elif record_type == ert3.data.RecordType.MAPPING_STR_FLOAT: - str_index = tuple(str(e) for e in dataframe.columns) - sdata = { - idx: float(dataframe.loc[realization_id][raw_idx]) - for raw_idx, idx in zip(raw_index, str_index) - } - return ert3.data.Record(data=sdata) + records = [ + ert3.data.Record(data=row.to_dict()) for _, row in dataframe.iterrows() + ] else: raise ValueError( f"Unexpected record type when loading numerical record: {record_type}" ) + return ert3.data.EnsembleRecord(records=records) + + +def _combine_records( + ensemble_records: List[ert3.data.EnsembleRecord], +) -> ert3.data.EnsembleRecord: + # Combine records into the first ensemble record + combined_records: List[ert3.data.Record] = [] + for record_idx, _ in enumerate(ensemble_records[0].records): + record0 = ensemble_records[0].records[record_idx] + + if isinstance(record0.data, list): + ldata = [ + val + for data in ( + ensemble_record.records[record_idx].data + for ensemble_record in ensemble_records + ) + if isinstance(data, list) + for val in data + ] + combined_records.append(ert3.data.Record(data=ldata)) + elif isinstance(record0.data, dict): + ddata = { + key: val + for data in ( + ensemble_record.records[record_idx].data + for ensemble_record in ensemble_records + ) + if isinstance(data, dict) + for key, val in data.items() + } + combined_records.append(ert3.data.Record(data=ddata)) + return ert3.data.EnsembleRecord(records=combined_records) def _get_numerical_metadata(ensemble_id: str, record_name: str) -> _NumericalMetaData: response = _get_from_server( - path=f"ensembles/{ensemble_id}/records/{record_name}/userdata", - params={"realization_index": 0}, # This assumes there is a realization 0 + f"ensembles/{ensemble_id}/records/{record_name}/userdata" ) if response.status_code == 404: @@ -314,30 +359,46 @@ def _get_numerical_data( ensemble_id = experiment["ensemble_ids"][0] # currently just one ens per exp metadata = _get_numerical_metadata(ensemble_id, record_name) - records = [] - for real_id in range(metadata.ensemble_size): - response = _get_from_server( - path=f"ensembles/{ensemble_id}/records/{record_name}", - params={"realization_index": real_id}, - headers={"accept": "text/csv"}, + response = _get_from_server( + f"ensembles/{ensemble_id}/records/{record_name}", + headers={"accept": "text/csv"}, + ) + + if response.status_code == 404: + raise ert3.exceptions.ElementMissingError( + f"No {record_name} data for experiment: {experiment_name}" ) - if response.status_code == 404: - raise ert3.exceptions.ElementMissingError( - f"No {record_name} data for experiment: {experiment_name}" - ) + if response.status_code != 200: + raise ert3.exceptions.StorageError(response.text) + + return _response2records( + response.content, + metadata.record_type, + ) - if response.status_code != 200: - raise ert3.exceptions.StorageError(response.text) - record = _response2record( - response.content, - metadata.record_type, - real_id, +def _get_experiment_parameters( + workspace: Path, experiment_name: str +) -> Mapping[str, Iterable[str]]: + experiment = _get_experiment_by_name(experiment_name) + if experiment is None: + raise ert3.exceptions.NonExistantExperiment( + f"Cannot get parameters from non-existing experiment: {experiment_name}" ) - records.append(record) - return ert3.data.EnsembleRecord(records=records) + ensemble_id = experiment["ensemble_ids"][0] # currently just one ens per exp + response = _get_from_server(f"ensembles/{ensemble_id}/parameters") + if response.status_code != 200: + raise ert3.exceptions.StorageError(response.text) + parameters: MutableMapping[str, List[str]] = {} + for name in response.json(): + key, val = name.split(".") + if key in parameters: + parameters[key].append(val) + else: + parameters[key] = [val] + return parameters def add_ensemble_record( @@ -349,8 +410,35 @@ def add_ensemble_record( ) -> None: if experiment_name is None: experiment_name = f"{workspace}.{_ENSEMBLE_RECORDS}" + experiment = _get_experiment_by_name(experiment_name) + if experiment is None: + raise ert3.exceptions.NonExistantExperiment( + f"Cannot add {record_name} data to " + f"non-existing experiment: {experiment_name}" + ) - _add_numerical_data(workspace, experiment_name, record_name, ensemble_record) + dataframe = pd.DataFrame([r.data for r in ensemble_record.records]) + record_type = _get_record_type(ensemble_record) + + parameters = _get_experiment_parameters(workspace, experiment_name) + if record_name in parameters: + # Split by columns + for column_label in dataframe: + _add_numerical_data( + workspace, + experiment_name, + f"{record_name}.{column_label}", + dataframe[column_label], + record_type, + ) + else: + _add_numerical_data( + workspace, + experiment_name, + record_name, + dataframe, + record_type, + ) def get_ensemble_record( @@ -361,13 +449,31 @@ def get_ensemble_record( ) -> ert3.data.EnsembleRecord: if experiment_name is None: experiment_name = f"{workspace}.{_ENSEMBLE_RECORDS}" + experiment = _get_experiment_by_name(experiment_name) + if experiment is None: + raise ert3.exceptions.NonExistantExperiment( + f"Cannot get {record_name} data, no experiment named: {experiment_name}" + ) - return _get_numerical_data(workspace, experiment_name, record_name) + param_names = _get_experiment_parameters(workspace, experiment_name) + if record_name in param_names: + ensemble_records = [ + _get_numerical_data( + workspace, + experiment_name, + record_name + _PARAMETER_RECORD_SEPARATOR + param_name, + ) + for param_name in param_names[record_name] + ] + return _combine_records(ensemble_records) + else: + return _get_numerical_data(workspace, experiment_name, record_name) def get_ensemble_record_names( - *, workspace: Path, experiment_name: Optional[str] = None + *, workspace: Path, experiment_name: Optional[str] = None, _flatten: bool = True ) -> Iterable[str]: + # _flatten is a parameter used only for testing separated parameter records if experiment_name is None: experiment_name = f"{workspace}.{_ENSEMBLE_RECORDS}" experiment = _get_experiment_by_name(experiment_name) @@ -380,23 +486,17 @@ def get_ensemble_record_names( response = _get_from_server(path=f"ensembles/{ensemble_id}/records") if response.status_code != 200: raise ert3.exceptions.StorageError(response.text) + + # Flatten any parameter records that were split + if _flatten: + return {x.split(_PARAMETER_RECORD_SEPARATOR)[0] for x in response.json().keys()} return list(response.json().keys()) def get_experiment_parameters( *, workspace: Path, experiment_name: str ) -> Iterable[str]: - experiment = _get_experiment_by_name(experiment_name) - if experiment is None: - raise ert3.exceptions.NonExistantExperiment( - f"Cannot get parameters from non-existing experiment: {experiment_name}" - ) - - ensemble_id = experiment["ensemble_ids"][0] # currently just one ens per exp - response = _get_from_server(path=f"ensembles/{ensemble_id}/parameters") - if response.status_code != 200: - raise ert3.exceptions.StorageError(response.text) - return list(response.json()) + return list(_get_experiment_parameters(workspace, experiment_name)) def get_experiment_responses(*, workspace: Path, experiment_name: str) -> Iterable[str]: diff --git a/tests/ert3/console/integration/test_cli.py b/tests/ert3/console/integration/test_cli.py index e99cf693a0c..88e0a4cab1b 100644 --- a/tests/ert3/console/integration/test_cli.py +++ b/tests/ert3/console/integration/test_cli.py @@ -217,7 +217,7 @@ def test_cli_status_some_runs(workspace, capsys): ert3.storage.init_experiment( workspace=workspace, experiment_name=experiments[idx], - parameters=[], + parameters={}, ensemble_size=42, responses=[], ) @@ -240,7 +240,7 @@ def test_cli_status_all_run(workspace, capsys): ert3.storage.init_experiment( workspace=workspace, experiment_name=experiment, - parameters=[], + parameters={}, ensemble_size=42, responses=[], ) @@ -302,7 +302,7 @@ def test_cli_clean_all(workspace): ert3.storage.init_experiment( workspace=workspace, experiment_name=experiment, - parameters=[], + parameters={}, ensemble_size=42, responses=[], ) @@ -335,7 +335,7 @@ def test_cli_clean_one(workspace): ert3.storage.init_experiment( workspace=workspace, experiment_name=experiment, - parameters=[], + parameters={}, ensemble_size=42, responses=[], ) @@ -375,7 +375,7 @@ def test_cli_clean_non_existant_experiment(workspace, capsys): ert3.storage.init_experiment( workspace=workspace, experiment_name=experiment, - parameters=[], + parameters={}, ensemble_size=42, responses=[], ) diff --git a/tests/ert3/storage/test_storage.py b/tests/ert3/storage/test_storage.py index 90e7ae163a5..53bca1c0946 100644 --- a/tests/ert3/storage/test_storage.py +++ b/tests/ert3/storage/test_storage.py @@ -25,7 +25,7 @@ def test_ensemble_size_zero(tmpdir, ert_storage): ert3.storage.init_experiment( workspace=tmpdir, experiment_name="my_experiment", - parameters=[], + parameters={}, ensemble_size=0, responses=[], ) @@ -38,7 +38,7 @@ def test_none_as_experiment_name(tmpdir, ert_storage): ert3.storage.init_experiment( workspace=tmpdir, experiment_name=None, - parameters=[], + parameters={}, ensemble_size=10, responses=[], ) @@ -50,7 +50,7 @@ def test_double_add_experiment(tmpdir, ert_storage): ert3.storage.init_experiment( workspace=tmpdir, experiment_name="my_experiment", - parameters=[], + parameters={}, ensemble_size=42, responses=[], ) @@ -61,7 +61,7 @@ def test_double_add_experiment(tmpdir, ert_storage): ert3.storage.init_experiment( workspace=tmpdir, experiment_name="my_experiment", - parameters=[], + parameters={}, ensemble_size=42, responses=[], ) @@ -72,14 +72,17 @@ def test_add_experiments(tmpdir, ert_storage): ert3.storage.init(workspace=tmpdir) experiment_names = ["a", "b", "c", "super-experiment", "explosions"] - experiment_parameters = [ + experiment_parameter_records = [ ["x"], ["a", "b"], ["alpha", "beta"], ["oxygen", "heat", "fuel"], ] - experiments = zip(experiment_names, experiment_parameters) - for idx, (experiment_name, experiment_parameters) in enumerate(experiments): + experiments = zip(experiment_names, experiment_parameter_records) + for idx, (experiment_name, experiment_parameter_records) in enumerate(experiments): + experiment_parameters = { + key: ["some_coeff"] for key in experiment_parameter_records + } ert3.storage.init_experiment( workspace=tmpdir, experiment_name=experiment_name, @@ -94,7 +97,7 @@ def test_add_experiments(tmpdir, ert_storage): parameters = ert3.storage.get_experiment_parameters( workspace=tmpdir, experiment_name=experiment_name ) - assert experiment_parameters == parameters + assert experiment_parameter_records == parameters @pytest.mark.requires_ert_storage @@ -139,8 +142,11 @@ def test_add_and_get_ensemble_record(tmpdir, raw_ensrec, ert_storage): ensrecord = ert3.data.EnsembleRecord(records=raw_ensrec) ert3.storage.add_ensemble_record( - workspace=tmpdir, record_name="my_ensemble_record", ensemble_record=ensrecord + workspace=tmpdir, + record_name="my_ensemble_record", + ensemble_record=ensrecord, ) + retrieved_ensrecord = ert3.storage.get_ensemble_record( workspace=tmpdir, record_name="my_ensemble_record" ) @@ -148,6 +154,72 @@ def test_add_and_get_ensemble_record(tmpdir, raw_ensrec, ert_storage): assert ensrecord == retrieved_ensrecord +@pytest.mark.requires_ert_storage +@pytest.mark.parametrize( + "raw_ensrec", + ( + [{"data": [i + 0.5, i + 1.1, i + 2.2]} for i in range(3)], + [{"data": {"a": i + 0.5, "b": i + 1.1, "c": i + 2.2}} for i in range(5)], + [{"data": {2: i + 0.5, 5: i + 1.1, 7: i + 2.2}} for i in range(2)], + ), +) +def test_add_and_get_ensemble_parameter_record(tmpdir, raw_ensrec, ert_storage): + """This tests a workaround so that webviz-ert is able to visualise parameters. + It expects records which are marked as parameter to contain only one value + per realisation, while ERT 3 uses multiple variables per record. That is, in + ERT 3, it is possible to specify a record named "coefficients" which + contains the variables "a", "b" and "c". These are rendered as column + labels, which ERT Storage accepts, but webviz-ert doesn't know how to use. + + The workaround involves splitting up any parameter record into its variable + parts, so "coefficients.a", "coefficients.b", "coefficients.c". This is an + implementation detail that exists entirely within the `ert3.storage` module, + and isn't visible outside of it. + + In order to test this behaviour, we use a testing-only kwarg called + `_flatten` in the function `ert3.storage.get_ensemble_record_names`, which + lets us see the record names that are hidden outside of the `ert3.storage` + module. + + """ + raw_data = raw_ensrec[0]["data"] + assert isinstance(raw_data, (list, dict)) + if isinstance(raw_data, list): + indices = [str(x) for x in range(len(raw_data))] + else: + indices = [str(x) for x in raw_data] + + ert3.storage.init(workspace=tmpdir) + ert3.storage.init_experiment( + workspace=tmpdir, + experiment_name="experiment_name", + parameters={"my_ensemble_record": indices}, + ensemble_size=len(raw_ensrec), + responses=[], + ) + + ensrecord = ert3.data.EnsembleRecord(records=raw_ensrec) + ert3.storage.add_ensemble_record( + workspace=tmpdir, + experiment_name="experiment_name", + record_name="my_ensemble_record", + ensemble_record=ensrecord, + ) + + record_names = ert3.storage.get_ensemble_record_names( + workspace=tmpdir, experiment_name="experiment_name", _flatten=False + ) + assert {f"my_ensemble_record.{x}" for x in indices} == set(record_names) + + retrieved_ensrecord = ert3.storage.get_ensemble_record( + workspace=tmpdir, + experiment_name="experiment_name", + record_name="my_ensemble_record", + ) + + assert ensrecord == retrieved_ensrecord + + @pytest.mark.requires_ert_storage def test_add_ensemble_record_twice(tmpdir, ert_storage): ert3.storage.init(workspace=tmpdir) @@ -188,7 +260,7 @@ def test_add_and_get_experiment_ensemble_record(tmpdir, ert_storage): ert3.storage.init_experiment( workspace=tmpdir, experiment_name=experiment, - parameters=[], + parameters={}, ensemble_size=ensemble_size, responses=[], ) @@ -262,7 +334,7 @@ def test_get_record_names(tmpdir, ert_storage): ert3.storage.init_experiment( workspace=tmpdir, experiment_name=experiment, - parameters=[], + parameters={}, ensemble_size=ensemble_size, responses=[], ) @@ -304,7 +376,7 @@ def test_delete_experiment(tmpdir, ert_storage): ert3.storage.init_experiment( workspace=tmpdir, experiment_name="test", - parameters=[], + parameters={}, ensemble_size=42, responses=[], ) @@ -341,7 +413,7 @@ def test_get_ensemble_responses( ert3.storage.init_experiment( workspace=tmpdir, experiment_name=experiment, - parameters=[], + parameters={}, ensemble_size=1, responses=responses, ) @@ -375,7 +447,7 @@ def test_ensemble_responses_and_parameters(tmpdir, ert_storage): ert3.storage.init_experiment( workspace=tmpdir, experiment_name=experiment, - parameters=["resp1"], + parameters={"resp1": ["some-key"]}, ensemble_size=1, responses=responses, ) diff --git a/tests/ert3/workspace/test_workspace.py b/tests/ert3/workspace/test_workspace.py index 3c5f001bfb2..fe871f37ebb 100644 --- a/tests/ert3/workspace/test_workspace.py +++ b/tests/ert3/workspace/test_workspace.py @@ -80,7 +80,7 @@ def test_workspace_experiment_has_run(tmpdir, ert_storage): ert3.storage.init_experiment( workspace=tmpdir, experiment_name="test1", - parameters=[], + parameters={}, ensemble_size=42, responses=[], )