From 85161819b86b914d291516e245762a9d8135459a Mon Sep 17 00:00:00 2001 From: joyce-chen-uni Date: Wed, 28 Aug 2024 21:28:49 -0700 Subject: [PATCH] Align logged attributes for errors and run metadata in kill_loss_spike_callback.py (#1494) --- .../callbacks/kill_loss_spike_callback.py | 35 ++++++++++++++----- llmfoundry/utils/exceptions.py | 8 +++-- tests/utils/test_exceptions.py | 2 ++ 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index b0a92c85e5..7d2a493fa5 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -109,11 +109,11 @@ def _detect_high_losses(self, current_step: int) -> bool: return is_high_loss - def _log_metadata(self, logger: Logger, key: str, message: str) -> None: + def _log_metadata(self, logger: Logger, key: str, value: dict) -> None: for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): destination.log_metadata({ - key: message, + key: value, 'loss_window': list(self.loss_window), }) @@ -122,22 +122,39 @@ def _handle_loss_spike( logger: Logger, running_loss_avg: float, ) -> None: - message = f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.' - self._log_metadata(logger, 'loss_spike', message) - if not self.log_only: + if self.log_only: + self._log_metadata( + logger, + 'loss_spike', + { + 'outlier_multiplier': self.outlier_multiplier, + 'running_loss_avg': running_loss_avg, + 'outlier_counter': self.outlier_counter, + }, + ) + else: raise LossSpikeError( outlier_multiplier=self.outlier_multiplier, - running_loss_avg=round(running_loss_avg), + running_loss_avg=running_loss_avg, outlier_counter=self.outlier_counter, + loss_window=list(self.loss_window), ) def _handle_high_losses(self, logger: Logger) -> None: - message = f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.' - self._log_metadata(logger, 'high_loss', message) - if not self.log_only: + if self.log_only: + self._log_metadata( + logger, + 'high_loss', + { + 'loss_cap': self.loss_cap, + 'window_size': self.window_size, + }, + ) + else: raise HighLossError( loss_cap=self.loss_cap, window_size=self.window_size, + loss_window=list(self.loss_window), ) def _set_window_size(self, state: State) -> None: diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 73951ef19e..c6cb6401b0 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -395,11 +395,12 @@ class LossSpikeError(UserError): def __init__( self, outlier_multiplier: float, - running_loss_avg: int, + running_loss_avg: float, outlier_counter: int, + loss_window: list[float], ) -> None: message = f'Training stopped due to a loss spike. The training loss was more than {outlier_multiplier} times greater than \ - the running average loss (approx. {running_loss_avg}) over {outlier_counter} consecutive training steps. \ + the running average loss (approx. {round(running_loss_avg, 1)}) over {outlier_counter} consecutive training steps. \ Please try submitting the run again with a lower learning rate.' super().__init__( @@ -407,6 +408,7 @@ def __init__( outlier_multiplier=outlier_multiplier, running_loss_avg=running_loss_avg, outlier_counter=outlier_counter, + loss_window=loss_window, ) @@ -417,6 +419,7 @@ def __init__( self, loss_cap: float, window_size: int, + loss_window: list[float], ) -> None: message = f'Training stopped due to consistently high losses. The training loss exceeded the threshold of {loss_cap} \ for more than half of the {window_size} most recent training steps. Please try submitting the run again with a lower learning rate.' @@ -425,4 +428,5 @@ def __init__( message, loss_cap=loss_cap, window_size=window_size, + loss_window=loss_window, ) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 75c50511dd..8bfc7287ab 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -37,6 +37,8 @@ def get_default_value(arg_type: Optional[type] = None): return 1 elif arg_type == float: return 1.0 + elif arg_type == list[float]: + return [1.0] elif arg_type == set[str]: return {'set'} elif arg_type == list[str]: