Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect loss spikes and high losses during training #1473

Merged
merged 85 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 76 commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
8e216a7
Kill run on loss spike callback + init
joyce-chen-uni Aug 15, 2024
a1a294b
Import
joyce-chen-uni Aug 15, 2024
601e64a
Add logger as arg
joyce-chen-uni Aug 16, 2024
95520c1
Attribute
joyce-chen-uni Aug 16, 2024
783d02d
Logging for debugging
joyce-chen-uni Aug 16, 2024
0b2eba4
Only check potential loss spike if loss window sufficient size
joyce-chen-uni Aug 16, 2024
e78e9b4
Need to track whether previous step was a potential spike
joyce-chen-uni Aug 17, 2024
18256a5
Bug fix loss window append + simplify
joyce-chen-uni Aug 19, 2024
d300726
Bug fix loss window append + simplify
joyce-chen-uni Aug 19, 2024
44ad527
add spike yaml temporarily
joyce-chen-uni Aug 19, 2024
d0523ac
Low parameters for testing
joyce-chen-uni Aug 19, 2024
eb5fdba
Delete logger
joyce-chen-uni Aug 19, 2024
ad5c5e9
Custom error for loss spike
joyce-chen-uni Aug 19, 2024
2a135d6
Detailed error
joyce-chen-uni Aug 19, 2024
06b958e
Reduce logging
joyce-chen-uni Aug 19, 2024
c9cfad2
Delete yaml
joyce-chen-uni Aug 19, 2024
7b4d31e
add a condition for stopping if generally very high losses
joyce-chen-uni Aug 20, 2024
12e878e
Test logging
joyce-chen-uni Aug 20, 2024
ee5ea13
Just test metadata logging
joyce-chen-uni Aug 20, 2024
f24ef7e
Move window slide to the end so that cur loss not factored into runni…
joyce-chen-uni Aug 20, 2024
6779b3d
log/fail depending on wall clock
joyce-chen-uni Aug 20, 2024
a496028
No loggin
joyce-chen-uni Aug 20, 2024
ad88a43
Try to ensure checkpoint saved before killing run
joyce-chen-uni Aug 21, 2024
8eb7824
test
joyce-chen-uni Aug 21, 2024
20261bd
Log metadata dict
joyce-chen-uni Aug 21, 2024
a76996f
testing
joyce-chen-uni Aug 21, 2024
e646b2a
Log all potential spikes to run metadata
joyce-chen-uni Aug 21, 2024
ebd8ff5
Test logging
joyce-chen-uni Aug 21, 2024
53f4117
log to be sure
joyce-chen-uni Aug 21, 2024
decf9c5
Final
joyce-chen-uni Aug 21, 2024
9db4814
Explanatory comment
joyce-chen-uni Aug 21, 2024
afa21b9
Merge branch 'main' into main
joyce-chen-uni Aug 21, 2024
4b732fa
Increase loss cap
joyce-chen-uni Aug 21, 2024
88a5417
Remove advisory
joyce-chen-uni Aug 21, 2024
45eaa0e
Remove hardcoded const
joyce-chen-uni Aug 21, 2024
e160e21
Docstring, deque, log_only mode
joyce-chen-uni Aug 21, 2024
39fdbfb
edit docstring
joyce-chen-uni Aug 21, 2024
8cb2f95
Specific different error for high losses
joyce-chen-uni Aug 22, 2024
a539866
Decompose detection functions
joyce-chen-uni Aug 22, 2024
5da701c
First pass unit tests
joyce-chen-uni Aug 22, 2024
7fe9c07
Working unit tests
joyce-chen-uni Aug 22, 2024
3721ed0
Nits
joyce-chen-uni Aug 22, 2024
e062db0
Type issues
joyce-chen-uni Aug 22, 2024
8ebec3f
Init callback
joyce-chen-uni Aug 22, 2024
df874b5
init
joyce-chen-uni Aug 22, 2024
30fceb0
Missing args
joyce-chen-uni Aug 22, 2024
0b6483a
Wording change
joyce-chen-uni Aug 22, 2024
f8fe2d8
lint
joyce-chen-uni Aug 22, 2024
b08b1dc
docstring
joyce-chen-uni Aug 22, 2024
b606ca8
Default log_only to true
joyce-chen-uni Aug 22, 2024
378beec
Address comments
joyce-chen-uni Aug 22, 2024
fd08f80
try fix type issue
joyce-chen-uni Aug 22, 2024
e759f5a
New exception for high loss
joyce-chen-uni Aug 22, 2024
89250e1
Handle float arg
joyce-chen-uni Aug 22, 2024
1154deb
Fixes exceptions, add exp class
joyce-chen-uni Aug 22, 2024
03826b5
Add suggestion back to run event message
joyce-chen-uni Aug 23, 2024
753793e
Log the loss window when threshold is met
joyce-chen-uni Aug 23, 2024
b8f2634
Lint
joyce-chen-uni Aug 23, 2024
74d2cb4
Set loss cap to max of first loss window, + only do callback on rank …
joyce-chen-uni Aug 23, 2024
3d74b99
Specify window size as a fraction of training duration steps
joyce-chen-uni Aug 23, 2024
10f9ac7
Formatting stuff
joyce-chen-uni Aug 23, 2024
4167a74
Formatting
joyce-chen-uni Aug 23, 2024
bf220f4
fit start?
joyce-chen-uni Aug 24, 2024
8ec812c
Logging for window size setting check
joyce-chen-uni Aug 24, 2024
93b0fe8
Add back window_size and loss_cap as optl params
joyce-chen-uni Aug 24, 2024
2214175
Test that error is not raised when log_only set to true
joyce-chen-uni Aug 25, 2024
57800e6
Update docstring
joyce-chen-uni Aug 25, 2024
c59e5be
Round fractional window size
joyce-chen-uni Aug 25, 2024
1aeb46f
Window size and loss cap not specifiable
joyce-chen-uni Aug 25, 2024
1b5a757
Adjust tests no window size / loss cap inputs
joyce-chen-uni Aug 25, 2024
e3c00b3
init vals for window size, loss window, loss cap
joyce-chen-uni Aug 25, 2024
04d20fa
decompose spike/high loss handling
joyce-chen-uni Aug 26, 2024
4658a8f
decompose logging and rename helpers
joyce-chen-uni Aug 26, 2024
55dd5bb
Test error raised/not raised depending on log_only
joyce-chen-uni Aug 26, 2024
5b044fa
Cleanup window frac
joyce-chen-uni Aug 26, 2024
95bed00
Clean up fit_start & callback rank 0 only
joyce-chen-uni Aug 26, 2024
50a6b5e
Try moving self._enabled setting
joyce-chen-uni Aug 26, 2024
9f37281
Private consts
joyce-chen-uni Aug 26, 2024
acd46f4
Add absolute loss cap and loss window as arguments
joyce-chen-uni Aug 26, 2024
782316b
Take out optional type
joyce-chen-uni Aug 26, 2024
187c861
Calculate the loss window once hit min loss window size
joyce-chen-uni Aug 27, 2024
47a461c
Test window size setting
joyce-chen-uni Aug 27, 2024
fee774b
Keep old deque
joyce-chen-uni Aug 27, 2024
3e2ecc6
Cleanup
joyce-chen-uni Aug 27, 2024
4005a5c
Remove running loss avg log
joyce-chen-uni Aug 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from llmfoundry.callbacks.eval_output_logging_callback import EvalOutputLogging
from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer
from llmfoundry.callbacks.kill_loss_spike_callback import KillLossSpike
from llmfoundry.callbacks.log_mbmoe_tok_per_expert_callback import (
MegaBlocksMoE_TokPerExpert,
)
Expand Down Expand Up @@ -60,6 +61,7 @@
callbacks.register('loss_perp_v_len', func=LossPerpVsContextLengthLogger)
callbacks.register('env_logging', func=EnvironmentLoggingCallback)
callbacks.register('nan_monitor', func=NaNMonitor)
callbacks.register('kill_loss_spike', func=KillLossSpike)

callbacks_with_config.register('async_eval', func=AsyncEval)
callbacks_with_config.register('curriculum_learning', func=CurriculumLearning)
Expand All @@ -76,4 +78,5 @@
'AsyncEval',
'CurriculumLearning',
'LossPerpVsContextLengthLogger',
'KillLossSpike',
]
185 changes: 185 additions & 0 deletions llmfoundry/callbacks/kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Track training runs for loss spikes or persistently high training loss."""
from __future__ import annotations

import logging
from collections import deque

import numpy as np
import torch
from composer.core import Callback, State, TimeUnit
from composer.loggers import Logger, MosaicMLLogger
from composer.utils import dist

from llmfoundry.utils.exceptions import HighLossError, LossSpikeError
from llmfoundry.utils.warnings import experimental_class

log = logging.getLogger(__name__)

__all__ = ['KillLossSpike']

MIN_WINDOW_SIZE = 100
joyce-chen-uni marked this conversation as resolved.
Show resolved Hide resolved
joyce-chen-uni marked this conversation as resolved.
Show resolved Hide resolved
MAX_LOSS_CAP = 10
WINDOW_FRACTION = 0.05


@experimental_class('KillLossSpike')
class KillLossSpike(Callback):
joyce-chen-uni marked this conversation as resolved.
Show resolved Hide resolved
joyce-chen-uni marked this conversation as resolved.
Show resolved Hide resolved
"""Detects and handles loss spikes or high losses during training.

Monitors the training loss at the end of each batch and maintains a rolling window of recent losses.
If recent training losses exceed a specified cap or if a significant spike in loss is detected, the callback can either
log a warning (displayed as a message on the run event) or raise a LossSpikeError to stop the run without retry.

Args:
log_only (bool): If True, the callback will only log warnings without interrupting training. If False, a
joyce-chen-uni marked this conversation as resolved.
Show resolved Hide resolved
LossSpikeError will be raised to stop training upon detecting a loss spike or persistently
high loss. Default is True.
patience (int): The number of consecutive outlier losses tolerated before considering the training loss to be
persistently high. Default is 4 (so 5 consecutive outlier losses will trigger an error).
outlier_multiplier (int): The multiplier used to determine if a loss is an outlier. A loss is considered an
outlier if it is outlier_multiplier times greater than the mean of losses in
the current window. Default is 2.
window_size (int): The size of the rolling window used to track recent losses. This is set to 1/20 of the total training batches, with a minimum of 100 steps.
loss_cap (int): The maximum allowable loss. If the training loss consistently exceeds this value,
it is considered a diverging or unstable run. This is set to the maximum loss from the first window of losses, with a maximum of 10.

Raises:
LossSpikeError: If log_only is False and a loss spike or persistently high loss is detected, this error is
raised to stop the run with an error message.
"""

def __init__(
joyce-chen-uni marked this conversation as resolved.
Show resolved Hide resolved
self,
log_only: bool = True,
patience: int = 4,
outlier_multiplier: float = 2,
):
self._enabled = True
joyce-chen-uni marked this conversation as resolved.
Show resolved Hide resolved
self.log_only = log_only
self.patience = patience
self.outlier_multiplier = outlier_multiplier
self.outlier_counter = 0
self.window_size = MIN_WINDOW_SIZE
self.loss_window = deque(maxlen=self.window_size)
self.loss_cap = MAX_LOSS_CAP

def _detect_loss_spike(
self,
train_loss: float,
running_loss_avg: float,
) -> bool:
# Train loss is an outlier
if train_loss >= running_loss_avg * self.outlier_multiplier:
self.outlier_counter += 1
log.info(
f'Potential loss spike detected. Iteration: {self.outlier_counter}',
)
if self.outlier_counter > self.patience:
log.info(
f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.',
)
return True
# Previous step loss was an outlier, current step loss is not. Reset outlier counter.
elif self.outlier_counter > 0:
log.info(f'Not a persistent loss spike. Resetting outlier counter.')
self.outlier_counter = 0
return False

def _detect_high_losses(self, current_step: int) -> bool:
if current_step < self.window_size * 2:
return False

# Half of the running losses are greater than our "high loss" threshold, after an initial buffer period
high_loss_count = sum(
1 for loss in self.loss_window if loss > self.loss_cap
)
is_high_loss = high_loss_count >= self.window_size / 2

if is_high_loss:
log.info(
f'High losses detected: {high_loss_count}/{self.window_size} losses above {self.loss_cap}.',
)

return is_high_loss

def _log_metadata(self, logger: Logger, key: str, message: str) -> None:
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({
key: message,
'loss_window': list(self.loss_window),
})

def _handle_loss_spike(
self,
logger: Logger,
running_loss_avg: float,
) -> None:
message = f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.'
self._log_metadata(logger, 'loss_spike', message)
if not self.log_only:
raise LossSpikeError(
outlier_multiplier=self.outlier_multiplier,
running_loss_avg=round(running_loss_avg),
outlier_counter=self.outlier_counter,
)

def _handle_high_losses(self, logger: Logger) -> None:
message = f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.'
self._log_metadata(logger, 'high_loss', message)
if not self.log_only:
raise HighLossError(
loss_cap=self.loss_cap,
window_size=self.window_size,
)

def fit_start(self, state: State, logger: Logger) -> None:
# Set the window size to a fraction of the total number of training batches for the run, minimum 100 batches.
if state.max_duration is not None:
if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None:
total_training_steps = state.dataloader_len * state.max_duration.value
elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN:
total_training_steps = state.max_duration.value
joyce-chen-uni marked this conversation as resolved.
Show resolved Hide resolved
self.window_size = max(
self.window_size,
round(float(total_training_steps * WINDOW_FRACTION)),
)
self.loss_window = deque(maxlen=self.window_size)
self._enabled = (dist.get_global_rank() == 0)

def batch_end(self, state: State, logger: Logger) -> None:

if not isinstance(state.loss, torch.Tensor):
raise NotImplementedError('Multiple losses not supported yet')
train_loss = state.loss.item()

log.info(f'Window size: {self.window_size}')

# Only start early stopping once a full window of loss data
joyce-chen-uni marked this conversation as resolved.
Show resolved Hide resolved
if len(self.loss_window) == self.window_size:

current_step = int(state.timestamp.batch)
# Only applies to if max_duration is set in tokens. If current batch is less than MIN_WINDOW_SIZE
# as set by tokens, we should raise the window size to the MIN_WINDOW_SIZE and continue.
if current_step < MIN_WINDOW_SIZE:
self.window_size = MIN_WINDOW_SIZE
self.loss_window.append(train_loss)
return

# Set the loss cap to the maximum loss from the first loss window
if current_step == self.window_size:
self.loss_cap = max(max(self.loss_window), self.loss_cap)

running_loss_avg = float(np.mean(self.loss_window))
log.info(f'Running loss average: {running_loss_avg}')
joyce-chen-uni marked this conversation as resolved.
Show resolved Hide resolved

if self._detect_loss_spike(train_loss, running_loss_avg):
self._handle_loss_spike(logger, running_loss_avg)

elif self._detect_high_losses(current_step):
self._handle_high_losses(logger)

self.loss_window.append(train_loss)
39 changes: 39 additions & 0 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,42 @@ class RunTimeoutError(InternalError):
def __init__(self, timeout: int) -> None:
message = f'Run timed out after {timeout} seconds.'
super().__init__(message, timeout=timeout)


class LossSpikeError(UserError):
"""Error thrown if a severe loss spike occurs."""

def __init__(
self,
outlier_multiplier: float,
running_loss_avg: int,
outlier_counter: int,
) -> None:
message = f'Training stopped due to a loss spike. The training loss was more than {outlier_multiplier} times greater than \
the running average loss (approx. {running_loss_avg}) over {outlier_counter} consecutive training steps. \
Please try submitting the run again with a lower learning rate.'

super().__init__(
message,
outlier_multiplier=outlier_multiplier,
running_loss_avg=running_loss_avg,
outlier_counter=outlier_counter,
)


class HighLossError(UserError):
"""Error thrown if training loss plateaus or is unstable at a high level."""

def __init__(
self,
loss_cap: float,
window_size: int,
) -> None:
message = f'Training stopped due to consistently high losses. The training loss exceeded the threshold of {loss_cap} \
for more than half of the {window_size} most recent training steps. Please try submitting the run again with a lower learning rate.'

super().__init__(
message,
loss_cap=loss_cap,
window_size=window_size,
)
77 changes: 77 additions & 0 deletions tests/callbacks/test_kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import unittest
from collections import deque
from unittest.mock import MagicMock, patch

from llmfoundry.callbacks.kill_loss_spike_callback import KillLossSpike
from llmfoundry.utils.exceptions import LossSpikeError


class TestKillLossSpike(unittest.TestCase):
joyce-chen-uni marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, *args: str, **kwargs: dict):
super(TestKillLossSpike, self).__init__(*args, **kwargs)
self.callback = KillLossSpike(
log_only=True,
patience=4,
outlier_multiplier=2,
)
self.callback.window_size = 10
self.callback.loss_cap = 10

@patch('llmfoundry.callbacks.kill_loss_spike_callback.log')
def test_detect_loss_spike_no_spike(self, _):
self.callback.outlier_counter = 0
train_loss = 4
running_loss_avg = 2
result = self.callback._detect_loss_spike(train_loss, running_loss_avg)
self.assertFalse(result)

@patch('llmfoundry.callbacks.kill_loss_spike_callback.log')
def test_detect_loss_spike_with_spike(self, _):
self.callback.outlier_counter = 4 # Simulating previous spikes
train_loss = 4
running_loss_avg = 2
result = self.callback._detect_loss_spike(train_loss, running_loss_avg)
self.assertTrue(result)

@patch('llmfoundry.callbacks.kill_loss_spike_callback.log')
def test_handle_loss_spike_logs_only_when_log_only_true(self, _):
logger = MagicMock()
running_loss_avg = 2
self.callback.log_only = True
self.callback.outlier_counter = 5

try:
self.callback._handle_loss_spike(logger, running_loss_avg)
except LossSpikeError:
self.fail('LossSpikeError was raised unexpectedly')

@patch('llmfoundry.callbacks.kill_loss_spike_callback.log')
def test_handle_loss_spike_raises_error_log_only_false(self, _):
logger = MagicMock()
running_loss_avg = 2
self.callback.log_only = False
self.callback.outlier_counter = 5

# LossSpikeError is raised
with self.assertRaises(LossSpikeError):
self.callback._handle_loss_spike(logger, running_loss_avg)

@patch('llmfoundry.callbacks.kill_loss_spike_callback.log')
def test_detect_high_losses_no_high_losses(self, _):
self.callback.loss_window = deque([2] * 10, maxlen=10)
current_step = 21
result = self.callback._detect_high_losses(current_step)
self.assertFalse(result)

@patch('llmfoundry.callbacks.kill_loss_spike_callback.log')
def test_detect_high_losses_with_high_losses(self, _):
self.callback.loss_window = deque(
[9, 8, 7, 6, 5, 11, 12, 13, 14, 15],
maxlen=10,
) # Simulate mix of losses in loss window
current_step = 21
result = self.callback._detect_high_losses(current_step)
self.assertTrue(result)
2 changes: 2 additions & 0 deletions tests/utils/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def get_default_value(arg_type: Optional[type] = None):
return 'string'
elif arg_type == int:
return 1
elif arg_type == float:
return 1.0
elif arg_type == set[str]:
return {'set'}
elif arg_type == list[str]:
Expand Down
Loading