Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training reorder parameter names for plot #55

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions training/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.2...HEAD)

### Add

- Order parameters in loss-plot by pressure/model level if possible [#55](https://github.com/ecmwf/anemoi-core/pull/55)

## [0.3.2 - Multiple Fixes, Checkpoint updates, Stretched-grid/LAM updates](https://github.com/ecmwf/anemoi-training/compare/0.3.1...0.3.2) - 2024-12-19

### Fixed
Expand Down
13 changes: 12 additions & 1 deletion training/src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pytorch_lightning.utilities import rank_zero_only

from anemoi.models.layers.mapper import GraphEdgeMixin
from anemoi.training.diagnostics.plots import argsort_variablename_variablelevel
from anemoi.training.diagnostics.plots import get_scatter_frame
from anemoi.training.diagnostics.plots import init_plot_settings
from anemoi.training.diagnostics.plots import plot_graph_edge_features
Expand Down Expand Up @@ -739,7 +740,11 @@ def sort_and_color_by_parameter_group(
def automatically_determine_group(name: str) -> str:
# first prefix of parameter name is group name
parts = name.split("_")
return parts[0]
if len(parts) == 1:
# if no underscore is present, return full name
return parts[0]
# else remove last part of name
return name[: -len(parts[-1]) - 1]

# group parameters by their determined group name for > 15 parameters
if len(self.parameter_names) <= 15:
Expand Down Expand Up @@ -841,6 +846,11 @@ def _plot(
parameter_positions = list(pl_module.data_indices.internal_model.output.name_to_index.values())
# reorder parameter_names by position
self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)]

# Sort the list using the custom key
argsort_indices = argsort_variablename_variablelevel(self.parameter_names)
self.parameter_names = [self.parameter_names[i] for i in argsort_indices]

if not isinstance(pl_module.loss, BaseWeightedLoss):
LOGGER.warning(
"Loss function must be a subclass of BaseWeightedLoss, or provide `squash`.",
Expand All @@ -859,6 +869,7 @@ def _plot(
loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy()

sort_by_parameter_group, colors, xticks, legend_patches = self.sort_and_color_by_parameter_group
loss = loss[argsort_indices]
fig = plot_loss(loss[sort_by_parameter_group], colors, xticks, legend_patches)

self._output_figure(
Expand Down
29 changes: 29 additions & 0 deletions training/src/anemoi/training/diagnostics/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,34 @@ def equirectangular_projection(latlons: np.array) -> np.array:
return pc_lat, pc_lon


def argsort_variablename_variablelevel(data: list[str]) -> list[int]:
"""Custom sort key to process the strings.

Sort parameter names by alpha part, then by numeric part at last
position (variable level) if available, then by the original string.

Parameters
----------
data : list[str]
List of strings to sort.

Returns
-------
list[int]
Sorted indices of the input list.
"""

def custom_sort_key(index: int) -> tuple:
s = data[index] # Access the element by index
parts = s.split("_")
alpha_part = s if len(parts) == 1 else s[: -len(parts[-1]) - 1]
numeric_part = int(parts[-1]) if len(parts) > 1 and parts[-1].isdigit() else float("inf")
return (alpha_part, numeric_part, s)

# Generate argsort indices
return sorted(range(len(data)), key=custom_sort_key)


def init_plot_settings() -> None:
"""Initialize matplotlib plot settings."""
small_font_size = 8
Expand Down Expand Up @@ -116,6 +144,7 @@ def plot_loss(
"""
# create plot
# more space for legend
# TODO(who?): make figsize more flexible depending on the number of bars
figsize = (8, 3) if legend_patches else (4, 3)
fig, ax = plt.subplots(1, 1, figsize=figsize, layout=LAYOUT)
# histogram plot
Expand Down
Loading