From 772f1c1be7dff29307211d86d4e9bfa1459b97b6 Mon Sep 17 00:00:00 2001 From: larsevj Date: Fri, 4 Oct 2024 16:56:21 +0200 Subject: [PATCH] Only use parameters from design matrix --- src/ert/config/design_matrix.py | 77 ++++++++----------- .../test_design_matrix.py | 30 +++----- 2 files changed, 40 insertions(+), 67 deletions(-) diff --git a/src/ert/config/design_matrix.py b/src/ert/config/design_matrix.py index e44118a8569..c05790a3ffb 100644 --- a/src/ert/config/design_matrix.py +++ b/src/ert/config/design_matrix.py @@ -1,13 +1,12 @@ from __future__ import annotations -from collections import defaultdict from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional import pandas as pd -from ert.config.gen_kw_config import GenKwConfig +from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition from ._option_dict import option_dict from .parsing import ( @@ -28,6 +27,8 @@ class DesignMatrix: xls_filename: Path design_sheet: str default_sheet: str + design_matrix_df: Optional[pd.DataFrame] = None + parameter_configuration: Optional[list[ParameterConfig]] = None @classmethod def from_config_list(cls, config_list: List[str]) -> "DesignMatrix": @@ -71,8 +72,7 @@ def from_config_list(cls, config_list: List[str]) -> "DesignMatrix": def read_design_matrix( self, - parameter_configurations: List[ParameterConfig], - ) -> pd.DataFrame: + ) -> None: """ Reads out all file content from different files and create dataframes """ @@ -80,8 +80,7 @@ def read_design_matrix( self.xls_filename, self.design_sheet ) if "REAL" in design_matrix_df.columns: - design_matrix_df.set_index(design_matrix_df["REAL"]) - del design_matrix_df["REAL"] + design_matrix_df = design_matrix_df.set_index("REAL", drop=True) try: DesignMatrix._validate_design_matrix_header(design_matrix_df) except ValueError as err: @@ -98,42 +97,30 @@ def read_design_matrix( # ignoring errors here is deprecated in pandas, should find another solution # design_matrix_sheet = design_matrix_sheet.apply(pd.to_numeric, errors="ignore") - - parameter_groups = defaultdict(list) - parameter_map = [] - all_genkw_configs = [ - param_group - for param_group in parameter_configurations - if isinstance(param_group, GenKwConfig) - ] - errors = {} - for param in design_matrix_df.columns: - par_gp = [] - for param_group in all_genkw_configs: - if param in param_group: - par_gp.append(param_group.name) - - if not par_gp: - parameter_name = "DESIGN_MATRIX" - parameter_groups[parameter_name].append(param) - parameter_map.append((parameter_name, param)) - elif len(par_gp) == 1: - parameter_name = par_gp[0] - parameter_groups[parameter_name].append(param) - parameter_map.append((parameter_name, param)) - else: - errors[param] = par_gp - - if errors: - msg = "" - for key, value in errors.items(): - msg += ( - f"The following parameter '{key}' was found in multiple" - f" GenKw parameters groups: {value}." + parameter_configuration = {} + transform_function_definitions: list[TransformFunctionDefinition] = [] + for parameter in design_matrix_df.columns: + transform_function_definitions.append( + TransformFunctionDefinition( + name=parameter, + param_name="RAW", + values=[], ) - raise ValueError(msg) - design_matrix_df.columns = pd.MultiIndex.from_tuples(parameter_map) - return design_matrix_df + ) + parameter_configuration[DESIGN_MATRIX_GROUP] = GenKwConfig( + name=DESIGN_MATRIX_GROUP, + forward_init=False, + template_file=None, + output_file=None, + transform_function_definitions=transform_function_definitions, + update=False, + ) + + design_matrix_df.columns = pd.MultiIndex.from_product( + [[DESIGN_MATRIX_GROUP], design_matrix_df.columns] + ) + self.design_matrix_df = design_matrix_df + self.parameter_configuration = parameter_configuration @staticmethod def _read_excel( @@ -205,8 +192,4 @@ def _read_defaultssheet( "initial or trailing whitespace." ) - default_df = default_df.rename(columns={0: "keys", 1: "defaults"}) - defaults = {} - for _, row in default_df.iterrows(): - defaults[row["keys"]] = row["defaults"] - return defaults + return {row[0]: row[1] for _, row in default_df.iterrows()} diff --git a/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py b/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py index ffa226e4e99..d610a7f0fd3 100644 --- a/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py +++ b/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py @@ -1,29 +1,19 @@ -import time - import pandas as pd -import pytest -from ert.config import ErtConfig +from ert.config import DesignMatrix -@pytest.mark.usefixtures("copy_poly_case") -def test_reading_design_matrix(copy_poly_case): +def test_reading_design_matrix(tmp_path): + design_path = tmp_path / "design_matrix.xlsx" design_matrix_df = pd.DataFrame( {"REAL": [0, 1, 2], "a": [1, 2, 3], "b": [0, 2, 0], "c": [3, 1, 3]} ) - default_sheet_df = pd.DataFrame() - with pd.ExcelWriter("design_matrix.xlsx") as xl_write: + default_sheet_df = pd.DataFrame([["one", 1], ["b", 4], ["d", 6]]) + with pd.ExcelWriter(design_path) as xl_write: design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01") - default_sheet_df.to_excel(xl_write, index=False, sheet_name="DefaultValues") - - with open("poly.ert", "a", encoding="utf-8") as fhandle: - fhandle.write( - "DESIGN_MATRIX design_matrix.xlsx DESIGN_SHEET:DesignSheet01 DEFAULT_SHEET:DefaultValues" + default_sheet_df.to_excel( + xl_write, index=False, sheet_name="DefaultValues", header=False ) - ert_config = ErtConfig.from_file("poly.ert") - parameter_configurations = ert_config.ensemble_config.parameter_configuration - t = time.perf_counter() - _design_frame = ert_config.analysis_config.design_matrix.read_design_matrix( - parameter_configurations - ) - print(f"Read design matrix time_used {(time.perf_counter() - t):.4f}s") + design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues") + design_matrix.read_design_matrix() + print("\n The design matrix:\n", design_matrix.design_matrix_df)