Skip to content

Commit

Permalink
Use polars logic in storage_info_widget
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen authored and yngve-sk committed Sep 30, 2024
1 parent edcd7ac commit 72a1fc0
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 88 deletions.
5 changes: 5 additions & 0 deletions src/ert/config/response_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Optional[Self]:
"""Creates a config, given an ert config dict.
A response config may depend on several config kws, such as REFCASE
for summary."""

@classmethod
def display_column(cls, value: Any, column_name: str) -> str:
"""Formats a value to a user-friendly displayable format."""
return str(value)
9 changes: 8 additions & 1 deletion src/ert/config/summary_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Optional, Set, Union
from typing import TYPE_CHECKING, Any, Optional, Set, Union

from ._read_summary import read_summary
from .ensemble_config import Refcase
Expand Down Expand Up @@ -88,5 +88,12 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Optional[SummaryConfig]:

return None

@classmethod
def display_column(cls, value: Any, column_name: str) -> str:
if column_name == "time":
return value.strftime("%Y-%m-%d")

return str(value)


responses_index.add_response_type(SummaryConfig)
182 changes: 96 additions & 86 deletions src/ert/gui/tools/manage_experiments/storage_info_widget.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import json
from enum import IntEnum
from functools import reduce
from typing import Optional

import numpy as np
import polars
import seaborn as sns
import yaml
from matplotlib.backends.backend_qt5agg import FigureCanvas # type: ignore
Expand Down Expand Up @@ -176,8 +175,8 @@ def _currentItemChanged(
if not selected:
return

observation_name = selected.data(1, Qt.ItemDataRole.DisplayRole)
if not observation_name:
observation_key = selected.data(1, Qt.ItemDataRole.DisplayRole)
if not observation_key:
return

observation_label = selected.data(0, Qt.ItemDataRole.DisplayRole)
Expand All @@ -186,103 +185,106 @@ def _currentItemChanged(

self._figure.clear()
ax = self._figure.add_subplot(111)
ax.set_title(observation_name)
ax.set_title(observation_key)
ax.grid(True)

observation_ds = observations_dict[observation_name]
response_type, obs_for_type = next(
(
(response_type, _df)
for response_type, _df in observations_dict.items()
if observation_key in _df["observation_key"]
),
(None, None),
)

assert response_type is not None
assert obs_for_type is not None

response_config = self._ensemble.experiment.response_configuration[
response_type
]
x_axis_col = response_config.primary_key[-1]

def _filter_on_observation_label(df: polars.DataFrame) -> polars.DataFrame:
# We add a column with the display name of the x axis column
# to correctly compare it to the observation_label
# (which is also such a display name)
return df.with_columns(
df[x_axis_col]
.map_elements(
lambda x: response_config.display_column(x, x_axis_col),
return_dtype=polars.String,
)
.alias("temp")
).filter(polars.col("temp").eq(observation_label))[
[x for x in df.columns if x != "temp"]
]

response_name = observation_ds.attrs["response"]
obs = obs_for_type.filter(polars.col("observation_key").eq(observation_key))
obs = _filter_on_observation_label(obs)

response_key = obs["response_key"].unique().to_list()[0]
response_ds = self._ensemble.load_responses(
response_name,
response_key,
tuple(self._ensemble.get_realization_list_with_responses()),
)

scaling_ds = self._ensemble.load_observation_scaling_factors()
scaling_df = self._ensemble.load_observation_scaling_factors()

def _try_render_scaled_obs() -> None:
if scaling_ds is None:
if scaling_df is None:
return None

_obs_df = observation_ds.to_dataframe().reset_index()
# Should store scaling by response type
# and use primary key for the response to
# create the index key.
index_cols = list(
set(observation_ds.to_dataframe().reset_index().columns)
- {
"observations",
"std",
}
index_col = polars.concat_str(response_config.primary_key, separator=", ")
joined = obs.with_columns(index_col.alias("_tmp_index")).join(
scaling_df,
how="left",
left_on=["observation_key", "_tmp_index"],
right_on=["obs_key", "index"],
)[["observations", "std", "scaling_factor"]]

joined_small = joined[["observations", "std", "scaling_factor"]]
joined_small = joined_small.group_by(["observations", "std"]).agg(
[polars.col("scaling_factor").product()]
)

# Just to ensure report step comes first as in _es_update
if "report_step" in index_cols:
index_cols = ["report_step", "index"]

# for summary there is only "time"
index_key = ", ".join([str(_obs_df[x].values[0]) for x in index_cols])
scaling_factors = (
scaling_ds.sel(obs_key=observation_name, drop=True)
.sel(index=index_key, drop=True)["scaling_factor"]
.dropna(dim="input_group")
.values
joined_small = joined_small.with_columns(
(joined_small["std"] * joined_small["scaling_factor"])
)

cumulative_scaling: float = (
reduce(lambda x, y: x * y, scaling_factors) or 1.0
)

original_std = observation_ds.get("std")
assert original_std
ax.errorbar(
x="Scaled observation",
y=observation_ds.get("observations"), # type: ignore
yerr=original_std * cumulative_scaling,
y=joined_small["observations"].to_list(),
yerr=joined_small["std"].to_list(),
fmt=".",
linewidth=1,
capsize=4,
color="black",
)

# check if the response is empty
if bool(response_ds.dims):
if response_name == "summary":
response_ds = response_ds.sel(name=str(observation_ds.name.data[0]))

if "time" in observation_ds.coords:
observation_ds = observation_ds.sel(time=observation_label)
response_ds = response_ds.sel(time=observation_label)
elif "index" in observation_ds.coords:
observation_ds = observation_ds.sel(index=int(observation_label))
response_ds = response_ds.drop(["index"]).sel(
index=int(observation_label)
)
if not response_ds.is_empty():
response_ds_for_label = _filter_on_observation_label(response_ds).rename(
{"values": "Responses"}
)[["response_key", "Responses"]]

ax.errorbar(
x="Observation",
y=observation_ds.get("observations"), # type: ignore
yerr=observation_ds.get("std"),
y=obs["observations"],
yerr=obs["std"],
fmt=".",
linewidth=1,
capsize=4,
color="black",
)
_try_render_scaled_obs()

response_ds = response_ds.rename_vars({"values": "Responses"})
sns.boxplot(response_ds.to_dataframe(), ax=ax)
sns.stripplot(response_ds.to_dataframe(), ax=ax, size=4, color=".3")
sns.boxplot(response_ds_for_label.to_pandas(), ax=ax)
sns.stripplot(response_ds_for_label.to_pandas(), ax=ax, size=4, color=".3")

else:
if "time" in observation_ds.coords:
observation_ds = observation_ds.sel(time=observation_label)
elif "index" in observation_ds.coords:
observation_ds = observation_ds.sel(index=int(observation_label))

ax.errorbar(
x="Observation",
y=observation_ds.get("observations"), # type: ignore
yerr=observation_ds.get("std"),
y=obs["observations"],
yerr=obs["std"],
fmt=".",
linewidth=1,
capsize=4,
Expand Down Expand Up @@ -310,32 +312,40 @@ def _currentTabChanged(self, index: int) -> None:

assert self._ensemble is not None
exp = self._ensemble.experiment
for obs_name, obs_ds in exp.observations.items():
response_name = obs_ds.attrs["response"]
if response_name == "summary":
name = obs_ds.name.data[0]
else:
name = response_name

match_list = self._observations_tree_widget.findItems(
name, Qt.MatchFlag.MatchExactly
)
if len(match_list) == 0:
root = QTreeWidgetItem(self._observations_tree_widget, [name])
else:
root = match_list[0]

if "time" in obs_ds.coords:
for t in obs_ds.time:
for response_type, obs_ds_for_type in exp.observations.items():
for obs_key, response_key in (
obs_ds_for_type.select(["observation_key", "response_key"])
.unique()
.to_numpy()
):
match_list = self._observations_tree_widget.findItems(
response_key, Qt.MatchFlag.MatchExactly
)
if len(match_list) == 0:
root = QTreeWidgetItem(
self._observations_tree_widget, [response_key]
)
else:
root = match_list[0]

obs_ds = obs_ds_for_type.filter(
polars.col("observation_key").eq(obs_key)
)
response_config = exp.response_configuration[response_type]
column_to_display = response_config.primary_key[-1]
for t in obs_ds[column_to_display].to_list():
QTreeWidgetItem(
root,
[str(np.datetime_as_string(t.values, unit="D")), obs_name],
[
response_config.display_column(t, column_to_display),
obs_key,
],
)
elif "index" in obs_ds.coords:
for t in obs_ds.index:
QTreeWidgetItem(root, [str(t.data), obs_name])

self._observations_tree_widget.sortItems(0, Qt.SortOrder.AscendingOrder)
self._observations_tree_widget.sortItems(
0, Qt.SortOrder.AscendingOrder
)

for i in range(self._observations_tree_widget.topLevelItemCount()):
item = self._observations_tree_widget.topLevelItem(i)
Expand Down
2 changes: 1 addition & 1 deletion src/ert/gui/tools/manage_experiments/storage_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _addItem(self) -> None:
ensemble = self._notifier.storage.create_experiment(
parameters=self._ert_config.ensemble_config.parameter_configuration,
responses=self._ert_config.ensemble_config.response_configuration,
observations=self._ert_config.observations,
observations=self._ert_config.enkf_obs.datasets,
name=create_experiment_dialog.experiment_name,
).create_ensemble(
name=create_experiment_dialog.ensemble_name,
Expand Down

0 comments on commit 72a1fc0

Please sign in to comment.