Skip to content

Commit

Permalink
Docstring, deque, log_only mode
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-chen-uni committed Aug 21, 2024
1 parent 45eaa0e commit e160e21
Showing 1 changed file with 41 additions and 18 deletions.
59 changes: 41 additions & 18 deletions llmfoundry/callbacks/kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,48 @@
import torch
import logging
import numpy as np
from collections import deque
from composer.core import Callback, State
from composer.loggers import Logger, MosaicMLLogger
# from llmfoundry.utils.exceptions import LossSpikeError
from llmfoundry.utils.exceptions import LossSpikeError
log = logging.getLogger(__name__)

__all__ = ['KillLossSpike']

class KillLossSpike(Callback):
"""
A callback for detecting and handling loss spikes or persistently high training losses during model training.
Monitors the training loss at the end of each batch and maintains a rolling window of recent losses.
If recent training losses exceed a specified cap or if a significant spike in loss is detected, the callback can either
log a warning (displayed as a message on the run event) or raise a LossSpikeError to stop the run without retry.
Parameters:
log_only (bool): If True, the callback will only log warnings without interrupting training. If False, a
LossSpikeError will be raised to stop training upon detecting a loss spike or persistently
high loss.
patience (int): The number of consecutive outlier losses tolerated before considering the training loss to be
persistently high. Default is 4 (so 5 consecutive outlier losses will trigger an error).
outlier_multiplier (int): The multiplier used to determine if a loss is an outlier. A loss is considered an
outlier if it is outlier_multiplier times greater than the mean of losses in
the current window. Default is 2.
window_size (int): The size of the rolling window used to track recent losses. Default is 100.
loss_cap (int): The maximum allowable loss. If the training loss consistently exceeds this value,
it is considered a diverging or unstable run. Default is 10.
Raises:
LossSpikeError: If log_only is False and a loss spike or persistently high loss is detected, this error is
raised to stop the run with an error message.
"""

def __init__(self, patience:int=4, outlier_multiplier:int=2, window_size:int=100, loss_cap:int=10):
def __init__(self, log_only:bool, patience:int=4, outlier_multiplier:int=2, window_size:int=100, loss_cap:int=10):
self.log_only = log_only
self.patience = patience
self.outlier_multiplier = outlier_multiplier
self.window_size = window_size
self.loss_cap = loss_cap
self.outlier_counter = 0
self.loss_window = []
self.loss_window = deque(maxlen=self.window_size)

def batch_end(self, state: State, logger: Logger) -> None:

Expand All @@ -41,12 +67,12 @@ def batch_end(self, state: State, logger: Logger) -> None:
log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}')
if self.outlier_counter > self.patience:
log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.')
# NOTE: Adding this info the TRAIN_UPDATED event is temporary to 1) collect data on spiky runs and 2) give users information about their run.
# This will be replaced with the hard error LossSpikeError.
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps.'})
# raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter)
if self.log_only:
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps.'})
else:
raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter)

# Previous step loss was an outlier, current step loss is not. Reset outlier counter.
elif self.outlier_counter > 0:
Expand All @@ -56,14 +82,11 @@ def batch_end(self, state: State, logger: Logger) -> None:
# Half of the running losses are greater than our "high loss" threshold, after the first window
elif (state.timestamp.batch >= self.window_size * 2) and (sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2):
log.info(f'High losses >{self.loss_cap} detected.')
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected.'})
# raise LossSpikeError()

else:
log.info(f'Full loss window size not reached ({len(self.loss_window)} < {self.window_size}). Collecting loss data...')
if self.log_only:
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected.'})
else:
raise LossSpikeError()

self.loss_window.append(train_loss)
if len(self.loss_window) > self.window_size:
self.loss_window.pop(0)

0 comments on commit e160e21

Please sign in to comment.