Skip to content

Commit

Permalink
Add run experiment with design matrix to ensemble experiment panel
Browse files Browse the repository at this point in the history
- Prefil active realization box with realizations from design matrix
- Use design_matrix parameters in ensemble experiment
- add test run cli with design matrix and poly example
- add test that save parameters internalize DataFrame parameters in the storage
- add merge function to merge design parameters with existing parameters
 -- Raise Validation error when having multiple overlapping groups
  • Loading branch information
xjules committed Dec 10, 2024
1 parent bf2d515 commit a670fcb
Show file tree
Hide file tree
Showing 10 changed files with 516 additions and 72 deletions.
91 changes: 74 additions & 17 deletions src/ert/config/design_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,10 @@
from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition

from ._option_dict import option_dict
from .parsing import (
ConfigValidationError,
ErrorInfo,
)
from .parsing import ConfigValidationError, ErrorInfo

if TYPE_CHECKING:
from ert.config import (
ParameterConfig,
)
from ert.config import ParameterConfig

DESIGN_MATRIX_GROUP = "DESIGN_MATRIX"

Expand All @@ -31,10 +26,17 @@ class DesignMatrix:
default_sheet: str

def __post_init__(self) -> None:
self.num_realizations: Optional[int] = None
self.active_realizations: Optional[List[bool]] = None
self.design_matrix_df: Optional[pd.DataFrame] = None
self.parameter_configuration: Optional[Dict[str, ParameterConfig]] = None
try:
(
self.active_realizations,
self.design_matrix_df,
self.parameter_configuration,
) = self.read_design_matrix()
except (ValueError, AttributeError) as exc:
raise ConfigValidationError.with_context(
f"Error reading design matrix {self.xls_filename}: {exc}",
str(self.xls_filename),
) from exc

@classmethod
def from_config_list(cls, config_list: List[str]) -> "DesignMatrix":
Expand Down Expand Up @@ -76,9 +78,64 @@ def from_config_list(cls, config_list: List[str]) -> "DesignMatrix":
default_sheet=default_sheet,
)

def merge_with_existing_parameters(
self, existing_parameters: List[ParameterConfig]
) -> tuple[List[ParameterConfig], ParameterConfig | None]:
"""
This method merges the design matrix parameters with the existing parameters and
returns the new list of existing parameters, wherein we drop GEN_KW group having a full overlap with the design matrix group.
GEN_KW group that was dropped will acquire a new name from the design matrix group.
Additionally, the ParameterConfig which is the design matrix group is returned separately.
Args:
existing_parameters (List[ParameterConfig]): List of existing parameters
Raises:
ConfigValidationError: If there is a partial overlap between the design matrix group and any existing GEN_KW group
Returns:
tuple[List[ParameterConfig], ParameterConfig]: List of existing parameters and the dedicated design matrix group
"""

new_param_config: List[ParameterConfig] = []

design_parameter_group = self.parameter_configuration[DESIGN_MATRIX_GROUP]
design_keys = []
if isinstance(design_parameter_group, GenKwConfig):
design_keys = [e.name for e in design_parameter_group.transform_functions]

design_group_added = False
for parameter_group in existing_parameters:
if not isinstance(parameter_group, GenKwConfig):
new_param_config += [parameter_group]
continue
existing_keys = [e.name for e in parameter_group.transform_functions]
if set(existing_keys) == set(design_keys):
if design_group_added:
raise ConfigValidationError(
(
"Multiple overlapping groups with design matrix found in existing parameters!\n"
f"{design_parameter_group.name} and {parameter_group.name}"
)
)

design_parameter_group.name = parameter_group.name
design_group_added = True
elif set(design_keys) & set(existing_keys):
raise ConfigValidationError(
(
"Overlapping parameter names found in design matrix!\n"
f"{DESIGN_MATRIX_GROUP}:{design_keys}\n{parameter_group.name}:{existing_keys}"
"\nThey need to much exactly or not at all."
)
)
else:
new_param_config += [parameter_group]
return new_param_config, design_parameter_group

def read_design_matrix(
self,
) -> None:
) -> tuple[List[bool], pd.DataFrame, Dict[str, ParameterConfig]]:
# Read the parameter names (first row) as strings to prevent pandas from modifying them.
# This ensures that duplicate or empty column names are preserved exactly as they appear in the Excel sheet.
# By doing this, we can properly validate variable names, including detecting duplicates or missing names.
Expand Down Expand Up @@ -142,11 +199,11 @@ def read_design_matrix(
[[DESIGN_MATRIX_GROUP], design_matrix_df.columns]
)
reals = design_matrix_df.index.tolist()
self.num_realizations = len(reals)
self.active_realizations = [x in reals for x in range(max(reals) + 1)]

self.design_matrix_df = design_matrix_df
self.parameter_configuration = parameter_configuration
return (
[x in reals for x in range(max(reals) + 1)],
design_matrix_df,
parameter_configuration,
)

@staticmethod
def _read_excel(
Expand Down
34 changes: 27 additions & 7 deletions src/ert/enkf_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,17 @@
)

import orjson
import pandas as pd
import xarray as xr
from numpy.random import SeedSequence

from ert.config.ert_config import forward_model_data_to_json
from ert.config.forward_model_step import ForwardModelStep
from ert.config.model_config import ModelConfig
from ert.substitutions import Substitutions, substitute_runpath_name

from .config import (
ExtParamConfig,
Field,
GenKwConfig,
ParameterConfig,
SurfaceConfig,
)
from .config import ExtParamConfig, Field, GenKwConfig, ParameterConfig, SurfaceConfig
from .config.design_matrix import DESIGN_MATRIX_GROUP
from .run_arg import RunArg
from .runpaths import Runpaths

Expand Down Expand Up @@ -165,6 +162,29 @@ def _seed_sequence(seed: Optional[int]) -> int:
return int_seed


def save_design_matrix_to_ensemble(
design_matrix_df: pd.DataFrame,
ensemble: Ensemble,
active_realizations: Iterable[int],
design_group_name: str = DESIGN_MATRIX_GROUP,
) -> None:
assert not design_matrix_df.empty
for realization_nr in active_realizations:
row = design_matrix_df.loc[realization_nr][DESIGN_MATRIX_GROUP]
ds = xr.Dataset(
{
"values": ("names", list(row.values)),
"transformed_values": ("names", list(row.values)),
"names": list(row.keys()),
}
)
ensemble.save_parameters(
design_group_name,
realization_nr,
ds,
)


def sample_prior(
ensemble: Ensemble,
active_realizations: Iterable[int],
Expand Down
30 changes: 12 additions & 18 deletions src/ert/gui/simulation/ensemble_experiment_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ert.gui.tools.design_matrix.design_matrix_panel import DesignMatrixPanel
from ert.mode_definitions import ENSEMBLE_EXPERIMENT_MODE
from ert.run_models import EnsembleExperiment
from ert.validation import RangeStringArgument
from ert.validation import ActiveRange, RangeStringArgument
from ert.validation.proper_name_argument import ExperimentValidation, ProperNameArgument

from .experiment_config_panel import ExperimentConfigPanel
Expand Down Expand Up @@ -85,6 +85,9 @@ def __init__(

design_matrix = analysis_config.design_matrix
if design_matrix is not None:
self._active_realizations_field.setText(
ActiveRange(design_matrix.active_realizations).rangestring
)
show_dm_param_button = QPushButton("Show parameters")
show_dm_param_button.setObjectName("show-dm-parameters")
show_dm_param_button.setMinimumWidth(50)
Expand Down Expand Up @@ -113,23 +116,14 @@ def __init__(
self.notifier.ertChanged.connect(self._update_experiment_name_placeholder)

def on_show_dm_params_clicked(self, design_matrix: DesignMatrix) -> None:
assert design_matrix is not None

if design_matrix.design_matrix_df is None:
design_matrix.read_design_matrix()

if (
design_matrix.design_matrix_df is not None
and not design_matrix.design_matrix_df.empty
):
viewer = DesignMatrixPanel(
design_matrix.design_matrix_df,
design_matrix.xls_filename.name,
)
viewer.setMinimumHeight(500)
viewer.setMinimumWidth(1000)
viewer.adjustSize()
viewer.exec_()
viewer = DesignMatrixPanel(
design_matrix.design_matrix_df,
design_matrix.xls_filename.name,
)
viewer.setMinimumHeight(500)
viewer.setMinimumWidth(1000)
viewer.adjustSize()
viewer.exec_()

@Slot(ExperimentConfigPanel)
def experimentTypeChanged(self, w: ExperimentConfigPanel) -> None:
Expand Down
33 changes: 30 additions & 3 deletions src/ert/run_models/ensemble_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

import numpy as np

from ert.enkf_main import sample_prior
from ert.config import ConfigValidationError
from ert.enkf_main import sample_prior, save_design_matrix_to_ensemble
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Ensemble, Experiment, Storage
from ert.trace import tracer

from ..run_arg import create_run_arguments
from .base_run_model import BaseRunModel, StatusEvents
from .base_run_model import BaseRunModel, ErtRunError, StatusEvents

if TYPE_CHECKING:
from ert.config import ErtConfig, QueueConfig
Expand Down Expand Up @@ -64,10 +65,27 @@ def run_experiment(
) -> None:
self.log_at_startup()
self.restart = restart
# If design matrix is present, we try to merge design matrix parameters
# to the experiment parameters and set new active realizations
parameters_config = self.ert_config.ensemble_config.parameter_configuration
design_matrix = self.ert_config.analysis_config.design_matrix
design_matrix_group = None
if design_matrix is not None:
try:
parameters_config, design_matrix_group = (
design_matrix.merge_with_existing_parameters(parameters_config)
)
except ConfigValidationError as exc:
raise ErtRunError(str(exc)) from exc

if not restart:
self.experiment = self._storage.create_experiment(
name=self.experiment_name,
parameters=self.ert_config.ensemble_config.parameter_configuration,
parameters=(
[*parameters_config, design_matrix_group]
if design_matrix_group is not None
else parameters_config
),
observations=self.ert_config.observations,
responses=self.ert_config.ensemble_config.response_configuration,
)
Expand All @@ -90,12 +108,21 @@ def run_experiment(
np.array(self.active_realizations, dtype=bool),
ensemble=self.ensemble,
)

sample_prior(
self.ensemble,
np.where(self.active_realizations)[0],
random_seed=self.random_seed,
)

if design_matrix_group is not None and design_matrix is not None:
save_design_matrix_to_ensemble(
design_matrix.design_matrix_df,
self.ensemble,
np.where(self.active_realizations)[0],
design_matrix_group.name,
)

self._evaluate_and_postprocess(
run_args,
self.ensemble,
Expand Down
17 changes: 12 additions & 5 deletions src/ert/run_models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,20 @@ def _setup_ensemble_experiment(
args: Namespace,
status_queue: SimpleQueue[StatusEvents],
) -> EnsembleExperiment:
active_realizations = _realizations(args, config.model_config.num_realizations)
active_realizations = _realizations(
args, config.model_config.num_realizations
).tolist()
if (
config.analysis_config.design_matrix is not None
and config.analysis_config.design_matrix.active_realizations is not None
):
active_realizations = config.analysis_config.design_matrix.active_realizations
experiment_name = args.experiment_name
assert experiment_name is not None

return EnsembleExperiment(
random_seed=config.random_seed,
active_realizations=active_realizations.tolist(),
active_realizations=active_realizations,
ensemble_name=args.current_ensemble,
minimum_required_realizations=config.analysis_config.minimum_required_realizations,
experiment_name=experiment_name,
Expand Down Expand Up @@ -271,9 +278,9 @@ def _setup_iterative_ensemble_smoother(
random_seed=config.random_seed,
active_realizations=active_realizations.tolist(),
target_ensemble=_iterative_ensemble_format(args),
number_of_iterations=int(args.num_iterations)
if args.num_iterations is not None
else 4,
number_of_iterations=(
int(args.num_iterations) if args.num_iterations is not None else 4
),
minimum_required_realizations=config.analysis_config.minimum_required_realizations,
num_retries_per_iter=4,
experiment_name=experiment_name,
Expand Down
Loading

0 comments on commit a670fcb

Please sign in to comment.