Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training reorder parameter names for plot #55

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions training/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.2...HEAD)

### Add

- Order parameters in loss-plot by pressure/model level if possible [#55](https://github.com/ecmwf/anemoi-core/pull/55)

## [0.3.2 - Multiple Fixes, Checkpoint updates, Stretched-grid/LAM updates](https://github.com/ecmwf/anemoi-training/compare/0.3.1...0.3.2) - 2024-12-19

### Fixed
Expand Down
30 changes: 29 additions & 1 deletion training/src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,11 @@ def sort_and_color_by_parameter_group(
def automatically_determine_group(name: str) -> str:
# first prefix of parameter name is group name
parts = name.split("_")
return parts[0]
if len(parts) == 1:
# if no underscore is present, return full name
return parts[0]
# else remove last part of name
return name[: -len(parts[-1]) - 1]

# group parameters by their determined group name for > 15 parameters
if len(self.parameter_names) <= 15:
Expand Down Expand Up @@ -824,6 +828,24 @@ def automatically_determine_group(name: str) -> str:
legend_patches,
)

def argsort_name_variablelevel(self) -> list[int]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be worth putting this in to a callback or loss utils? I can see this function being useful for other callbacks in the future

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right. Will do that.

"""Custom sort key to process the strings.

Sort parameter names by alpha part, then by numeric part at last
position (variable level), then by the original string.
"""
data = self.parameter_names

def custom_sort_key(index: int) -> tuple:
s = data[index] # Access the element by index
parts = s.split("_")
alpha_part = s if len(parts) == 1 else s[: -len(parts[-1]) - 1]
numeric_part = int(parts[-1]) if len(parts) > 1 and parts[-1].isdigit() else float("inf")
return (alpha_part, numeric_part, s)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to return s here too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If alpha_part and numeric part are the same for two different variables, the entire variable name s is used for sorting.


# Generate argsort indices
return sorted(range(len(data)), key=custom_sort_key)

@rank_zero_only
def _plot(
self,
Expand All @@ -841,6 +863,11 @@ def _plot(
parameter_positions = list(pl_module.data_indices.internal_model.output.name_to_index.values())
# reorder parameter_names by position
self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)]

# Sort the list using the custom key
argsort_indices = self.argsort_name_variablelevel()
self.parameter_names = [self.parameter_names[i] for i in argsort_indices]

if not isinstance(pl_module.loss, BaseWeightedLoss):
LOGGER.warning(
"Loss function must be a subclass of BaseWeightedLoss, or provide `squash`.",
Expand All @@ -859,6 +886,7 @@ def _plot(
loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy()

sort_by_parameter_group, colors, xticks, legend_patches = self.sort_and_color_by_parameter_group
loss = loss[argsort_indices]
fig = plot_loss(loss[sort_by_parameter_group], colors, xticks, legend_patches)

self._output_figure(
Expand Down
1 change: 1 addition & 0 deletions training/src/anemoi/training/diagnostics/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def plot_loss(
"""
# create plot
# more space for legend
# TODO(who?): make figsize more flexible depending on the number of bars
figsize = (8, 3) if legend_patches else (4, 3)
fig, ax = plt.subplots(1, 1, figsize=figsize, layout=LAYOUT)
# histogram plot
Expand Down
Loading