Skip to content

Commit

Permalink
Support reading design_matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
larsevj committed Oct 8, 2024
1 parent c258c76 commit d08d4f6
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 6 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dependencies = [
"matplotlib",
"netCDF4",
"numpy<2",
"openpyxl", # extra dependency for pandas (excel)
"orjson",
"packaging",
"pandas",
Expand Down
2 changes: 2 additions & 0 deletions src/ert/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .analysis_config import AnalysisConfig
from .analysis_module import AnalysisModule, ESSettings, IESSettings
from .capture_validation import capture_validation
from .design_matrix import DesignMatrix
from .enkf_observation_implementation_type import EnkfObservationImplementationType
from .ensemble_config import EnsembleConfig
from .ert_config import ErtConfig
Expand Down Expand Up @@ -48,6 +49,7 @@
"ConfigValidationError",
"ConfigValidationError",
"ConfigWarning",
"DesignMatrix",
"ESSettings",
"EnkfObs",
"EnkfObservationImplementationType",
Expand Down
4 changes: 2 additions & 2 deletions src/ert/config/analysis_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class AnalysisConfig:
ies_module: IESSettings = field(default_factory=IESSettings)
observation_settings: UpdateSettings = field(default_factory=UpdateSettings)
num_iterations: int = 1
design_matrix_args: Optional[DesignMatrix] = None
design_matrix: Optional[DesignMatrix] = None

@no_type_check
@classmethod
Expand Down Expand Up @@ -194,7 +194,7 @@ def from_dict(cls, config_dict: ConfigDict) -> "AnalysisConfig":
observation_settings=obs_settings,
es_module=es_settings,
ies_module=ies_settings,
design_matrix_args=DesignMatrix.from_config_list(design_matrix_config_list)
design_matrix=DesignMatrix.from_config_list(design_matrix_config_list)
if design_matrix_config_list is not None
else None,
)
Expand Down
162 changes: 161 additions & 1 deletion src/ert/config/design_matrix.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import List
from typing import TYPE_CHECKING, List

import pandas as pd

from ert.config.gen_kw_config import GenKwConfig

from ._option_dict import option_dict
from .parsing import (
ConfigValidationError,
ErrorInfo,
)

if TYPE_CHECKING:
from ert.config import (
ParameterConfig,
)

DESIGN_MATRIX_GROUP = "DESIGN_MATRIX"


@dataclass
class DesignMatrix:
Expand Down Expand Up @@ -41,6 +53,12 @@ def from_config_list(cls, config_list: List[str]) -> "DesignMatrix":
errors.append(
ErrorInfo("Missing required DEFAULT_SHEET").set_context(config_list)
)
if design_sheet is not None and design_sheet == default_sheet:
errors.append(
ErrorInfo(
"DESIGN_SHEET and DEFAULT_SHEET can not be the same."
).set_context(config_list)
)
if errors:
raise ConfigValidationError.from_collected(errors)
assert design_sheet is not None
Expand All @@ -50,3 +68,145 @@ def from_config_list(cls, config_list: List[str]) -> "DesignMatrix":
design_sheet=design_sheet,
default_sheet=default_sheet,
)

def read_design_matrix(
self,
parameter_configurations: List[ParameterConfig],
) -> pd.DataFrame:
"""
Reads out all file content from different files and create dataframes
"""
design_matrix_df = DesignMatrix._read_excel(
self.xls_filename, self.design_sheet
)
if "REAL" in design_matrix_df.columns:
design_matrix_df.set_index(design_matrix_df["REAL"])
del design_matrix_df["REAL"]
try:
DesignMatrix._validate_design_matrix_header(design_matrix_df)
except ValueError as err:
raise ValueError(f"Design matrix not valid, error: {err!s}") from err

# Todo: Check for invalid realizations, drop them maybe?
# This should probably handle/(fill in) missing values in design_matrix_sheet as well? Or maybe not.
defaults = DesignMatrix._read_defaultssheet(
self.xls_filename, self.default_sheet
)
for k, v in defaults.items():
if k not in design_matrix_df.columns:
design_matrix_df[k] = v

# ignoring errors here is deprecated in pandas, should find another solution
# design_matrix_sheet = design_matrix_sheet.apply(pd.to_numeric, errors="ignore")

parameter_groups = defaultdict(list)
parameter_map = []
all_genkw_configs = [
param_group
for param_group in parameter_configurations
if isinstance(param_group, GenKwConfig)
]
errors = {}
for param in design_matrix_df.columns:
par_gp = []
for param_group in all_genkw_configs:
if param in param_group:
par_gp.append(param_group.name)

if not par_gp:
parameter_name = "DESIGN_MATRIX"
parameter_groups[parameter_name].append(param)
parameter_map.append((parameter_name, param))
elif len(par_gp) == 1:
parameter_name = par_gp[0]
parameter_groups[parameter_name].append(param)
parameter_map.append((parameter_name, param))
else:
errors[param] = par_gp

if errors:
msg = ""
for key, value in errors.items():
msg += (
f"The following parameter '{key}' was found in multiple"
f" GenKw parameters groups: {value}."
)
raise ValueError(msg)
design_matrix_df.columns = pd.MultiIndex.from_tuples(parameter_map)
return design_matrix_df

@staticmethod
def _read_excel(
file_name: Path | str,
sheet_name: str,
usecols: int | list[int] | None = None,
header: int | None = 0,
) -> pd.DataFrame:
"""
Make dataframe from excel file
:return: Dataframe
:raises: OsError if file not found
:raises: ValueError if file not loaded correctly
"""
dframe: pd.DataFrame = pd.read_excel(
file_name,
sheet_name,
usecols=usecols,
header=header,
)
return dframe.dropna(axis=1, how="all")

def _validate_design_matrix_header(design_matrix: pd.DataFrame) -> None:
"""
Validate header in user inputted design matrix
:raises: ValueError if design matrix contains empty headers
"""
if design_matrix.empty:
return
try:
unnamed = design_matrix.loc[
:, design_matrix.columns.str.contains("^Unnamed")
]
except ValueError as err:
# We catch because int/floats as column headers
# in xlsx gets read as int/float and is not valid to index by.
raise ValueError(
f"Invalid value in design matrix header, error: {err !s}"
) from err
column_indexes = [int(x.split(":")[1]) for x in unnamed.columns.to_numpy()]
if len(column_indexes) > 0:
raise ValueError(f"Column headers not present in column {column_indexes}")

@staticmethod
def _read_defaultssheet(
xlsfilename: Path | str, defaultssheetname: str
) -> dict[str, str]:
"""
Construct a dataframe of keys and values to be used as defaults from the
first two columns in a spreadsheet.
Returns a dict of default values
:raises: ValueError if defaults sheet is non-empty but non-parsable
"""
default_df = DesignMatrix._read_excel(
xlsfilename, defaultssheetname, usecols=[0, 1], header=None
)
if default_df.empty:
return {}
if len(default_df.columns) < 2:
raise ValueError("Defaults sheet must have at least two columns")
# Look for initial or trailing whitespace in parameter names. This
# is disallowed as it can create user confusion and has no use-case.
for paramname in default_df.loc[:, 0]:
if paramname != paramname.strip():
raise ValueError(
f'Parameter name "{paramname}" in default values contains '
"initial or trailing whitespace."
)

default_df = default_df.rename(columns={0: "keys", 1: "defaults"})
defaults = {}
for _, row in default_df.iterrows():
defaults[row["keys"]] = row["defaults"]
return defaults
7 changes: 4 additions & 3 deletions src/ert/config/gen_kw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ class TransformFunctionDefinition:
class GenKwConfig(ParameterConfig):
template_file: Optional[str]
output_file: Optional[str]
transform_function_definitions: (
List[TransformFunctionDefinition] | List[Dict[Any, Any]]
)
transform_function_definitions: List[TransformFunctionDefinition]
forward_init_file: Optional[str] = None

def __post_init__(self) -> None:
Expand All @@ -90,6 +88,9 @@ def __post_init__(self) -> None:
)
self._validate()

def __contains__(self, item: str) -> bool:
return item in [v.name for v in self.transform_function_definitions]

def __len__(self) -> int:
return len(self.transform_functions)

Expand Down
29 changes: 29 additions & 0 deletions tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import time

import pandas as pd
import pytest

from ert.config import ErtConfig


@pytest.mark.usefixtures("copy_poly_case")
def test_reading_design_matrix(copy_poly_case):
design_matrix_df = pd.DataFrame(
{"REAL": [0, 1, 2], "a": [1, 2, 3], "b": [0, 2, 0], "c": [3, 1, 3]}
)
default_sheet_df = pd.DataFrame()
with pd.ExcelWriter("design_matrix.xlsx") as xl_write:
design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01")
default_sheet_df.to_excel(xl_write, index=False, sheet_name="DefaultValues")

with open("poly.ert", "a", encoding="utf-8") as fhandle:
fhandle.write(
"DESIGN_MATRIX design_matrix.xlsx DESIGN_SHEET:DesignSheet01 DEFAULT_SHEET:DefaultValues"
)
ert_config = ErtConfig.from_file("poly.ert")
parameter_configurations = ert_config.ensemble_config.parameter_configuration
t = time.perf_counter()
_design_frame = ert_config.analysis_config.design_matrix.read_design_matrix(
parameter_configurations
)
print(f"Read design matrix time_used {(time.perf_counter() - t):.4f}s")

0 comments on commit d08d4f6

Please sign in to comment.