Skip to content

Commit

Permalink
Add box-plots to standard deviation plot
Browse files Browse the repository at this point in the history
- test for stddev plo

Co-authored by Feda Curic <[email protected]>
  • Loading branch information
dafeda authored and xjules committed Oct 2, 2024
1 parent 4321e49 commit 382b0ef
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 15 deletions.
78 changes: 63 additions & 15 deletions src/ert/gui/plottery/plots/std_dev.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -28,40 +28,88 @@ def plot(
if layer is not None:
vmin: float = np.inf
vmax: float = -np.inf
axes = []
images: List[npt.NDArray[np.float32]] = []
heatmaps = []
boxplot_axes = []

# Adjust height_ratios to reduce space between plots
figure.set_layout_engine("constrained")
gridspec = figure.add_gridspec(2, ensemble_count, hspace=0.2)

for i, ensemble in enumerate(plot_context.ensembles(), start=1):
ax = figure.add_subplot(1, ensemble_count, i)
axes.append(ax)
ax_heat = figure.add_subplot(gridspec[0, i - 1])
ax_box = figure.add_subplot(gridspec[1, i - 1])
data = std_dev_data[ensemble.name]
if data.size == 0:
ax.set_axis_off()
ax.text(
ax_heat.set_axis_off()
ax_box.set_axis_off()
ax_heat.text(
0.5,
0.5,
f"No data for {ensemble.experiment_name} : {ensemble.name}",
ha="center",
va="center",
)
else:
images.append(data)
vmin = min(vmin, float(np.min(data)))
vmax = max(vmax, float(np.max(data)))
ax.set_title(

im = ax_heat.imshow(data, cmap="viridis", aspect="equal")
heatmaps.append(im)

ax_box.boxplot(data.flatten(), vert=True, widths=0.5)
boxplot_axes.append(ax_box)

min_value = np.min(data)
mean_value = np.mean(data)
max_value = np.max(data)

annotation_text = f"Min: {min_value:.2f}\nMean: {mean_value:.2f}\nMax: {max_value:.2f}"
ax_box.annotate(
annotation_text,
xy=(1, 1), # Changed from (0, 1) to (1, 1)
xycoords="axes fraction",
ha="right", # Changed from 'left' to 'right'
va="top",
fontsize=8,
fontweight="bold",
bbox={
"facecolor": "white",
"edgecolor": "black",
"boxstyle": "round,pad=0.2",
},
)

ax_box.spines["top"].set_visible(False)
ax_box.spines["right"].set_visible(False)
ax_box.spines["bottom"].set_visible(False)
ax_box.spines["left"].set_visible(True)

ax_box.set_xticks([])
ax_box.set_xticklabels([])

ax_heat.set_ylabel("")
ax_box.set_ylabel(
"Standard Deviation", fontsize=8
) # Reduced font size

self._colorbar(im)

ax_heat.set_title(
f"{ensemble.experiment_name} : {ensemble.name} layer={layer}",
wrap=True,
fontsize=10, # Reduced font size
)

norm = plt.Normalize(vmin, vmax)
for ax, data in zip(axes, images):
if data is not None:
im = ax.imshow(data, norm=norm, cmap="viridis")
self._colorbar(im)
figure.tight_layout()
for im in heatmaps:
im.set_norm(norm)

padding = 0.05 * (vmax - vmin)
for ax_box in boxplot_axes:
ax_box.set_ylim(vmin - padding, vmax + padding)

@staticmethod
def _colorbar(mappable: Any) -> Any:
# https://joseph-long.com/writing/colorbars/
last_axes = plt.gca()
ax = mappable.axes
assert ax is not None
Expand Down
49 changes: 49 additions & 0 deletions tests/ert/unit_tests/gui/plottery/test_stddev_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from unittest.mock import Mock

import matplotlib.pyplot as plt
import numpy as np
import pytest
from matplotlib.figure import Figure

from ert.gui.plottery import PlotConfig, PlotContext
from ert.gui.plottery.plots.std_dev import StdDevPlot
from ert.gui.tools.plot.plot_api import EnsembleObject


@pytest.fixture()
def plot_context(request):
context = Mock(spec=PlotContext)
context.ensembles.return_value = [
EnsembleObject("ensemble_1", "id", False, "experiment_1")
]
context.history_data = None
context.layer = 0
context.plotConfig.return_value = PlotConfig(title="StdDev Plot")
return context


def test_stddev_plot_shows_boxplot(plot_context: PlotContext):
rng = np.random.default_rng()
figure = Figure()
std_dev_data = rng.random((5, 5))
StdDevPlot().plot(
figure,
plot_context,
{},
{},
{"ensemble_1": std_dev_data},
)
ax = figure.axes
assert ax[0].get_title() == "experiment_1 : ensemble_1 layer=0"
assert ax[1].get_ylabel() == "Standard Deviation"
annotation = [
child for child in ax[1].get_children() if isinstance(child, plt.Annotation)
]
assert len(annotation) == 1
min_value = np.min(std_dev_data)
mean_value = np.mean(std_dev_data)
max_value = np.max(std_dev_data)
assert (
annotation[0].get_text()
== f"Min: {min_value:.2f}\nMean: {mean_value:.2f}\nMax: {max_value:.2f}"
)

0 comments on commit 382b0ef

Please sign in to comment.