From 92ddf3acc1df48353fa92338118f301e6fea9c1f Mon Sep 17 00:00:00 2001 From: larsevj Date: Thu, 26 Sep 2024 15:06:58 +0200 Subject: [PATCH] Support reading design_matrix --- pyproject.toml | 1 + src/ert/config/__init__.py | 2 + src/ert/config/analysis_config.py | 4 +- src/ert/config/design_matrix.py | 162 +++++++++++++++++- src/ert/config/gen_kw_config.py | 7 +- .../test_design_matrix.py | 29 ++++ 6 files changed, 199 insertions(+), 6 deletions(-) create mode 100644 tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py diff --git a/pyproject.toml b/pyproject.toml index 5c8b5d8a55d..20817b4ea05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "matplotlib", "netCDF4", "numpy<2", + "openpyxl", # extra dependency for pandas (excel) "orjson", "packaging", "pandas", diff --git a/src/ert/config/__init__.py b/src/ert/config/__init__.py index ec01d6db8b3..971625cc36d 100644 --- a/src/ert/config/__init__.py +++ b/src/ert/config/__init__.py @@ -1,6 +1,7 @@ from .analysis_config import AnalysisConfig from .analysis_module import AnalysisModule, ESSettings, IESSettings from .capture_validation import capture_validation +from .design_matrix import DesignMatrix from .enkf_observation_implementation_type import EnkfObservationImplementationType from .ensemble_config import EnsembleConfig from .ert_config import ErtConfig @@ -48,6 +49,7 @@ "ConfigValidationError", "ConfigValidationError", "ConfigWarning", + "DesignMatrix", "ESSettings", "EnkfObs", "EnkfObservationImplementationType", diff --git a/src/ert/config/analysis_config.py b/src/ert/config/analysis_config.py index a2afdf9ad6f..5bdd376f8c9 100644 --- a/src/ert/config/analysis_config.py +++ b/src/ert/config/analysis_config.py @@ -41,7 +41,7 @@ class AnalysisConfig: ies_module: IESSettings = field(default_factory=IESSettings) observation_settings: UpdateSettings = field(default_factory=UpdateSettings) num_iterations: int = 1 - design_matrix_args: Optional[DesignMatrix] = None + design_matrix: Optional[DesignMatrix] = None @no_type_check @classmethod @@ -194,7 +194,7 @@ def from_dict(cls, config_dict: ConfigDict) -> "AnalysisConfig": observation_settings=obs_settings, es_module=es_settings, ies_module=ies_settings, - design_matrix_args=DesignMatrix.from_config_list(design_matrix_config_list) + design_matrix=DesignMatrix.from_config_list(design_matrix_config_list) if design_matrix_config_list is not None else None, ) diff --git a/src/ert/config/design_matrix.py b/src/ert/config/design_matrix.py index 20b5fd8df0d..e44118a8569 100644 --- a/src/ert/config/design_matrix.py +++ b/src/ert/config/design_matrix.py @@ -1,8 +1,13 @@ from __future__ import annotations +from collections import defaultdict from dataclasses import dataclass from pathlib import Path -from typing import List +from typing import TYPE_CHECKING, List + +import pandas as pd + +from ert.config.gen_kw_config import GenKwConfig from ._option_dict import option_dict from .parsing import ( @@ -10,6 +15,13 @@ ErrorInfo, ) +if TYPE_CHECKING: + from ert.config import ( + ParameterConfig, + ) + +DESIGN_MATRIX_GROUP = "DESIGN_MATRIX" + @dataclass class DesignMatrix: @@ -41,6 +53,12 @@ def from_config_list(cls, config_list: List[str]) -> "DesignMatrix": errors.append( ErrorInfo("Missing required DEFAULT_SHEET").set_context(config_list) ) + if design_sheet is not None and design_sheet == default_sheet: + errors.append( + ErrorInfo( + "DESIGN_SHEET and DEFAULT_SHEET can not be the same." + ).set_context(config_list) + ) if errors: raise ConfigValidationError.from_collected(errors) assert design_sheet is not None @@ -50,3 +68,145 @@ def from_config_list(cls, config_list: List[str]) -> "DesignMatrix": design_sheet=design_sheet, default_sheet=default_sheet, ) + + def read_design_matrix( + self, + parameter_configurations: List[ParameterConfig], + ) -> pd.DataFrame: + """ + Reads out all file content from different files and create dataframes + """ + design_matrix_df = DesignMatrix._read_excel( + self.xls_filename, self.design_sheet + ) + if "REAL" in design_matrix_df.columns: + design_matrix_df.set_index(design_matrix_df["REAL"]) + del design_matrix_df["REAL"] + try: + DesignMatrix._validate_design_matrix_header(design_matrix_df) + except ValueError as err: + raise ValueError(f"Design matrix not valid, error: {err!s}") from err + + # Todo: Check for invalid realizations, drop them maybe? + # This should probably handle/(fill in) missing values in design_matrix_sheet as well? Or maybe not. + defaults = DesignMatrix._read_defaultssheet( + self.xls_filename, self.default_sheet + ) + for k, v in defaults.items(): + if k not in design_matrix_df.columns: + design_matrix_df[k] = v + + # ignoring errors here is deprecated in pandas, should find another solution + # design_matrix_sheet = design_matrix_sheet.apply(pd.to_numeric, errors="ignore") + + parameter_groups = defaultdict(list) + parameter_map = [] + all_genkw_configs = [ + param_group + for param_group in parameter_configurations + if isinstance(param_group, GenKwConfig) + ] + errors = {} + for param in design_matrix_df.columns: + par_gp = [] + for param_group in all_genkw_configs: + if param in param_group: + par_gp.append(param_group.name) + + if not par_gp: + parameter_name = "DESIGN_MATRIX" + parameter_groups[parameter_name].append(param) + parameter_map.append((parameter_name, param)) + elif len(par_gp) == 1: + parameter_name = par_gp[0] + parameter_groups[parameter_name].append(param) + parameter_map.append((parameter_name, param)) + else: + errors[param] = par_gp + + if errors: + msg = "" + for key, value in errors.items(): + msg += ( + f"The following parameter '{key}' was found in multiple" + f" GenKw parameters groups: {value}." + ) + raise ValueError(msg) + design_matrix_df.columns = pd.MultiIndex.from_tuples(parameter_map) + return design_matrix_df + + @staticmethod + def _read_excel( + file_name: Path | str, + sheet_name: str, + usecols: int | list[int] | None = None, + header: int | None = 0, + ) -> pd.DataFrame: + """ + Make dataframe from excel file + :return: Dataframe + :raises: OsError if file not found + :raises: ValueError if file not loaded correctly + """ + dframe: pd.DataFrame = pd.read_excel( + file_name, + sheet_name, + usecols=usecols, + header=header, + ) + return dframe.dropna(axis=1, how="all") + + def _validate_design_matrix_header(design_matrix: pd.DataFrame) -> None: + """ + Validate header in user inputted design matrix + :raises: ValueError if design matrix contains empty headers + """ + if design_matrix.empty: + return + try: + unnamed = design_matrix.loc[ + :, design_matrix.columns.str.contains("^Unnamed") + ] + except ValueError as err: + # We catch because int/floats as column headers + # in xlsx gets read as int/float and is not valid to index by. + raise ValueError( + f"Invalid value in design matrix header, error: {err !s}" + ) from err + column_indexes = [int(x.split(":")[1]) for x in unnamed.columns.to_numpy()] + if len(column_indexes) > 0: + raise ValueError(f"Column headers not present in column {column_indexes}") + + @staticmethod + def _read_defaultssheet( + xlsfilename: Path | str, defaultssheetname: str + ) -> dict[str, str]: + """ + Construct a dataframe of keys and values to be used as defaults from the + first two columns in a spreadsheet. + + Returns a dict of default values + + :raises: ValueError if defaults sheet is non-empty but non-parsable + """ + default_df = DesignMatrix._read_excel( + xlsfilename, defaultssheetname, usecols=[0, 1], header=None + ) + if default_df.empty: + return {} + if len(default_df.columns) < 2: + raise ValueError("Defaults sheet must have at least two columns") + # Look for initial or trailing whitespace in parameter names. This + # is disallowed as it can create user confusion and has no use-case. + for paramname in default_df.loc[:, 0]: + if paramname != paramname.strip(): + raise ValueError( + f'Parameter name "{paramname}" in default values contains ' + "initial or trailing whitespace." + ) + + default_df = default_df.rename(columns={0: "keys", 1: "defaults"}) + defaults = {} + for _, row in default_df.iterrows(): + defaults[row["keys"]] = row["defaults"] + return defaults diff --git a/src/ert/config/gen_kw_config.py b/src/ert/config/gen_kw_config.py index a634573c965..459b9b46aad 100644 --- a/src/ert/config/gen_kw_config.py +++ b/src/ert/config/gen_kw_config.py @@ -70,9 +70,7 @@ class TransformFunctionDefinition: class GenKwConfig(ParameterConfig): template_file: Optional[str] output_file: Optional[str] - transform_function_definitions: ( - List[TransformFunctionDefinition] | List[Dict[Any, Any]] - ) + transform_function_definitions: List[TransformFunctionDefinition] forward_init_file: Optional[str] = None def __post_init__(self) -> None: @@ -90,6 +88,9 @@ def __post_init__(self) -> None: ) self._validate() + def __contains__(self, item: str) -> bool: + return item in [v.name for v in self.transform_function_definitions] + def __len__(self) -> int: return len(self.transform_functions) diff --git a/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py b/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py new file mode 100644 index 00000000000..ffa226e4e99 --- /dev/null +++ b/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py @@ -0,0 +1,29 @@ +import time + +import pandas as pd +import pytest + +from ert.config import ErtConfig + + +@pytest.mark.usefixtures("copy_poly_case") +def test_reading_design_matrix(copy_poly_case): + design_matrix_df = pd.DataFrame( + {"REAL": [0, 1, 2], "a": [1, 2, 3], "b": [0, 2, 0], "c": [3, 1, 3]} + ) + default_sheet_df = pd.DataFrame() + with pd.ExcelWriter("design_matrix.xlsx") as xl_write: + design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01") + default_sheet_df.to_excel(xl_write, index=False, sheet_name="DefaultValues") + + with open("poly.ert", "a", encoding="utf-8") as fhandle: + fhandle.write( + "DESIGN_MATRIX design_matrix.xlsx DESIGN_SHEET:DesignSheet01 DEFAULT_SHEET:DefaultValues" + ) + ert_config = ErtConfig.from_file("poly.ert") + parameter_configurations = ert_config.ensemble_config.parameter_configuration + t = time.perf_counter() + _design_frame = ert_config.analysis_config.design_matrix.read_design_matrix( + parameter_configurations + ) + print(f"Read design matrix time_used {(time.perf_counter() - t):.4f}s")