Skip to content

Commit

Permalink
Collapse all realization legend entries into one for each ensemble.
Browse files Browse the repository at this point in the history
  • Loading branch information
DanSava committed Aug 13, 2024
1 parent 5b8b6cc commit 8729cbc
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
4 changes: 2 additions & 2 deletions tests/plots/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def test_realizations_plot_representation():
realization_df, x_axis, assets.ERTSTYLE["ensemble-selector"]["color_wheel"][0]
)
assert len(plots) == 20
for plot in plots:
for idx, plot in enumerate(plots):
np.testing.assert_equal(x_axis, plot.repr.x)
np.testing.assert_equal(plot.repr.y, realization_df[plot.name].values)
np.testing.assert_equal(plot.repr.y, realization_df[idx].values)


def test_realizations_statistics_plot_representation():
Expand Down
13 changes: 9 additions & 4 deletions webviz_ert/controllers/multi_response_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def _get_realizations_plots(
x_axis=x_axis,
y_axis=realizations_df[realization].values,
text=f"Realization: {realization} Ensemble: {ensemble_name}",
name=realization,
name=ensemble_name,
legendgroup=ensemble_name,
showlegend=False if idx > 0 else True,
**_style,
)
realizations_data.append(plot)
Expand Down Expand Up @@ -81,7 +83,9 @@ def _get_realizations_statistics_plots(


def _get_observation_plots(
observation_df: pd.DataFrame, metadata: Optional[List[str]] = None
observation_df: pd.DataFrame,
metadata: Optional[List[str]] = None,
ensemble: str = "",
) -> PlotModel:
data = observation_df["values"]
stds = observation_df["std"]
Expand All @@ -97,7 +101,7 @@ def _get_observation_plots(
x_axis=x_axis,
y_axis=data,
text=attributes,
name="Observation",
name=f"Observation_{ensemble}",
error_y=dict(
type="data", # value of error bar given in data coordinates
array=stds.values,
Expand Down Expand Up @@ -135,7 +139,8 @@ def _create_response_plot(
)
if response.observations:
observations = [
_get_observation_plots(obs.data_df()) for obs in response.observations
_get_observation_plots(obs.data_df(), ensemble=ensemble_name)
for obs in response.observations
]
else:
observations = []
Expand Down
9 changes: 7 additions & 2 deletions webviz_ert/models/plot_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,18 @@ class PlotModel:
def __init__(self, **kwargs: Any):
self._x_axis = kwargs["x_axis"]
self._y_axis = kwargs["y_axis"]
self._text = kwargs["text"] if "text" in kwargs else None
self._text = kwargs.get("text")
self._name = kwargs["name"]
self._mode = kwargs["mode"]
self._line = kwargs["line"]
self._marker = kwargs["marker"]
self._error_y = kwargs.get("error_y")
self._hoverlabel = kwargs.get("hoverlabel")
self._meta = kwargs["meta"] if "meta" in kwargs else None
self._meta = kwargs.get("meta")
self._xaxis = kwargs.get("xaxis")
self.selected = True
self.legendgroup = kwargs.get("legendgroup")
self.showlegend = kwargs.get("showlegend", True)

@property
def repr(self) -> Union[go.Scattergl, go.Scatter]:
Expand All @@ -224,7 +226,10 @@ def repr(self) -> Union[go.Scattergl, go.Scatter]:
connectgaps=True,
hoverlabel=self._hoverlabel,
meta=self._meta,
showlegend=self.showlegend,
)
if self.legendgroup:
repr_dict["legendgroup"] = self.legendgroup
if self._line:
repr_dict["line"] = self._line
if self._marker:
Expand Down

0 comments on commit 8729cbc

Please sign in to comment.