diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 654e6ffc05..0038c65439 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -7,22 +7,48 @@ import torch import logging import numpy as np +from collections import deque from composer.core import Callback, State from composer.loggers import Logger, MosaicMLLogger -# from llmfoundry.utils.exceptions import LossSpikeError +from llmfoundry.utils.exceptions import LossSpikeError log = logging.getLogger(__name__) __all__ = ['KillLossSpike'] class KillLossSpike(Callback): + """ + A callback for detecting and handling loss spikes or persistently high training losses during model training. + + Monitors the training loss at the end of each batch and maintains a rolling window of recent losses. + If recent training losses exceed a specified cap or if a significant spike in loss is detected, the callback can either + log a warning (displayed as a message on the run event) or raise a LossSpikeError to stop the run without retry. + + Parameters: + log_only (bool): If True, the callback will only log warnings without interrupting training. If False, a + LossSpikeError will be raised to stop training upon detecting a loss spike or persistently + high loss. + patience (int): The number of consecutive outlier losses tolerated before considering the training loss to be + persistently high. Default is 4 (so 5 consecutive outlier losses will trigger an error). + outlier_multiplier (int): The multiplier used to determine if a loss is an outlier. A loss is considered an + outlier if it is outlier_multiplier times greater than the mean of losses in + the current window. Default is 2. + window_size (int): The size of the rolling window used to track recent losses. Default is 100. + loss_cap (int): The maximum allowable loss. If the training loss consistently exceeds this value, + it is considered a diverging or unstable run. Default is 10. + + Raises: + LossSpikeError: If log_only is False and a loss spike or persistently high loss is detected, this error is + raised to stop the run with an error message. + """ - def __init__(self, patience:int=4, outlier_multiplier:int=2, window_size:int=100, loss_cap:int=10): + def __init__(self, log_only:bool, patience:int=4, outlier_multiplier:int=2, window_size:int=100, loss_cap:int=10): + self.log_only = log_only self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size self.loss_cap = loss_cap self.outlier_counter = 0 - self.loss_window = [] + self.loss_window = deque(maxlen=self.window_size) def batch_end(self, state: State, logger: Logger) -> None: @@ -41,12 +67,12 @@ def batch_end(self, state: State, logger: Logger) -> None: log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') if self.outlier_counter > self.patience: log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') - # NOTE: Adding this info the TRAIN_UPDATED event is temporary to 1) collect data on spiky runs and 2) give users information about their run. - # This will be replaced with the hard error LossSpikeError. - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps.'}) - # raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) + if self.log_only: + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps.'}) + else: + raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) # Previous step loss was an outlier, current step loss is not. Reset outlier counter. elif self.outlier_counter > 0: @@ -56,14 +82,11 @@ def batch_end(self, state: State, logger: Logger) -> None: # Half of the running losses are greater than our "high loss" threshold, after the first window elif (state.timestamp.batch >= self.window_size * 2) and (sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2): log.info(f'High losses >{self.loss_cap} detected.') - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected.'}) - # raise LossSpikeError() - - else: - log.info(f'Full loss window size not reached ({len(self.loss_window)} < {self.window_size}). Collecting loss data...') + if self.log_only: + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata({'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected.'}) + else: + raise LossSpikeError() self.loss_window.append(train_loss) - if len(self.loss_window) > self.window_size: - self.loss_window.pop(0)