diff --git a/training/CHANGELOG.md b/training/CHANGELOG.md index 19160c59..7a505e40 100644 --- a/training/CHANGELOG.md +++ b/training/CHANGELOG.md @@ -10,6 +10,10 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.2...HEAD) +### Add + +- Order parameters in loss-plot by pressure/model level if possible [#55](https://github.com/ecmwf/anemoi-core/pull/55) + ## [0.3.2 - Multiple Fixes, Checkpoint updates, Stretched-grid/LAM updates](https://github.com/ecmwf/anemoi-training/compare/0.3.1...0.3.2) - 2024-12-19 ### Fixed diff --git a/training/src/anemoi/training/diagnostics/callbacks/plot.py b/training/src/anemoi/training/diagnostics/callbacks/plot.py index aebba10e..6140d8d3 100644 --- a/training/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/training/src/anemoi/training/diagnostics/callbacks/plot.py @@ -33,6 +33,7 @@ from pytorch_lightning.utilities import rank_zero_only from anemoi.models.layers.mapper import GraphEdgeMixin +from anemoi.training.diagnostics.plots import argsort_variablename_variablelevel from anemoi.training.diagnostics.plots import get_scatter_frame from anemoi.training.diagnostics.plots import init_plot_settings from anemoi.training.diagnostics.plots import plot_graph_edge_features @@ -739,7 +740,11 @@ def sort_and_color_by_parameter_group( def automatically_determine_group(name: str) -> str: # first prefix of parameter name is group name parts = name.split("_") - return parts[0] + if len(parts) == 1: + # if no underscore is present, return full name + return parts[0] + # else remove last part of name + return name[: -len(parts[-1]) - 1] # group parameters by their determined group name for > 15 parameters if len(self.parameter_names) <= 15: @@ -841,6 +846,11 @@ def _plot( parameter_positions = list(pl_module.data_indices.internal_model.output.name_to_index.values()) # reorder parameter_names by position self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)] + + # Sort the list using the custom key + argsort_indices = argsort_variablename_variablelevel(self.parameter_names) + self.parameter_names = [self.parameter_names[i] for i in argsort_indices] + if not isinstance(pl_module.loss, BaseWeightedLoss): LOGGER.warning( "Loss function must be a subclass of BaseWeightedLoss, or provide `squash`.", @@ -859,6 +869,7 @@ def _plot( loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy() sort_by_parameter_group, colors, xticks, legend_patches = self.sort_and_color_by_parameter_group + loss = loss[argsort_indices] fig = plot_loss(loss[sort_by_parameter_group], colors, xticks, legend_patches) self._output_figure( diff --git a/training/src/anemoi/training/diagnostics/plots.py b/training/src/anemoi/training/diagnostics/plots.py index 0ce55cdd..50e17c10 100644 --- a/training/src/anemoi/training/diagnostics/plots.py +++ b/training/src/anemoi/training/diagnostics/plots.py @@ -58,6 +58,34 @@ def equirectangular_projection(latlons: np.array) -> np.array: return pc_lat, pc_lon +def argsort_variablename_variablelevel(data: list[str]) -> list[int]: + """Custom sort key to process the strings. + + Sort parameter names by alpha part, then by numeric part at last + position (variable level) if available, then by the original string. + + Parameters + ---------- + data : list[str] + List of strings to sort. + + Returns + ------- + list[int] + Sorted indices of the input list. + """ + + def custom_sort_key(index: int) -> tuple: + s = data[index] # Access the element by index + parts = s.split("_") + alpha_part = s if len(parts) == 1 else s[: -len(parts[-1]) - 1] + numeric_part = int(parts[-1]) if len(parts) > 1 and parts[-1].isdigit() else float("inf") + return (alpha_part, numeric_part, s) + + # Generate argsort indices + return sorted(range(len(data)), key=custom_sort_key) + + def init_plot_settings() -> None: """Initialize matplotlib plot settings.""" small_font_size = 8 @@ -116,6 +144,7 @@ def plot_loss( """ # create plot # more space for legend + # TODO(who?): make figsize more flexible depending on the number of bars figsize = (8, 3) if legend_patches else (4, 3) fig, ax = plt.subplots(1, 1, figsize=figsize, layout=LAYOUT) # histogram plot