-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training reorder parameter names for plot #55
base: develop
Are you sure you want to change the base?
Changes from 4 commits
e0bd75d
5045e7f
0f10431
636b5f0
940f6eb
b03331e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -739,7 +739,11 @@ def sort_and_color_by_parameter_group( | |
def automatically_determine_group(name: str) -> str: | ||
# first prefix of parameter name is group name | ||
parts = name.split("_") | ||
return parts[0] | ||
if len(parts) == 1: | ||
# if no underscore is present, return full name | ||
return parts[0] | ||
# else remove last part of name | ||
return name[: -len(parts[-1]) - 1] | ||
|
||
# group parameters by their determined group name for > 15 parameters | ||
if len(self.parameter_names) <= 15: | ||
|
@@ -824,6 +828,24 @@ def automatically_determine_group(name: str) -> str: | |
legend_patches, | ||
) | ||
|
||
def argsort_name_variablelevel(self) -> list[int]: | ||
"""Custom sort key to process the strings. | ||
|
||
Sort parameter names by alpha part, then by numeric part at last | ||
position (variable level), then by the original string. | ||
""" | ||
data = self.parameter_names | ||
|
||
def custom_sort_key(index: int) -> tuple: | ||
s = data[index] # Access the element by index | ||
parts = s.split("_") | ||
alpha_part = s if len(parts) == 1 else s[: -len(parts[-1]) - 1] | ||
numeric_part = int(parts[-1]) if len(parts) > 1 and parts[-1].isdigit() else float("inf") | ||
return (alpha_part, numeric_part, s) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you need to return s here too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If |
||
|
||
# Generate argsort indices | ||
return sorted(range(len(data)), key=custom_sort_key) | ||
|
||
@rank_zero_only | ||
def _plot( | ||
self, | ||
|
@@ -841,6 +863,11 @@ def _plot( | |
parameter_positions = list(pl_module.data_indices.internal_model.output.name_to_index.values()) | ||
# reorder parameter_names by position | ||
self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)] | ||
|
||
# Sort the list using the custom key | ||
argsort_indices = self.argsort_name_variablelevel() | ||
self.parameter_names = [self.parameter_names[i] for i in argsort_indices] | ||
|
||
if not isinstance(pl_module.loss, BaseWeightedLoss): | ||
LOGGER.warning( | ||
"Loss function must be a subclass of BaseWeightedLoss, or provide `squash`.", | ||
|
@@ -859,6 +886,7 @@ def _plot( | |
loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy() | ||
|
||
sort_by_parameter_group, colors, xticks, legend_patches = self.sort_and_color_by_parameter_group | ||
loss = loss[argsort_indices] | ||
fig = plot_loss(loss[sort_by_parameter_group], colors, xticks, legend_patches) | ||
|
||
self._output_figure( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be worth putting this in to a callback or loss utils? I can see this function being useful for other callbacks in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's right. Will do that.