diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index 660f282267..fe84efa316 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -22,6 +22,7 @@ from llmfoundry.callbacks.eval_output_logging_callback import EvalOutputLogging from llmfoundry.callbacks.fdiff_callback import FDiffMetrics from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer +from llmfoundry.callbacks.kill_loss_spike_callback import KillLossSpike from llmfoundry.callbacks.log_mbmoe_tok_per_expert_callback import ( MegaBlocksMoE_TokPerExpert, ) @@ -60,6 +61,7 @@ callbacks.register('loss_perp_v_len', func=LossPerpVsContextLengthLogger) callbacks.register('env_logging', func=EnvironmentLoggingCallback) callbacks.register('nan_monitor', func=NaNMonitor) +callbacks.register('kill_loss_spike', func=KillLossSpike) callbacks_with_config.register('async_eval', func=AsyncEval) callbacks_with_config.register('curriculum_learning', func=CurriculumLearning) @@ -76,4 +78,5 @@ 'AsyncEval', 'CurriculumLearning', 'LossPerpVsContextLengthLogger', + 'KillLossSpike', ] diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py new file mode 100644 index 0000000000..b0a92c85e5 --- /dev/null +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -0,0 +1,197 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Track training runs for loss spikes or persistently high training loss.""" +from __future__ import annotations + +import logging +from collections import deque + +import numpy as np +import torch +from composer.core import Callback, State, TimeUnit +from composer.loggers import Logger, MosaicMLLogger +from composer.utils import dist + +from llmfoundry.utils.exceptions import HighLossError, LossSpikeError +from llmfoundry.utils.warnings import experimental_class + +log = logging.getLogger(__name__) + +__all__ = ['KillLossSpike'] + +_MIN_WINDOW_SIZE = 100 +_MAX_LOSS_CAP = 10 +_WINDOW_FRACTION = 0.05 + + +@experimental_class('KillLossSpike') +class KillLossSpike(Callback): + """Detects and handles loss spikes or high losses during training. + + Monitors the training loss at the end of each batch and maintains a rolling window of recent losses. + If recent training losses exceed a specified cap or if a significant spike in loss is detected, the callback can either + log a warning (displayed as a message on the run event) or raise a LossSpikeError to stop the run without retry. + + Args: + log_only (bool): If True, the callback will only log warnings without interrupting training. If False, a + LossSpikeError will be raised to stop training upon detecting a loss spike or persistently + high loss. Default is True. + patience (int): The number of consecutive outlier losses tolerated before considering the training loss to be + persistently high. Default is 4 (so 5 consecutive outlier losses will trigger an error). + outlier_multiplier (int): The multiplier used to determine if a loss is an outlier. A loss is considered an + outlier if it is outlier_multiplier times greater than the mean of losses in + the current window. Default is 2. + window_size (int): The size of the rolling window used to track recent losses. This is set to 1/20 of the total training batches, with a minimum of 100 steps. + loss_cap (int): The maximum allowable loss. If the training loss consistently exceeds this value, + it is considered a diverging or unstable run. This is set to the maximum loss from the first window of losses, with a maximum of 10. + + Raises: + LossSpikeError: If log_only is False and a loss spike or persistently high loss is detected, this error is + raised to stop the run with an error message. + """ + + def __init__( + self, + log_only: bool = True, + patience: int = 4, + outlier_multiplier: float = 2, + window_size: int = _MIN_WINDOW_SIZE, + loss_cap: float = _MAX_LOSS_CAP, + ): + self._enabled = (dist.get_global_rank() == 0) + self.log_only = log_only + self.patience = patience + self.outlier_multiplier = outlier_multiplier + self.outlier_counter = 0 + self.window_size = window_size + self.loss_window = deque() + self.loss_cap = loss_cap + self.window_size_set = (window_size != _MIN_WINDOW_SIZE) + self.loss_cap_set = (loss_cap != _MAX_LOSS_CAP) + + def _detect_loss_spike( + self, + train_loss: float, + running_loss_avg: float, + ) -> bool: + # Train loss is an outlier + if train_loss >= running_loss_avg * self.outlier_multiplier: + self.outlier_counter += 1 + log.info( + f'Potential loss spike detected. Iteration: {self.outlier_counter}', + ) + if self.outlier_counter > self.patience: + log.info( + f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.', + ) + return True + # Previous step loss was an outlier, current step loss is not. Reset outlier counter. + elif self.outlier_counter > 0: + log.info(f'Not a persistent loss spike. Resetting outlier counter.') + self.outlier_counter = 0 + return False + + def _detect_high_losses(self, current_step: int) -> bool: + if current_step < self.window_size * 2: + return False + + # Half of the running losses are greater than our "high loss" threshold, after an initial buffer period + high_loss_count = sum( + 1 for loss in self.loss_window if loss > self.loss_cap + ) + is_high_loss = (high_loss_count >= self.window_size / 2) + + if is_high_loss: + log.info( + f'High losses detected: {high_loss_count}/{self.window_size} losses above {self.loss_cap}.', + ) + + return is_high_loss + + def _log_metadata(self, logger: Logger, key: str, message: str) -> None: + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata({ + key: message, + 'loss_window': list(self.loss_window), + }) + + def _handle_loss_spike( + self, + logger: Logger, + running_loss_avg: float, + ) -> None: + message = f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.' + self._log_metadata(logger, 'loss_spike', message) + if not self.log_only: + raise LossSpikeError( + outlier_multiplier=self.outlier_multiplier, + running_loss_avg=round(running_loss_avg), + outlier_counter=self.outlier_counter, + ) + + def _handle_high_losses(self, logger: Logger) -> None: + message = f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.' + self._log_metadata(logger, 'high_loss', message) + if not self.log_only: + raise HighLossError( + loss_cap=self.loss_cap, + window_size=self.window_size, + ) + + def _set_window_size(self, state: State) -> None: + total_training_steps = 0 + current_step = int(state.timestamp.batch) + current_token = int(state.timestamp.token) + + if state.max_duration is not None: + if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None: + total_training_steps = state.dataloader_len * state.max_duration.value + elif state.max_duration.unit == TimeUnit.BATCH: + total_training_steps = state.max_duration.value + elif state.max_duration.unit == TimeUnit.TOKEN: + # This is an approximation of the total batches from the total tokens, assuming the ratio of tokens:batch is constant. + total_training_steps = current_step * ( + state.max_duration.value / current_token + ) + self.window_size = max( + self.window_size, + round(float(total_training_steps * _WINDOW_FRACTION)), + ) + self.loss_window = deque(self.loss_window, maxlen=self.window_size) + log.info(f'Window size set to: {self.window_size}') + self.window_size_set = True + + def batch_end(self, state: State, logger: Logger) -> None: + + if not isinstance(state.loss, torch.Tensor): + raise NotImplementedError('Multiple losses not supported yet') + train_loss = state.loss.item() + + # Only start early stopping once a full window of loss data + if len(self.loss_window) == self.window_size: + + current_step = int(state.timestamp.batch) + + # If window size has not yet been set either by user or during run, set window size to a fraction of the total training duration. Minimum 100 batches. + if not self.window_size_set: + self._set_window_size(state) + # Window size has been expanded -- keep adding losses until we reach the window size. + if self.window_size > _MIN_WINDOW_SIZE: + self.loss_window.append(train_loss) + return + + # If user does not provide a loss cap, set loss cap to the maximum loss from the first loss window. Hard cap at loss=10. + if not self.loss_cap_set and current_step == self.window_size: + self.loss_cap = min(max(self.loss_window), self.loss_cap) + self.loss_cap_set = True + + running_loss_avg = float(np.mean(self.loss_window)) + + if self._detect_loss_spike(train_loss, running_loss_avg): + self._handle_loss_spike(logger, running_loss_avg) + elif self._detect_high_losses(current_step): + self._handle_high_losses(logger) + + self.loss_window.append(train_loss) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 206095f28b..73951ef19e 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -387,3 +387,42 @@ class RunTimeoutError(InternalError): def __init__(self, timeout: int) -> None: message = f'Run timed out after {timeout} seconds.' super().__init__(message, timeout=timeout) + + +class LossSpikeError(UserError): + """Error thrown if a severe loss spike occurs.""" + + def __init__( + self, + outlier_multiplier: float, + running_loss_avg: int, + outlier_counter: int, + ) -> None: + message = f'Training stopped due to a loss spike. The training loss was more than {outlier_multiplier} times greater than \ + the running average loss (approx. {running_loss_avg}) over {outlier_counter} consecutive training steps. \ + Please try submitting the run again with a lower learning rate.' + + super().__init__( + message, + outlier_multiplier=outlier_multiplier, + running_loss_avg=running_loss_avg, + outlier_counter=outlier_counter, + ) + + +class HighLossError(UserError): + """Error thrown if training loss plateaus or is unstable at a high level.""" + + def __init__( + self, + loss_cap: float, + window_size: int, + ) -> None: + message = f'Training stopped due to consistently high losses. The training loss exceeded the threshold of {loss_cap} \ + for more than half of the {window_size} most recent training steps. Please try submitting the run again with a lower learning rate.' + + super().__init__( + message, + loss_cap=loss_cap, + window_size=window_size, + ) diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py new file mode 100644 index 0000000000..577dbd33cc --- /dev/null +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -0,0 +1,105 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +import unittest +from collections import deque +from unittest.mock import MagicMock, patch + +from composer.core.time import TimeUnit + +from llmfoundry.callbacks.kill_loss_spike_callback import KillLossSpike +from llmfoundry.utils.exceptions import LossSpikeError + + +class TestKillLossSpike(unittest.TestCase): + + def __init__(self, *args: str, **kwargs: dict): + super(TestKillLossSpike, self).__init__(*args, **kwargs) + self.callback = KillLossSpike( + log_only=True, + patience=4, + outlier_multiplier=2, + ) + self.callback.window_size = 10 + self.callback.loss_cap = 10 + + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_detect_loss_spike_no_spike(self, _): + self.callback.outlier_counter = 0 + train_loss = 4 + running_loss_avg = 2 + result = self.callback._detect_loss_spike(train_loss, running_loss_avg) + self.assertFalse(result) + + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_detect_loss_spike_with_spike(self, _): + self.callback.outlier_counter = 4 # Simulating previous spikes + train_loss = 4 + running_loss_avg = 2 + result = self.callback._detect_loss_spike(train_loss, running_loss_avg) + self.assertTrue(result) + + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_handle_loss_spike_logs_only_when_log_only_true(self, _): + logger = MagicMock() + running_loss_avg = 2 + self.callback.log_only = True + self.callback.outlier_counter = 5 + + try: + self.callback._handle_loss_spike(logger, running_loss_avg) + except LossSpikeError: + self.fail('LossSpikeError was raised unexpectedly') + + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_handle_loss_spike_raises_error_log_only_false(self, _): + logger = MagicMock() + running_loss_avg = 2 + self.callback.log_only = False + self.callback.outlier_counter = 5 + + # LossSpikeError is raised + with self.assertRaises(LossSpikeError): + self.callback._handle_loss_spike(logger, running_loss_avg) + + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_detect_high_losses_no_high_losses(self, _): + self.callback.loss_window = deque([2] * 10, maxlen=10) + current_step = 21 + result = self.callback._detect_high_losses(current_step) + self.assertFalse(result) + + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_detect_high_losses_with_high_losses(self, _): + self.callback.loss_window = deque( + [9, 8, 7, 6, 5, 11, 12, 13, 14, 15], + maxlen=10, + ) # Simulate mix of losses in loss window + current_step = 21 + result = self.callback._detect_high_losses(current_step) + self.assertTrue(result) + + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_set_window_size_from_token(self, _): + state = MagicMock() + state.max_duration.unit = TimeUnit.TOKEN + state.max_duration.value = 100000 + state.timestamp.batch = 100 + state.timestamp.token = 4000 + + self.callback._set_window_size(state) + + self.assertEqual(self.callback.window_size, 125) + self.assertTrue(self.callback.window_size_set) + + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_set_window_size_from_epoch(self, _): + state = MagicMock() + state.max_duration.unit = TimeUnit.EPOCH + state.dataloader_len = 1000 + state.max_duration.value = 3 + state.timestamp.batch = 100 + + self.callback._set_window_size(state) + + self.assertEqual(self.callback.window_size, 150) + self.assertTrue(self.callback.window_size_set) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 097bdf77fb..75c50511dd 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -35,6 +35,8 @@ def get_default_value(arg_type: Optional[type] = None): return 'string' elif arg_type == int: return 1 + elif arg_type == float: + return 1.0 elif arg_type == set[str]: return {'set'} elif arg_type == list[str]: