diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 90bc66b6f4..4917cf7ef4 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -64,7 +64,7 @@ def __init__( self.loss_window = deque(maxlen=self.window_size) self.loss_cap = MAX_LOSS_CAP - def detect_loss_spike(self, train_loss: float, running_loss_avg: float): + def _detect_loss_spike(self, train_loss: float, running_loss_avg: float): # Train loss is an outlier if train_loss >= running_loss_avg * self.outlier_multiplier: self.outlier_counter += 1 @@ -82,7 +82,7 @@ def detect_loss_spike(self, train_loss: float, running_loss_avg: float): self.outlier_counter = 0 return False - def detect_high_losses(self, current_step: int): + def _detect_high_losses(self, current_step: int): # Half of the running losses are greater than our "high loss" threshold, after an initial buffer period if (current_step >= self.window_size * 2) and ( sum(1 for loss in self.loss_window if loss > self.loss_cap) >= @@ -93,18 +93,20 @@ def detect_high_losses(self, current_step: int): ) return True return False - - def handle_loss_spike( - self, logger: Logger, running_loss_avg: float - ) -> None: + + def _log_metadata(self, logger: Logger, key: str, message: str) -> None: for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): destination.log_metadata({ - 'loss_spike': - f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.', - 'loss_window': - list(self.loss_window), + key: message, + 'loss_window': list(self.loss_window), }) + + def _handle_loss_spike( + self, logger: Logger, running_loss_avg: float + ) -> None: + message = f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.' + self._log_metadata(logger, 'loss_spike', message) if not self.log_only: raise LossSpikeError( outlier_multiplier=self.outlier_multiplier, @@ -112,15 +114,9 @@ def handle_loss_spike( outlier_counter=self.outlier_counter, ) - def handle_high_losses(self, logger: Logger) -> None: - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({ - 'high_loss': - f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.', - 'loss_window': - list(self.loss_window), - }) + def _handle_high_losses(self, logger: Logger) -> None: + message = f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.' + self._log_metadata(logger, 'high_loss', message) if not self.log_only: raise HighLossError( loss_cap=self.loss_cap,