From 72a1fc046220c9bebe054d2b91049fbb48a725ee Mon Sep 17 00:00:00 2001 From: "Yngve S. Kristiansen" Date: Fri, 20 Sep 2024 15:35:34 +0200 Subject: [PATCH] Use polars logic in storage_info_widget --- src/ert/config/response_config.py | 5 + src/ert/config/summary_config.py | 9 +- .../manage_experiments/storage_info_widget.py | 182 +++++++++--------- .../manage_experiments/storage_widget.py | 2 +- 4 files changed, 110 insertions(+), 88 deletions(-) diff --git a/src/ert/config/response_config.py b/src/ert/config/response_config.py index 9d807d84435..e41bdfb3ff4 100644 --- a/src/ert/config/response_config.py +++ b/src/ert/config/response_config.py @@ -47,3 +47,8 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Optional[Self]: """Creates a config, given an ert config dict. A response config may depend on several config kws, such as REFCASE for summary.""" + + @classmethod + def display_column(cls, value: Any, column_name: str) -> str: + """Formats a value to a user-friendly displayable format.""" + return str(value) diff --git a/src/ert/config/summary_config.py b/src/ert/config/summary_config.py index 07b21b7f4c4..2cca97bbd3e 100644 --- a/src/ert/config/summary_config.py +++ b/src/ert/config/summary_config.py @@ -3,7 +3,7 @@ import logging from dataclasses import dataclass from datetime import datetime -from typing import TYPE_CHECKING, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Optional, Set, Union from ._read_summary import read_summary from .ensemble_config import Refcase @@ -88,5 +88,12 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Optional[SummaryConfig]: return None + @classmethod + def display_column(cls, value: Any, column_name: str) -> str: + if column_name == "time": + return value.strftime("%Y-%m-%d") + + return str(value) + responses_index.add_response_type(SummaryConfig) diff --git a/src/ert/gui/tools/manage_experiments/storage_info_widget.py b/src/ert/gui/tools/manage_experiments/storage_info_widget.py index c88b8be01c9..013877b5b59 100644 --- a/src/ert/gui/tools/manage_experiments/storage_info_widget.py +++ b/src/ert/gui/tools/manage_experiments/storage_info_widget.py @@ -1,9 +1,8 @@ import json from enum import IntEnum -from functools import reduce from typing import Optional -import numpy as np +import polars import seaborn as sns import yaml from matplotlib.backends.backend_qt5agg import FigureCanvas # type: ignore @@ -176,8 +175,8 @@ def _currentItemChanged( if not selected: return - observation_name = selected.data(1, Qt.ItemDataRole.DisplayRole) - if not observation_name: + observation_key = selected.data(1, Qt.ItemDataRole.DisplayRole) + if not observation_key: return observation_label = selected.data(0, Qt.ItemDataRole.DisplayRole) @@ -186,82 +185,91 @@ def _currentItemChanged( self._figure.clear() ax = self._figure.add_subplot(111) - ax.set_title(observation_name) + ax.set_title(observation_key) ax.grid(True) - observation_ds = observations_dict[observation_name] + response_type, obs_for_type = next( + ( + (response_type, _df) + for response_type, _df in observations_dict.items() + if observation_key in _df["observation_key"] + ), + (None, None), + ) + + assert response_type is not None + assert obs_for_type is not None + + response_config = self._ensemble.experiment.response_configuration[ + response_type + ] + x_axis_col = response_config.primary_key[-1] + + def _filter_on_observation_label(df: polars.DataFrame) -> polars.DataFrame: + # We add a column with the display name of the x axis column + # to correctly compare it to the observation_label + # (which is also such a display name) + return df.with_columns( + df[x_axis_col] + .map_elements( + lambda x: response_config.display_column(x, x_axis_col), + return_dtype=polars.String, + ) + .alias("temp") + ).filter(polars.col("temp").eq(observation_label))[ + [x for x in df.columns if x != "temp"] + ] - response_name = observation_ds.attrs["response"] + obs = obs_for_type.filter(polars.col("observation_key").eq(observation_key)) + obs = _filter_on_observation_label(obs) + + response_key = obs["response_key"].unique().to_list()[0] response_ds = self._ensemble.load_responses( - response_name, + response_key, tuple(self._ensemble.get_realization_list_with_responses()), ) - scaling_ds = self._ensemble.load_observation_scaling_factors() + scaling_df = self._ensemble.load_observation_scaling_factors() def _try_render_scaled_obs() -> None: - if scaling_ds is None: + if scaling_df is None: return None - _obs_df = observation_ds.to_dataframe().reset_index() - # Should store scaling by response type - # and use primary key for the response to - # create the index key. - index_cols = list( - set(observation_ds.to_dataframe().reset_index().columns) - - { - "observations", - "std", - } + index_col = polars.concat_str(response_config.primary_key, separator=", ") + joined = obs.with_columns(index_col.alias("_tmp_index")).join( + scaling_df, + how="left", + left_on=["observation_key", "_tmp_index"], + right_on=["obs_key", "index"], + )[["observations", "std", "scaling_factor"]] + + joined_small = joined[["observations", "std", "scaling_factor"]] + joined_small = joined_small.group_by(["observations", "std"]).agg( + [polars.col("scaling_factor").product()] ) - - # Just to ensure report step comes first as in _es_update - if "report_step" in index_cols: - index_cols = ["report_step", "index"] - - # for summary there is only "time" - index_key = ", ".join([str(_obs_df[x].values[0]) for x in index_cols]) - scaling_factors = ( - scaling_ds.sel(obs_key=observation_name, drop=True) - .sel(index=index_key, drop=True)["scaling_factor"] - .dropna(dim="input_group") - .values + joined_small = joined_small.with_columns( + (joined_small["std"] * joined_small["scaling_factor"]) ) - cumulative_scaling: float = ( - reduce(lambda x, y: x * y, scaling_factors) or 1.0 - ) - - original_std = observation_ds.get("std") - assert original_std ax.errorbar( x="Scaled observation", - y=observation_ds.get("observations"), # type: ignore - yerr=original_std * cumulative_scaling, + y=joined_small["observations"].to_list(), + yerr=joined_small["std"].to_list(), fmt=".", linewidth=1, capsize=4, color="black", ) - # check if the response is empty - if bool(response_ds.dims): - if response_name == "summary": - response_ds = response_ds.sel(name=str(observation_ds.name.data[0])) - - if "time" in observation_ds.coords: - observation_ds = observation_ds.sel(time=observation_label) - response_ds = response_ds.sel(time=observation_label) - elif "index" in observation_ds.coords: - observation_ds = observation_ds.sel(index=int(observation_label)) - response_ds = response_ds.drop(["index"]).sel( - index=int(observation_label) - ) + if not response_ds.is_empty(): + response_ds_for_label = _filter_on_observation_label(response_ds).rename( + {"values": "Responses"} + )[["response_key", "Responses"]] ax.errorbar( x="Observation", - y=observation_ds.get("observations"), # type: ignore - yerr=observation_ds.get("std"), + y=obs["observations"], + yerr=obs["std"], fmt=".", linewidth=1, capsize=4, @@ -269,20 +277,14 @@ def _try_render_scaled_obs() -> None: ) _try_render_scaled_obs() - response_ds = response_ds.rename_vars({"values": "Responses"}) - sns.boxplot(response_ds.to_dataframe(), ax=ax) - sns.stripplot(response_ds.to_dataframe(), ax=ax, size=4, color=".3") + sns.boxplot(response_ds_for_label.to_pandas(), ax=ax) + sns.stripplot(response_ds_for_label.to_pandas(), ax=ax, size=4, color=".3") else: - if "time" in observation_ds.coords: - observation_ds = observation_ds.sel(time=observation_label) - elif "index" in observation_ds.coords: - observation_ds = observation_ds.sel(index=int(observation_label)) - ax.errorbar( x="Observation", - y=observation_ds.get("observations"), # type: ignore - yerr=observation_ds.get("std"), + y=obs["observations"], + yerr=obs["std"], fmt=".", linewidth=1, capsize=4, @@ -310,32 +312,40 @@ def _currentTabChanged(self, index: int) -> None: assert self._ensemble is not None exp = self._ensemble.experiment - for obs_name, obs_ds in exp.observations.items(): - response_name = obs_ds.attrs["response"] - if response_name == "summary": - name = obs_ds.name.data[0] - else: - name = response_name - - match_list = self._observations_tree_widget.findItems( - name, Qt.MatchFlag.MatchExactly - ) - if len(match_list) == 0: - root = QTreeWidgetItem(self._observations_tree_widget, [name]) - else: - root = match_list[0] - if "time" in obs_ds.coords: - for t in obs_ds.time: + for response_type, obs_ds_for_type in exp.observations.items(): + for obs_key, response_key in ( + obs_ds_for_type.select(["observation_key", "response_key"]) + .unique() + .to_numpy() + ): + match_list = self._observations_tree_widget.findItems( + response_key, Qt.MatchFlag.MatchExactly + ) + if len(match_list) == 0: + root = QTreeWidgetItem( + self._observations_tree_widget, [response_key] + ) + else: + root = match_list[0] + + obs_ds = obs_ds_for_type.filter( + polars.col("observation_key").eq(obs_key) + ) + response_config = exp.response_configuration[response_type] + column_to_display = response_config.primary_key[-1] + for t in obs_ds[column_to_display].to_list(): QTreeWidgetItem( root, - [str(np.datetime_as_string(t.values, unit="D")), obs_name], + [ + response_config.display_column(t, column_to_display), + obs_key, + ], ) - elif "index" in obs_ds.coords: - for t in obs_ds.index: - QTreeWidgetItem(root, [str(t.data), obs_name]) - self._observations_tree_widget.sortItems(0, Qt.SortOrder.AscendingOrder) + self._observations_tree_widget.sortItems( + 0, Qt.SortOrder.AscendingOrder + ) for i in range(self._observations_tree_widget.topLevelItemCount()): item = self._observations_tree_widget.topLevelItem(i) diff --git a/src/ert/gui/tools/manage_experiments/storage_widget.py b/src/ert/gui/tools/manage_experiments/storage_widget.py index 1be6f6351c8..a13e9dcc45c 100644 --- a/src/ert/gui/tools/manage_experiments/storage_widget.py +++ b/src/ert/gui/tools/manage_experiments/storage_widget.py @@ -151,7 +151,7 @@ def _addItem(self) -> None: ensemble = self._notifier.storage.create_experiment( parameters=self._ert_config.ensemble_config.parameter_configuration, responses=self._ert_config.ensemble_config.response_configuration, - observations=self._ert_config.observations, + observations=self._ert_config.enkf_obs.datasets, name=create_experiment_dialog.experiment_name, ).create_ensemble( name=create_experiment_dialog.ensemble_name,