From 2ce1d41ddbda1be31ccbc792bd544396104e24f0 Mon Sep 17 00:00:00 2001 From: larsevj Date: Tue, 8 Oct 2024 17:24:04 +0200 Subject: [PATCH] Add more tests --- src/ert/config/design_matrix.py | 15 +++- .../test_design_matrix.py | 70 ++++++++++++++++++- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/src/ert/config/design_matrix.py b/src/ert/config/design_matrix.py index 641233631d5..e353e8efc32 100644 --- a/src/ert/config/design_matrix.py +++ b/src/ert/config/design_matrix.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, List, Optional import pandas as pd +from pandas.api.types import is_integer_dtype from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition @@ -27,6 +28,8 @@ class DesignMatrix: xls_filename: Path design_sheet: str default_sheet: str + num_realizations: Optional[int] = None + active_realizations: Optional[list[int]] = None design_matrix_df: Optional[pd.DataFrame] = None parameter_configuration: Optional[dict[str, ParameterConfig]] = None @@ -80,7 +83,13 @@ def read_design_matrix( self.xls_filename, self.design_sheet ) if "REAL" in design_matrix_df.columns: - design_matrix_df = design_matrix_df.set_index("REAL", drop=True) + if not is_integer_dtype(design_matrix_df.dtypes["REAL"]) or any( + design_matrix_df["REAL"] < 0 + ): + raise ValueError("REAL column must only contain positive integers") + design_matrix_df = design_matrix_df.set_index( + "REAL", drop=True, verify_integrity=True + ) try: DesignMatrix._validate_design_matrix_header(design_matrix_df) except ValueError as err: @@ -119,6 +128,10 @@ def read_design_matrix( design_matrix_df.columns = pd.MultiIndex.from_product( [[DESIGN_MATRIX_GROUP], design_matrix_df.columns] ) + reals = design_matrix_df.index.tolist() + self.num_realizations = len(reals) + self.active_realizations = [x in reals for x in range(max(reals))] + self.design_matrix_df = design_matrix_df self.parameter_configuration = parameter_configuration diff --git a/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py b/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py index d610a7f0fd3..145a3d43220 100644 --- a/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py +++ b/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py @@ -1,4 +1,6 @@ +import numpy as np import pandas as pd +import pytest from ert.config import DesignMatrix @@ -6,7 +8,73 @@ def test_reading_design_matrix(tmp_path): design_path = tmp_path / "design_matrix.xlsx" design_matrix_df = pd.DataFrame( - {"REAL": [0, 1, 2], "a": [1, 2, 3], "b": [0, 2, 0], "c": [3, 1, 3]} + { + "REAL": [0, 1, 2], + "a": [1, 2, 3], + "b": [0, 2, 0], + "c": [3, 1, 3], + } + ) + default_sheet_df = pd.DataFrame([["one", 1], ["b", 4], ["d", 6]]) + with pd.ExcelWriter(design_path) as xl_write: + design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01") + default_sheet_df.to_excel( + xl_write, index=False, sheet_name="DefaultValues", header=False + ) + design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues") + design_matrix.read_design_matrix() + + +@pytest.mark.parametrize( + "real_column, error_msg", + [ + pytest.param([0, 1, 1], "Index has duplicate keys", id="duplicate entries"), + pytest.param( + [0, 1.1, 2], + "REAL column must only contain positive integers", + id="invalid float values", + ), + pytest.param( + [0, "a", 10], + "REAL column must only contain positive integers", + id="invalid types", + ), + ], +) +def test_reading_design_matrix_validate_reals(tmp_path, real_column, error_msg): + design_path = tmp_path / "design_matrix.xlsx" + design_matrix_df = pd.DataFrame( + { + "REAL": real_column, + "a": [1, 2, 3], + "b": [0, 2, 0], + "c": [3, 1, 3], + } + ) + default_sheet_df = pd.DataFrame() + with pd.ExcelWriter(design_path) as xl_write: + design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01") + default_sheet_df.to_excel( + xl_write, index=False, sheet_name="DefaultValues", header=False + ) + design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues") + with pytest.raises(ValueError, match=error_msg): + design_matrix.read_design_matrix() + + +def test_reading_design_matrix_duplicate_columns(tmp_path): + design_path = tmp_path / "design_matrix.xlsx" + design_matrix_df = pd.DataFrame( + { + "REAL": [0, 1, -4], + "a": [1, 2, 3], + "b": [0, 2, 0], + "c": [3, 1, 3], + "0": ["a", 2, "c"], + } + ) + design_matrix_df = pd.DataFrame( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), columns=["a", "b", "a"] ) default_sheet_df = pd.DataFrame([["one", 1], ["b", 4], ["d", 6]]) with pd.ExcelWriter(design_path) as xl_write: