From e0bd75dd699379410aea052bb01341f3f36af2c4 Mon Sep 17 00:00:00 2001 From: sahahner Date: Fri, 22 Nov 2024 16:52:46 +0000 Subject: [PATCH 1/5] reorder-parameter-names-for-plot --- .../training/diagnostics/callbacks/plot.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index b13d8727..aa63cc2a 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -788,6 +788,24 @@ def automatically_determine_group(name: str) -> str: legend_patches, ) + def argsort_name_pressurelevel(self) -> list[int]: + """Custom sort key to process the strings. + + Sort parameter names by alpha part, then by numeric part at second + position (presure level), then by the original string. + """ + data = self.parameter_names + + def custom_sort_key(index: int) -> tuple: + s = data[index] # Access the element by index + parts = s.split("_") + alpha_part = parts[0] + numeric_part = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else float("inf") + return (alpha_part, numeric_part, s) + + # Generate argsort indices + return sorted(range(len(data)), key=custom_sort_key) + @rank_zero_only def _plot( self, @@ -805,6 +823,11 @@ def _plot( parameter_positions = list(pl_module.data_indices.internal_model.output.name_to_index.values()) # reorder parameter_names by position self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)] + + # Sort the list using the custom key + argsort_indices = self.argsort_name_pressurelevel() + self.parameter_names = [self.parameter_names[i] for i in argsort_indices] + if not isinstance(pl_module.loss, BaseWeightedLoss): logging.warning( "Loss function must be a subclass of BaseWeightedLoss, or provide `squash`.", @@ -823,6 +846,7 @@ def _plot( loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy() sort_by_parameter_group, colors, xticks, legend_patches = self.sort_and_color_by_parameter_group + loss = loss[argsort_indices] fig = plot_loss(loss[sort_by_parameter_group], colors, xticks, legend_patches) self._output_figure( From 0f10431b9ecfdf06cb64a23d40507e2d8e5cbb50 Mon Sep 17 00:00:00 2001 From: sahahner Date: Mon, 30 Dec 2024 16:31:11 +0000 Subject: [PATCH 2/5] use last part of variable name as possible variable level for sorting --- .../training/diagnostics/callbacks/plot.py | 18 +++++++++++------- .../src/anemoi/training/diagnostics/plots.py | 1 + 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/training/src/anemoi/training/diagnostics/callbacks/plot.py b/training/src/anemoi/training/diagnostics/callbacks/plot.py index 32a71f2b..fee21faa 100644 --- a/training/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/training/src/anemoi/training/diagnostics/callbacks/plot.py @@ -739,7 +739,11 @@ def sort_and_color_by_parameter_group( def automatically_determine_group(name: str) -> str: # first prefix of parameter name is group name parts = name.split("_") - return parts[0] + if len(parts) == 1: + # if no underscore is present, return full name + return parts[0] + # else remove last part of name + return name[: -len(parts[-1]) - 1] # group parameters by their determined group name for > 15 parameters if len(self.parameter_names) <= 15: @@ -824,19 +828,19 @@ def automatically_determine_group(name: str) -> str: legend_patches, ) - def argsort_name_pressurelevel(self) -> list[int]: + def argsort_name_variablelevel(self) -> list[int]: """Custom sort key to process the strings. - Sort parameter names by alpha part, then by numeric part at second - position (presure level), then by the original string. + Sort parameter names by alpha part, then by numeric part at last + position (variable level), then by the original string. """ data = self.parameter_names def custom_sort_key(index: int) -> tuple: s = data[index] # Access the element by index parts = s.split("_") - alpha_part = parts[0] - numeric_part = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else float("inf") + alpha_part = s if len(parts) == 1 else s[: -len(parts[-1]) - 1] + numeric_part = int(parts[-1]) if len(parts) > 1 and parts[-1].isdigit() else float("inf") return (alpha_part, numeric_part, s) # Generate argsort indices @@ -861,7 +865,7 @@ def _plot( self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)] # Sort the list using the custom key - argsort_indices = self.argsort_name_pressurelevel() + argsort_indices = self.argsort_name_variablelevel() self.parameter_names = [self.parameter_names[i] for i in argsort_indices] if not isinstance(pl_module.loss, BaseWeightedLoss): diff --git a/training/src/anemoi/training/diagnostics/plots.py b/training/src/anemoi/training/diagnostics/plots.py index 0ce55cdd..8556ca82 100644 --- a/training/src/anemoi/training/diagnostics/plots.py +++ b/training/src/anemoi/training/diagnostics/plots.py @@ -116,6 +116,7 @@ def plot_loss( """ # create plot # more space for legend + # TODO(who?): make figsize more flexible depending on the number of bars figsize = (8, 3) if legend_patches else (4, 3) fig, ax = plt.subplots(1, 1, figsize=figsize, layout=LAYOUT) # histogram plot From 636b5f009bf40eea1f1eb5ec787fe3e8b0c12e4a Mon Sep 17 00:00:00 2001 From: sahahner Date: Mon, 30 Dec 2024 16:38:47 +0000 Subject: [PATCH 3/5] changelog --- training/CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/training/CHANGELOG.md b/training/CHANGELOG.md index 19160c59..7a505e40 100644 --- a/training/CHANGELOG.md +++ b/training/CHANGELOG.md @@ -10,6 +10,10 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.2...HEAD) +### Add + +- Order parameters in loss-plot by pressure/model level if possible [#55](https://github.com/ecmwf/anemoi-core/pull/55) + ## [0.3.2 - Multiple Fixes, Checkpoint updates, Stretched-grid/LAM updates](https://github.com/ecmwf/anemoi-training/compare/0.3.1...0.3.2) - 2024-12-19 ### Fixed From 940f6ebf284cff649eb48726fe97773e23fa55a2 Mon Sep 17 00:00:00 2001 From: sahahner Date: Tue, 31 Dec 2024 11:15:12 +0000 Subject: [PATCH 4/5] argsort variable names into plots.py to make available for other plot functions. --- .../training/diagnostics/callbacks/plot.py | 21 ++------------ .../src/anemoi/training/diagnostics/plots.py | 28 +++++++++++++++++++ 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/training/src/anemoi/training/diagnostics/callbacks/plot.py b/training/src/anemoi/training/diagnostics/callbacks/plot.py index fee21faa..4dace095 100644 --- a/training/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/training/src/anemoi/training/diagnostics/callbacks/plot.py @@ -33,6 +33,7 @@ from pytorch_lightning.utilities import rank_zero_only from anemoi.models.layers.mapper import GraphEdgeMixin +from anemoi.training.diagnostics.plots import argsort_name_variablelevel from anemoi.training.diagnostics.plots import get_scatter_frame from anemoi.training.diagnostics.plots import init_plot_settings from anemoi.training.diagnostics.plots import plot_graph_edge_features @@ -828,24 +829,6 @@ def automatically_determine_group(name: str) -> str: legend_patches, ) - def argsort_name_variablelevel(self) -> list[int]: - """Custom sort key to process the strings. - - Sort parameter names by alpha part, then by numeric part at last - position (variable level), then by the original string. - """ - data = self.parameter_names - - def custom_sort_key(index: int) -> tuple: - s = data[index] # Access the element by index - parts = s.split("_") - alpha_part = s if len(parts) == 1 else s[: -len(parts[-1]) - 1] - numeric_part = int(parts[-1]) if len(parts) > 1 and parts[-1].isdigit() else float("inf") - return (alpha_part, numeric_part, s) - - # Generate argsort indices - return sorted(range(len(data)), key=custom_sort_key) - @rank_zero_only def _plot( self, @@ -865,7 +848,7 @@ def _plot( self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)] # Sort the list using the custom key - argsort_indices = self.argsort_name_variablelevel() + argsort_indices = argsort_name_variablelevel(self.parameter_names) self.parameter_names = [self.parameter_names[i] for i in argsort_indices] if not isinstance(pl_module.loss, BaseWeightedLoss): diff --git a/training/src/anemoi/training/diagnostics/plots.py b/training/src/anemoi/training/diagnostics/plots.py index 8556ca82..d97cdf3f 100644 --- a/training/src/anemoi/training/diagnostics/plots.py +++ b/training/src/anemoi/training/diagnostics/plots.py @@ -58,6 +58,34 @@ def equirectangular_projection(latlons: np.array) -> np.array: return pc_lat, pc_lon +def argsort_name_variablelevel(data: list[str]) -> list[int]: + """Custom sort key to process the strings. + + Sort parameter names by alpha part, then by numeric part at last + position (variable level) if available, then by the original string. + + Parameters + ---------- + data : list[str] + List of strings to sort. + + Returns + ------- + list[int] + Sorted indices of the input list. + """ + + def custom_sort_key(index: int) -> tuple: + s = data[index] # Access the element by index + parts = s.split("_") + alpha_part = s if len(parts) == 1 else s[: -len(parts[-1]) - 1] + numeric_part = int(parts[-1]) if len(parts) > 1 and parts[-1].isdigit() else float("inf") + return (alpha_part, numeric_part, s) + + # Generate argsort indices + return sorted(range(len(data)), key=custom_sort_key) + + def init_plot_settings() -> None: """Initialize matplotlib plot settings.""" small_font_size = 8 From b03331e927e5c1a6c2b52d5bb8c5ea76959dee04 Mon Sep 17 00:00:00 2001 From: sahahner Date: Tue, 31 Dec 2024 11:16:49 +0000 Subject: [PATCH 5/5] change function name to argsort_variablename_variablelevel --- training/src/anemoi/training/diagnostics/callbacks/plot.py | 4 ++-- training/src/anemoi/training/diagnostics/plots.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/training/src/anemoi/training/diagnostics/callbacks/plot.py b/training/src/anemoi/training/diagnostics/callbacks/plot.py index 4dace095..6140d8d3 100644 --- a/training/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/training/src/anemoi/training/diagnostics/callbacks/plot.py @@ -33,7 +33,7 @@ from pytorch_lightning.utilities import rank_zero_only from anemoi.models.layers.mapper import GraphEdgeMixin -from anemoi.training.diagnostics.plots import argsort_name_variablelevel +from anemoi.training.diagnostics.plots import argsort_variablename_variablelevel from anemoi.training.diagnostics.plots import get_scatter_frame from anemoi.training.diagnostics.plots import init_plot_settings from anemoi.training.diagnostics.plots import plot_graph_edge_features @@ -848,7 +848,7 @@ def _plot( self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)] # Sort the list using the custom key - argsort_indices = argsort_name_variablelevel(self.parameter_names) + argsort_indices = argsort_variablename_variablelevel(self.parameter_names) self.parameter_names = [self.parameter_names[i] for i in argsort_indices] if not isinstance(pl_module.loss, BaseWeightedLoss): diff --git a/training/src/anemoi/training/diagnostics/plots.py b/training/src/anemoi/training/diagnostics/plots.py index d97cdf3f..50e17c10 100644 --- a/training/src/anemoi/training/diagnostics/plots.py +++ b/training/src/anemoi/training/diagnostics/plots.py @@ -58,7 +58,7 @@ def equirectangular_projection(latlons: np.array) -> np.array: return pc_lat, pc_lon -def argsort_name_variablelevel(data: list[str]) -> list[int]: +def argsort_variablename_variablelevel(data: list[str]) -> list[int]: """Custom sort key to process the strings. Sort parameter names by alpha part, then by numeric part at last