From 8e216a7c7922d6a8f77005775c514dde35999bb1 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 15 Aug 2024 14:38:30 -0700 Subject: [PATCH 01/83] Kill run on loss spike callback + init --- llmfoundry/callbacks/__init__.py | 3 ++ .../callbacks/kill_loss_spike_callback.py | 44 +++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 llmfoundry/callbacks/kill_loss_spike_callback.py diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index 496e905e13..c15f2925fe 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -35,6 +35,7 @@ from llmfoundry.callbacks.run_timeout_callback import RunTimeoutCallback from llmfoundry.callbacks.scheduled_gc_callback import ScheduledGarbageCollector from llmfoundry.registry import callbacks, callbacks_with_config +from llmfoundry.callbacks.kill_loss_spike_callback import KillLossSpike callbacks.register('system_metrics_monitor', func=SystemMetricsMonitor) callbacks.register('lr_monitor', func=LRMonitor) @@ -55,6 +56,7 @@ callbacks.register('eval_output_logging', func=EvalOutputLogging) callbacks.register('mbmoe_tok_per_expert', func=MegaBlocksMoE_TokPerExpert) callbacks.register('run_timeout', func=RunTimeoutCallback) +callbacks.register('kill_loss_spike', func=KillLossSpike) callbacks.register('loss_perp_v_len', func=LossPerpVsContextLengthLogger) @@ -73,4 +75,5 @@ 'AsyncEval', 'CurriculumLearning', 'LossPerpVsContextLengthLogger', + 'KillLossSpike', ] diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py new file mode 100644 index 0000000000..d81badafdc --- /dev/null +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -0,0 +1,44 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Monitor rate of change of loss.""" +from __future__ import annotations + +import torch +import numpy as np +from composer.core import Callback, State +from utils.exceptions import UserError + +__all__ = ['KillLossSpike'] + +class KillLossSpike(Callback): + + def __init__(self, patience:int=10, outlier_multiplier:int=2, window_size:int=100): + self.patience = patience + self.outlier_multiplier = outlier_multiplier + self.window_size = window_size + self.iterations = 0 + self.early_stop = False + self.loss_window = [] + + def batch_end(self, state: State): + if not isinstance(state.loss, torch.Tensor): + raise NotImplementedError('Multiple losses not supported yet') + train_loss = state.loss.item() + + self.loss_window.append(train_loss) + if len(self.loss_window) > self.window_size: + self.loss_window.pop(0) + # Only start early stopping once a full window of loss data + if len(self.loss_window) == self.window_size: + running_loss_avg = np.mean(self.loss_window) + + # If train loss exceeds the running average + if train_loss > running_loss_avg * self.outlier_threshold: + self.iterations += 1 + if self.iterations > self.patience: + self.early_stop = True + # Some kind of user error message + raise UserError('Training stopped due to loss spike. Please try submitting the run again with a lower learning rate.') + else: + self.iterations = 0 From a1a294bbba7d7fc53df37c97ee68e2297aa88597 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 15 Aug 2024 16:19:00 -0700 Subject: [PATCH 02/83] Import --- llmfoundry/callbacks/kill_loss_spike_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index d81badafdc..1c5593bf73 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -7,7 +7,7 @@ import torch import numpy as np from composer.core import Callback, State -from utils.exceptions import UserError +from llmfoundry.utils.exceptions import UserError __all__ = ['KillLossSpike'] From 601e64a060b01bfd39dd025ba539a69b0c62e8b1 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 15 Aug 2024 17:13:50 -0700 Subject: [PATCH 03/83] Add logger as arg --- llmfoundry/callbacks/kill_loss_spike_callback.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 1c5593bf73..d50d6874b5 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -7,6 +7,7 @@ import torch import numpy as np from composer.core import Callback, State +from composer.loggers import Logger from llmfoundry.utils.exceptions import UserError __all__ = ['KillLossSpike'] @@ -21,7 +22,7 @@ def __init__(self, patience:int=10, outlier_multiplier:int=2, window_size:int=10 self.early_stop = False self.loss_window = [] - def batch_end(self, state: State): + def batch_end(self, state: State, _: Logger) -> None: if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') train_loss = state.loss.item() From 95520c13ccd090ea484df3902c807ab709641d09 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 15 Aug 2024 17:23:23 -0700 Subject: [PATCH 04/83] Attribute --- llmfoundry/callbacks/kill_loss_spike_callback.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index d50d6874b5..6912fa7460 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -19,6 +19,7 @@ def __init__(self, patience:int=10, outlier_multiplier:int=2, window_size:int=10 self.outlier_multiplier = outlier_multiplier self.window_size = window_size self.iterations = 0 + self.running_loss_avg = 0 self.early_stop = False self.loss_window = [] @@ -32,10 +33,10 @@ def batch_end(self, state: State, _: Logger) -> None: self.loss_window.pop(0) # Only start early stopping once a full window of loss data if len(self.loss_window) == self.window_size: - running_loss_avg = np.mean(self.loss_window) + self.running_loss_avg = np.mean(self.loss_window) # If train loss exceeds the running average - if train_loss > running_loss_avg * self.outlier_threshold: + if train_loss > self.running_loss_avg * self.outlier_multiplier: self.iterations += 1 if self.iterations > self.patience: self.early_stop = True From 783d02d64d2f148d226c858a2ab8db58e20d8829 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Fri, 16 Aug 2024 11:25:39 -0700 Subject: [PATCH 05/83] Logging for debugging --- llmfoundry/callbacks/kill_loss_spike_callback.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 6912fa7460..1c64dcb1da 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -5,16 +5,18 @@ from __future__ import annotations import torch +import logging import numpy as np from composer.core import Callback, State from composer.loggers import Logger from llmfoundry.utils.exceptions import UserError +log = logging.getLogger(__name__) __all__ = ['KillLossSpike'] class KillLossSpike(Callback): - def __init__(self, patience:int=10, outlier_multiplier:int=2, window_size:int=100): + def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=100): self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size @@ -23,7 +25,7 @@ def __init__(self, patience:int=10, outlier_multiplier:int=2, window_size:int=10 self.early_stop = False self.loss_window = [] - def batch_end(self, state: State, _: Logger) -> None: + def batch_end(self, state: State, logger: Logger) -> None: if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') train_loss = state.loss.item() @@ -34,13 +36,16 @@ def batch_end(self, state: State, _: Logger) -> None: # Only start early stopping once a full window of loss data if len(self.loss_window) == self.window_size: self.running_loss_avg = np.mean(self.loss_window) + log.info(f'Running loss average: {self.running_loss_avg}') # If train loss exceeds the running average if train_loss > self.running_loss_avg * self.outlier_multiplier: self.iterations += 1 + log.info(f'Potential loss spike detected. Iteration: {self.iterations}') if self.iterations > self.patience: self.early_stop = True # Some kind of user error message raise UserError('Training stopped due to loss spike. Please try submitting the run again with a lower learning rate.') else: + log.info(f'Not a persistent loss spike.') self.iterations = 0 From 0b2eba492459bad1c5beaaa3d45a6b51f66af609 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Fri, 16 Aug 2024 16:09:12 -0700 Subject: [PATCH 06/83] Only check potential loss spike if loss window sufficient size --- .../callbacks/kill_loss_spike_callback.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 1c64dcb1da..89c3fd63e3 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -16,7 +16,7 @@ class KillLossSpike(Callback): - def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=100): + def __init__(self, patience:int=2, outlier_multiplier:int=2, window_size:int=100): self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size @@ -38,14 +38,14 @@ def batch_end(self, state: State, logger: Logger) -> None: self.running_loss_avg = np.mean(self.loss_window) log.info(f'Running loss average: {self.running_loss_avg}') - # If train loss exceeds the running average - if train_loss > self.running_loss_avg * self.outlier_multiplier: - self.iterations += 1 - log.info(f'Potential loss spike detected. Iteration: {self.iterations}') - if self.iterations > self.patience: - self.early_stop = True - # Some kind of user error message - raise UserError('Training stopped due to loss spike. Please try submitting the run again with a lower learning rate.') - else: - log.info(f'Not a persistent loss spike.') - self.iterations = 0 + # If train loss exceeds the running average + if train_loss > self.running_loss_avg * self.outlier_multiplier: + self.iterations += 1 + log.info(f'Potential loss spike detected. Iteration: {self.iterations}') + if self.iterations > self.patience: + self.early_stop = True + # Some kind of user error message + raise UserError('Training stopped due to loss spike. Please try submitting the run again with a lower learning rate.') + else: + log.info(f'Not a persistent loss spike.') + self.iterations = 0 From e78e9b49f32187821fbf8c394dc10088f8b21807 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Fri, 16 Aug 2024 17:12:02 -0700 Subject: [PATCH 07/83] Need to track whether previous step was a potential spike --- .../callbacks/kill_loss_spike_callback.py | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 89c3fd63e3..a10ba849e7 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -16,23 +16,21 @@ class KillLossSpike(Callback): - def __init__(self, patience:int=2, outlier_multiplier:int=2, window_size:int=100): - self.patience = patience - self.outlier_multiplier = outlier_multiplier - self.window_size = window_size - self.iterations = 0 - self.running_loss_avg = 0 - self.early_stop = False - self.loss_window = [] + def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=100): + self.patience = patience + self.outlier_multiplier = outlier_multiplier + self.window_size = window_size + self.iterations = 0 + self.running_loss_avg = 0 + self.early_stop = False + self.prev_step_spike = False + self.loss_window = [] def batch_end(self, state: State, logger: Logger) -> None: if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') train_loss = state.loss.item() - self.loss_window.append(train_loss) - if len(self.loss_window) > self.window_size: - self.loss_window.pop(0) # Only start early stopping once a full window of loss data if len(self.loss_window) == self.window_size: self.running_loss_avg = np.mean(self.loss_window) @@ -40,12 +38,21 @@ def batch_end(self, state: State, logger: Logger) -> None: # If train loss exceeds the running average if train_loss > self.running_loss_avg * self.outlier_multiplier: - self.iterations += 1 log.info(f'Potential loss spike detected. Iteration: {self.iterations}') + + if self.prev_was_spike: + self.iterations += 1 if self.iterations > self.patience: self.early_stop = True # Some kind of user error message raise UserError('Training stopped due to loss spike. Please try submitting the run again with a lower learning rate.') - else: - log.info(f'Not a persistent loss spike.') - self.iterations = 0 + + self.prev_step_spike = True + + elif self.prev_step_spike: + log.info(f'Not a persistent loss spike.') + self.iterations = 0 + self.prev_step_spike = False + + self.loss_window.append(train_loss) + self.loss_window.pop(0) From 18256a545358417fcc2a9aa9893001d2bc7fc65d Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Sun, 18 Aug 2024 21:08:31 -0700 Subject: [PATCH 08/83] Bug fix loss window append + simplify --- .../callbacks/kill_loss_spike_callback.py | 40 +++++++++---------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index a10ba849e7..278f2be7d1 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -20,39 +20,35 @@ def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=100 self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size - self.iterations = 0 - self.running_loss_avg = 0 - self.early_stop = False - self.prev_step_spike = False + self.outlier_counter = 0 self.loss_window = [] def batch_end(self, state: State, logger: Logger) -> None: if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') train_loss = state.loss.item() + self.loss_window.append(train_loss) + if len(self.loss_window) > self.window_size: + self.loss_window.pop(0) # Only start early stopping once a full window of loss data if len(self.loss_window) == self.window_size: - self.running_loss_avg = np.mean(self.loss_window) - log.info(f'Running loss average: {self.running_loss_avg}') + running_loss_avg = np.mean(self.loss_window) + log.info(f'Running loss average: {running_loss_avg}') # If train loss exceeds the running average - if train_loss > self.running_loss_avg * self.outlier_multiplier: - log.info(f'Potential loss spike detected. Iteration: {self.iterations}') - - if self.prev_was_spike: - self.iterations += 1 - if self.iterations > self.patience: - self.early_stop = True + if train_loss > running_loss_avg * self.outlier_multiplier: + self.outlier_counter += 1 + log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') + if self.outlier_counter > self.patience: # Some kind of user error message raise UserError('Training stopped due to loss spike. Please try submitting the run again with a lower learning rate.') - - self.prev_step_spike = True - - elif self.prev_step_spike: - log.info(f'Not a persistent loss spike.') - self.iterations = 0 - self.prev_step_spike = False - self.loss_window.append(train_loss) - self.loss_window.pop(0) + elif self.outlier_counter > 0: + log.info(f'Not a persistent loss spike.') + self.outlier_counter = 0 + + else: + log.info('No loss spike detected.') + else: + log.info(f'Full loss window size not reached ({len(self.loss_window)} < {self.window_size}). Collecting loss data...') From d3007262163c2005e42eb44734a7e47413926c72 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Sun, 18 Aug 2024 21:08:31 -0700 Subject: [PATCH 09/83] Bug fix loss window append + simplify --- .../callbacks/kill_loss_spike_callback.py | 43 +++++++++---------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index a10ba849e7..9edc43a37a 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -20,39 +20,38 @@ def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=100 self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size - self.iterations = 0 - self.running_loss_avg = 0 - self.early_stop = False - self.prev_step_spike = False + self.outlier_counter = 0 self.loss_window = [] def batch_end(self, state: State, logger: Logger) -> None: if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') train_loss = state.loss.item() + self.loss_window.append(train_loss) + if len(self.loss_window) > self.window_size: + self.loss_window.pop(0) # Only start early stopping once a full window of loss data if len(self.loss_window) == self.window_size: - self.running_loss_avg = np.mean(self.loss_window) - log.info(f'Running loss average: {self.running_loss_avg}') + running_loss_avg = np.mean(self.loss_window) + log.info(f'Running loss average: {running_loss_avg}') # If train loss exceeds the running average - if train_loss > self.running_loss_avg * self.outlier_multiplier: - log.info(f'Potential loss spike detected. Iteration: {self.iterations}') - - if self.prev_was_spike: - self.iterations += 1 - if self.iterations > self.patience: - self.early_stop = True + if train_loss > running_loss_avg * self.outlier_multiplier: + self.outlier_counter += 1 + log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') + if self.outlier_counter > self.patience: # Some kind of user error message - raise UserError('Training stopped due to loss spike. Please try submitting the run again with a lower learning rate.') - - self.prev_step_spike = True - - elif self.prev_step_spike: + raise UserError(f'Training stopped due to a loss spike over {self.outlier_counter} consecutive training steps. \ + Please try submitting the run again with a lower learning rate.') + + # Previous step loss was an outlier, current step loss is not. Reset outlier counter. + elif self.outlier_counter > 0: log.info(f'Not a persistent loss spike.') - self.iterations = 0 - self.prev_step_spike = False + self.outlier_counter = 0 + + else: + log.info('No loss spike detected.') - self.loss_window.append(train_loss) - self.loss_window.pop(0) + else: + log.info(f'Full loss window size not reached ({len(self.loss_window)} < {self.window_size}). Collecting loss data...') From d0523ac577b49dc05ea9b2e46bfd345889e63f6c Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Mon, 19 Aug 2024 08:55:11 -0700 Subject: [PATCH 10/83] Low parameters for testing --- llmfoundry/callbacks/kill_loss_spike_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 9edc43a37a..6c69c6595b 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -16,7 +16,7 @@ class KillLossSpike(Callback): - def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=100): + def __init__(self, patience:int=2, outlier_multiplier:int=1.2, window_size:int=50): self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size @@ -36,7 +36,7 @@ def batch_end(self, state: State, logger: Logger) -> None: running_loss_avg = np.mean(self.loss_window) log.info(f'Running loss average: {running_loss_avg}') - # If train loss exceeds the running average + # If train loss is an outlier if train_loss > running_loss_avg * self.outlier_multiplier: self.outlier_counter += 1 log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') From eb5fdba153fa1a8380927e1340f630c57c4458ad Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Mon, 19 Aug 2024 09:39:19 -0700 Subject: [PATCH 11/83] Delete logger --- llmfoundry/callbacks/kill_loss_spike_callback.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 6c69c6595b..b45bc1239b 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -16,7 +16,7 @@ class KillLossSpike(Callback): - def __init__(self, patience:int=2, outlier_multiplier:int=1.2, window_size:int=50): + def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=100): self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size @@ -24,6 +24,8 @@ def __init__(self, patience:int=2, outlier_multiplier:int=1.2, window_size:int=5 self.loss_window = [] def batch_end(self, state: State, logger: Logger) -> None: + del logger + if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') train_loss = state.loss.item() From ad5c5e9466e8d99cd472476b3745d09a81beb049 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Mon, 19 Aug 2024 10:38:43 -0700 Subject: [PATCH 12/83] Custom error for loss spike --- llmfoundry/callbacks/kill_loss_spike_callback.py | 9 ++++----- llmfoundry/utils/exceptions.py | 8 ++++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index b45bc1239b..753f14e11b 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -9,7 +9,7 @@ import numpy as np from composer.core import Callback, State from composer.loggers import Logger -from llmfoundry.utils.exceptions import UserError +from llmfoundry.utils.exceptions import LossSpikeError log = logging.getLogger(__name__) __all__ = ['KillLossSpike'] @@ -44,16 +44,15 @@ def batch_end(self, state: State, logger: Logger) -> None: log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') if self.outlier_counter > self.patience: # Some kind of user error message - raise UserError(f'Training stopped due to a loss spike over {self.outlier_counter} consecutive training steps. \ - Please try submitting the run again with a lower learning rate.') + raise LossSpikeError(self.outlier_counter) # Previous step loss was an outlier, current step loss is not. Reset outlier counter. elif self.outlier_counter > 0: - log.info(f'Not a persistent loss spike.') + log.info(f'Not a persistent loss spike. Resetting outlier counter.') self.outlier_counter = 0 else: - log.info('No loss spike detected.') + log.info('No loss spike detected. Average of recent losses: {running_loss_avg}.') else: log.info(f'Full loss window size not reached ({len(self.loss_window)} < {self.window_size}). Collecting loss data...') diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index b2e5cc06e8..747b404e46 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -379,3 +379,11 @@ class RunTimeoutError(InternalError): def __init__(self, timeout: int) -> None: message = f'Run timed out after {timeout} seconds.' super().__init__(message, timeout=timeout) + +class LossSpikeError(UserError): + """Error thrown a severe loss spike occurs.""" + + def __init__(self, outlier_counter: int) -> None: + message = f'Training stopped due to a loss spike over {outlier_counter} consecutive training steps. \ + Please try submitting the run again with a lower learning rate.' + super().__init__(message) From 2a135d6ab6ade44f4c7d9043f76dbc6c9e4ee301 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Mon, 19 Aug 2024 11:07:39 -0700 Subject: [PATCH 13/83] Detailed error --- llmfoundry/callbacks/kill_loss_spike_callback.py | 4 ++-- llmfoundry/utils/exceptions.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 753f14e11b..155391bea3 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -44,7 +44,7 @@ def batch_end(self, state: State, logger: Logger) -> None: log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') if self.outlier_counter > self.patience: # Some kind of user error message - raise LossSpikeError(self.outlier_counter) + raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) # Previous step loss was an outlier, current step loss is not. Reset outlier counter. elif self.outlier_counter > 0: @@ -52,7 +52,7 @@ def batch_end(self, state: State, logger: Logger) -> None: self.outlier_counter = 0 else: - log.info('No loss spike detected. Average of recent losses: {running_loss_avg}.') + log.info(f'No loss spike detected. Average of recent losses: {running_loss_avg}.') else: log.info(f'Full loss window size not reached ({len(self.loss_window)} < {self.window_size}). Collecting loss data...') diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 747b404e46..38107d4368 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -383,7 +383,8 @@ def __init__(self, timeout: int) -> None: class LossSpikeError(UserError): """Error thrown a severe loss spike occurs.""" - def __init__(self, outlier_counter: int) -> None: - message = f'Training stopped due to a loss spike over {outlier_counter} consecutive training steps. \ + def __init__(self, outlier_multiplier: int, running_loss_avg: int, outlier_counter: int) -> None: + message = f'Training stopped due to a loss spike. The training loss was {outlier_multiplier} times greater than the \ + running average loss (approx. {running_loss_avg}) over {outlier_counter} consecutive training steps. \ Please try submitting the run again with a lower learning rate.' - super().__init__(message) + super().__init__(message, outlier_multiplier=outlier_multiplier, running_loss_avg=running_loss_avg, outlier_counter=outlier_counter) From 06b958e169dac7226d87b244690612ee005caf5f Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Mon, 19 Aug 2024 13:03:07 -0700 Subject: [PATCH 14/83] Reduce logging --- llmfoundry/callbacks/kill_loss_spike_callback.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 155391bea3..9f9ff93e14 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -50,9 +50,6 @@ def batch_end(self, state: State, logger: Logger) -> None: elif self.outlier_counter > 0: log.info(f'Not a persistent loss spike. Resetting outlier counter.') self.outlier_counter = 0 - - else: - log.info(f'No loss spike detected. Average of recent losses: {running_loss_avg}.') else: log.info(f'Full loss window size not reached ({len(self.loss_window)} < {self.window_size}). Collecting loss data...') From c9cfad2be8700a88506997a6b91ba504cd28b737 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Mon, 19 Aug 2024 15:55:31 -0700 Subject: [PATCH 15/83] Delete yaml --- mcli/mcli-test-kill-spike.yaml | 173 --------------------------------- 1 file changed, 173 deletions(-) delete mode 100644 mcli/mcli-test-kill-spike.yaml diff --git a/mcli/mcli-test-kill-spike.yaml b/mcli/mcli-test-kill-spike.yaml deleted file mode 100644 index 6211f95f2e..0000000000 --- a/mcli/mcli-test-kill-spike.yaml +++ /dev/null @@ -1,173 +0,0 @@ -integrations: -- integration_type: git_repo - git_repo: joyce-chen-uni/llm-foundry - git_branch: main - # git_commit: # OR use your commit hash - pip_install: .[gpu] - ssh_clone: false # Should be true if using a private repo - -command: | - cd llm-foundry/scripts - composer train/train.py /mnt/config/parameters.yaml -image: mosaicml/llm-foundry:2.3.1_cu121-latest -name: llama3-finetune-kill-spike - -compute: - # Note: Finetuning the 70b model requires at least 16x80GB GPUs - gpus: 8 # Number of GPUs to use - ## These configurations are optional - cluster: r1z1 # Name of the cluster to use for this run - # gpu_type: a100_80gb # Type of GPU to use. We use a100_80gb in our experiments - -# The below is injected as a YAML file: /mnt/config/parameters.yaml -parameters: - max_seq_len: 8192 - - # Run Name - run_name: # If left blank, will be read from env var $RUN_NAME - - max_split_size_mb: 512 - - # Model - model: - name: hf_causal_lm - init_device: mixed - pretrained_model_name_or_path: meta-llama/Meta-Llama-3-8B - pretrained: true - # Note: you must have set the HF_TOKEN environment variable and have access to the llama2 models - use_auth_token: true - use_flash_attention_2: true - - # Tokenizer - tokenizer: - name: meta-llama/Meta-Llama-3-8B - kwargs: - model_max_length: 8192 - - # Dataloaders - train_loader: - name: finetuning - dataset: - hf_name: mosaicml/dolly_hhrlhf - split: train - max_seq_len: 8192 - allow_pad_trimming: false - decoder_only_format: true - shuffle: true - # # Use packing_ratio: 'auto' to automatically profile and select the highest observed packing ratio with - # # zero waste. In practice, this may result in > 0 waste because profiling is done on only a portion - # # of the dataset. - # # Or use `python llmfoundry/scripts/misc/profile_packing.py --yaml-path /path/to/this/yaml/ ...` - # # to profile this run's optimal packing_ratio as it depends on GPU count, - # # batch size, sequence length - # packing_ratio: auto - drop_last: true - num_workers: 8 - pin_memory: false - prefetch_factor: 2 - persistent_workers: true - timeout: 0 - - eval_loader: - name: finetuning - dataset: - hf_name: mosaicml/dolly_hhrlhf - split: test - max_seq_len: ${max_seq_len} - allow_pad_trimming: false - decoder_only_format: true - # packing_ratio: - shuffle: false - drop_last: true - num_workers: 8 - pin_memory: false - prefetch_factor: 2 - persistent_workers: true - timeout: 0 - - # Optimization - scheduler: - name: linear_decay_with_warmup - t_warmup: 100ba - alpha_f: 0.1 - - # Note: You may want to change learning rate, betas, weight decay - optimizer: - name: decoupled_lionw - lr: 1.0e-3 - betas: - - 0.9 - - 0.95 - weight_decay: 0.0 - - algorithms: - gradient_clipping: - clipping_type: norm - clipping_threshold: 1.0 - - max_duration: 1ep - eval_first: false - eval_interval: 1ep - eval_subset_num_batches: -1 - global_train_batch_size: 64 - - # System - seed: 17 - device_eval_batch_size: 8 - device_train_microbatch_size: auto - precision: amp_bf16 - - # FSDP - fsdp_config: - sharding_strategy: FULL_SHARD - mixed_precision: PURE - activation_checkpointing: true - activation_checkpointing_reentrant: false - activation_cpu_offload: false - limit_all_gathers: trued - - # Logging - progress_bar: false - log_to_console: true - console_log_interval: 1ba - - callbacks: - speed_monitor: - window_size: 10 - lr_monitor: {} - memory_monitor: {} - runtime_estimator: {} - kill_loss_spike: {} - - load_weights_only: true # Only load the weights, not the optimizer state, LR schedule, etc - - loggers: - mlflow: - tracking_uri: databricks - experiment_name: /Users/joyce.chen@databricks.com/kill-spike - model_registry_uri: databricks-uc - model_registry_prefix: main.joyce - ignore_metrics: - - memory/* - - throughput/* - - trainer/* - - time/val - - time/train - - time/total - - time/token_in_epoch - - time/sample* - - time/batch_in_epoch - - loss/train/lbl - - loss/train/loss - - metrics/train/LanguageCrossEntropy - - time/batch - -# Checkpoint to local filesystem or remote object store -# save_interval: 2000ba -# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK -# save_folder: ./{run_name}/checkpoints -# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints - -# Load from local filesystem or remote object store -# load_path: ./gpt-1b/checkpoints/latest-rank{rank}.pt -# load_path: s3://my-bucket/my-folder/gpt-1b/checkpoints/latest-rank{rank}.pt From 7b4d31e2741cd2f01a394a976459e0dde29a8253 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Mon, 19 Aug 2024 17:15:05 -0700 Subject: [PATCH 16/83] add a condition for stopping if generally very high losses --- llmfoundry/callbacks/kill_loss_spike_callback.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 9f9ff93e14..663c936190 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -16,10 +16,11 @@ class KillLossSpike(Callback): - def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=100): + def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=100, loss_cap:int=10): self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size + self.loss_cap = loss_cap self.outlier_counter = 0 self.loss_window = [] @@ -43,13 +44,18 @@ def batch_end(self, state: State, logger: Logger) -> None: self.outlier_counter += 1 log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') if self.outlier_counter > self.patience: - # Some kind of user error message - raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) + # raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) + log.info(f'RUN STOPPED. Loss spike >{self.outlier_multiplier}*{running_loss_avg} detected for {self.outlier_counter} steps.') # Previous step loss was an outlier, current step loss is not. Reset outlier counter. elif self.outlier_counter > 0: log.info(f'Not a persistent loss spike. Resetting outlier counter.') self.outlier_counter = 0 + # Half of the running losses are greater than our "high loss" threshold + elif sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2: + # raise LossSpikeError() + log.info(f'RUN STOPPED. High losses >{self.loss_cap} detected.') + else: log.info(f'Full loss window size not reached ({len(self.loss_window)} < {self.window_size}). Collecting loss data...') From 12e878e1b738e9b946e122081ca330620d253558 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Tue, 20 Aug 2024 14:33:21 -0700 Subject: [PATCH 17/83] Test logging --- .../callbacks/kill_loss_spike_callback.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 663c936190..1b1d7ebe7d 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -8,7 +8,7 @@ import logging import numpy as np from composer.core import Callback, State -from composer.loggers import Logger +from composer.loggers import Logger, MosaicMLLogger from llmfoundry.utils.exceptions import LossSpikeError log = logging.getLogger(__name__) @@ -16,7 +16,7 @@ class KillLossSpike(Callback): - def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=100, loss_cap:int=10): + def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=20, loss_cap:int=10): self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size @@ -25,7 +25,8 @@ def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=100 self.loss_window = [] def batch_end(self, state: State, logger: Logger) -> None: - del logger + + train_time = logger.get_metric('time/train') if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') @@ -44,8 +45,13 @@ def batch_end(self, state: State, logger: Logger) -> None: self.outlier_counter += 1 log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') if self.outlier_counter > self.patience: - # raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) - log.info(f'RUN STOPPED. Loss spike >{self.outlier_multiplier}*{running_loss_avg} detected for {self.outlier_counter} steps.') + if train_time > 0.1: + log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata('LossSpike', f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') + else: + raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) # Previous step loss was an outlier, current step loss is not. Reset outlier counter. elif self.outlier_counter > 0: @@ -54,8 +60,13 @@ def batch_end(self, state: State, logger: Logger) -> None: # Half of the running losses are greater than our "high loss" threshold elif sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2: - # raise LossSpikeError() - log.info(f'RUN STOPPED. High losses >{self.loss_cap} detected.') + if train_time > 1: + log.info(f'High losses >{self.loss_cap} detected.') + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata('PersistentHighLoss', f'High losses >{self.loss_cap} detected.') + else: + raise LossSpikeError() else: log.info(f'Full loss window size not reached ({len(self.loss_window)} < {self.window_size}). Collecting loss data...') From ee5ea13f9a6690096c424236e78015e49415fc85 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Tue, 20 Aug 2024 14:58:00 -0700 Subject: [PATCH 18/83] Just test metadata logging --- .../callbacks/kill_loss_spike_callback.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 1b1d7ebe7d..76bc2b29de 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -26,7 +26,7 @@ def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=20, def batch_end(self, state: State, logger: Logger) -> None: - train_time = logger.get_metric('time/train') + # train_time = logger.get_metric('time/train') if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') @@ -45,13 +45,13 @@ def batch_end(self, state: State, logger: Logger) -> None: self.outlier_counter += 1 log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') if self.outlier_counter > self.patience: - if train_time > 0.1: - log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') + # if train_time > 0.1: + # log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): - destination.log_metadata('LossSpike', f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') - else: - raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) + destination.log_metadata('Loss Spike', f'Loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.') + # else: + # raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) # Previous step loss was an outlier, current step loss is not. Reset outlier counter. elif self.outlier_counter > 0: @@ -60,13 +60,13 @@ def batch_end(self, state: State, logger: Logger) -> None: # Half of the running losses are greater than our "high loss" threshold elif sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2: - if train_time > 1: - log.info(f'High losses >{self.loss_cap} detected.') + # if train_time > 1: + # log.info(f'High losses >{self.loss_cap} detected.') for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): - destination.log_metadata('PersistentHighLoss', f'High losses >{self.loss_cap} detected.') - else: - raise LossSpikeError() + destination.log_metadata('High Loss', f'Persistently high losses >{self.loss_cap} detected. Try lowering the learning rate.') + # else: + # raise LossSpikeError() else: log.info(f'Full loss window size not reached ({len(self.loss_window)} < {self.window_size}). Collecting loss data...') From f24ef7e6d3a45578d151edd836b033462f770f4a Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Tue, 20 Aug 2024 15:18:39 -0700 Subject: [PATCH 19/83] Move window slide to the end so that cur loss not factored into running loss --- llmfoundry/callbacks/kill_loss_spike_callback.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 76bc2b29de..ca306f6da3 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -31,9 +31,6 @@ def batch_end(self, state: State, logger: Logger) -> None: if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') train_loss = state.loss.item() - self.loss_window.append(train_loss) - if len(self.loss_window) > self.window_size: - self.loss_window.pop(0) # Only start early stopping once a full window of loss data if len(self.loss_window) == self.window_size: @@ -51,7 +48,7 @@ def batch_end(self, state: State, logger: Logger) -> None: if isinstance(destination, MosaicMLLogger): destination.log_metadata('Loss Spike', f'Loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.') # else: - # raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) + raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) # Previous step loss was an outlier, current step loss is not. Reset outlier counter. elif self.outlier_counter > 0: @@ -70,3 +67,7 @@ def batch_end(self, state: State, logger: Logger) -> None: else: log.info(f'Full loss window size not reached ({len(self.loss_window)} < {self.window_size}). Collecting loss data...') + + self.loss_window.append(train_loss) + if len(self.loss_window) > self.window_size: + self.loss_window.pop(0) From 6779b3d07ff1e275fa5c9dec9648012929335a54 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Tue, 20 Aug 2024 15:51:45 -0700 Subject: [PATCH 20/83] log/fail depending on wall clock --- .../callbacks/kill_loss_spike_callback.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index ca306f6da3..b9d5ac0155 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -26,7 +26,7 @@ def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=20, def batch_end(self, state: State, logger: Logger) -> None: - # train_time = logger.get_metric('time/train') + train_time = state.timestamp.total_wct.total_seconds() if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') @@ -42,12 +42,13 @@ def batch_end(self, state: State, logger: Logger) -> None: self.outlier_counter += 1 log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') if self.outlier_counter > self.patience: - # if train_time > 0.1: - # log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') + # TODO: Make this a full hour after testint + if train_time > 360: + log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): destination.log_metadata('Loss Spike', f'Loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.') - # else: + else: raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) # Previous step loss was an outlier, current step loss is not. Reset outlier counter. @@ -57,13 +58,13 @@ def batch_end(self, state: State, logger: Logger) -> None: # Half of the running losses are greater than our "high loss" threshold elif sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2: - # if train_time > 1: - # log.info(f'High losses >{self.loss_cap} detected.') + if train_time > 360: + log.info(f'High losses >{self.loss_cap} detected.') for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): destination.log_metadata('High Loss', f'Persistently high losses >{self.loss_cap} detected. Try lowering the learning rate.') - # else: - # raise LossSpikeError() + else: + raise LossSpikeError() else: log.info(f'Full loss window size not reached ({len(self.loss_window)} < {self.window_size}). Collecting loss data...') From a4960285ff05b5bd961c6fc928b390e7ff52555f Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Tue, 20 Aug 2024 16:11:18 -0700 Subject: [PATCH 21/83] No loggin --- .../callbacks/kill_loss_spike_callback.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index b9d5ac0155..ddc3149492 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -16,7 +16,7 @@ class KillLossSpike(Callback): - def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=20, loss_cap:int=10): + def __init__(self, patience:int=4, outlier_multiplier:int=2, window_size:int=100, loss_cap:int=10): self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size @@ -26,7 +26,7 @@ def __init__(self, patience:int=3, outlier_multiplier:int=2, window_size:int=20, def batch_end(self, state: State, logger: Logger) -> None: - train_time = state.timestamp.total_wct.total_seconds() + # train_time = state.timestamp.total_wct.total_seconds() if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') @@ -43,12 +43,12 @@ def batch_end(self, state: State, logger: Logger) -> None: log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') if self.outlier_counter > self.patience: # TODO: Make this a full hour after testint - if train_time > 360: - log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata('Loss Spike', f'Loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.') - else: + # if train_time > 3600: + # log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') + # for destination in logger.destinations: + # if isinstance(destination, MosaicMLLogger): + # destination.log_metadata('Loss Spike', f'Loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.') + # else: raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) # Previous step loss was an outlier, current step loss is not. Reset outlier counter. @@ -58,12 +58,12 @@ def batch_end(self, state: State, logger: Logger) -> None: # Half of the running losses are greater than our "high loss" threshold elif sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2: - if train_time > 360: - log.info(f'High losses >{self.loss_cap} detected.') - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata('High Loss', f'Persistently high losses >{self.loss_cap} detected. Try lowering the learning rate.') - else: + # if train_time > 3600: + # log.info(f'High losses >{self.loss_cap} detected.') + # for destination in logger.destinations: + # if isinstance(destination, MosaicMLLogger): + # destination.log_metadata('High Loss', f'Persistently high losses >{self.loss_cap} detected. Try lowering the learning rate.') + # else: raise LossSpikeError() else: From ad88a435e9489066695e29926357e65d79ab82c6 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Tue, 20 Aug 2024 17:20:31 -0700 Subject: [PATCH 22/83] Try to ensure checkpoint saved before killing run --- .../callbacks/kill_loss_spike_callback.py | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index ddc3149492..d2030eebb9 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -26,7 +26,7 @@ def __init__(self, patience:int=4, outlier_multiplier:int=2, window_size:int=100 def batch_end(self, state: State, logger: Logger) -> None: - # train_time = state.timestamp.total_wct.total_seconds() + train_time = state.timestamp.total_wct.total_seconds() if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') @@ -42,13 +42,12 @@ def batch_end(self, state: State, logger: Logger) -> None: self.outlier_counter += 1 log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') if self.outlier_counter > self.patience: - # TODO: Make this a full hour after testint - # if train_time > 3600: - # log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') - # for destination in logger.destinations: - # if isinstance(destination, MosaicMLLogger): - # destination.log_metadata('Loss Spike', f'Loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.') - # else: + if train_time < 5000: + log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata('Loss Spike', f'Loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.') + else: raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) # Previous step loss was an outlier, current step loss is not. Reset outlier counter. @@ -56,14 +55,14 @@ def batch_end(self, state: State, logger: Logger) -> None: log.info(f'Not a persistent loss spike. Resetting outlier counter.') self.outlier_counter = 0 - # Half of the running losses are greater than our "high loss" threshold - elif sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2: - # if train_time > 3600: - # log.info(f'High losses >{self.loss_cap} detected.') - # for destination in logger.destinations: - # if isinstance(destination, MosaicMLLogger): - # destination.log_metadata('High Loss', f'Persistently high losses >{self.loss_cap} detected. Try lowering the learning rate.') - # else: + # Half of the running losses are greater than our "high loss" threshold, after the first window + elif state.timestamp.batch >= 200 and (sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2): + if train_time < 5000: + log.info(f'High losses >{self.loss_cap} detected.') + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata('High Loss', f'Persistently high losses >{self.loss_cap} detected. Try lowering the learning rate.') + else: raise LossSpikeError() else: From 8eb78242ad9e7059076bca5d2a2f12e8a1cca607 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Tue, 20 Aug 2024 17:39:19 -0700 Subject: [PATCH 23/83] test --- llmfoundry/callbacks/kill_loss_spike_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index d2030eebb9..2ace1cbe85 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -16,7 +16,7 @@ class KillLossSpike(Callback): - def __init__(self, patience:int=4, outlier_multiplier:int=2, window_size:int=100, loss_cap:int=10): + def __init__(self, patience:int=4, outlier_multiplier:int=2, window_size:int=50, loss_cap:int=10): self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size From 20261bda426982428c584f4c1b5074f8e4f586b7 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Wed, 21 Aug 2024 09:14:18 -0700 Subject: [PATCH 24/83] Log metadata dict --- llmfoundry/callbacks/kill_loss_spike_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 2ace1cbe85..a260e8080c 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -46,7 +46,7 @@ def batch_end(self, state: State, logger: Logger) -> None: log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): - destination.log_metadata('Loss Spike', f'Loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.') + destination.log_metadata({'Loss Spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.'}) else: raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) @@ -61,7 +61,7 @@ def batch_end(self, state: State, logger: Logger) -> None: log.info(f'High losses >{self.loss_cap} detected.') for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): - destination.log_metadata('High Loss', f'Persistently high losses >{self.loss_cap} detected. Try lowering the learning rate.') + destination.log_metadata({'High Loss': f'Persistently high (>{self.loss_cap}) training losses detected. Try lowering the learning rate.'}) else: raise LossSpikeError() From a76996feecd1e0449733ae5794a8b9de76dc5b09 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Wed, 21 Aug 2024 09:52:09 -0700 Subject: [PATCH 25/83] testing --- .../callbacks/kill_loss_spike_callback.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index a260e8080c..288ffba18a 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -16,7 +16,7 @@ class KillLossSpike(Callback): - def __init__(self, patience:int=4, outlier_multiplier:int=2, window_size:int=50, loss_cap:int=10): + def __init__(self, patience:int=4, outlier_multiplier:int=2, window_size:int=20, loss_cap:int=10): self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size @@ -26,7 +26,7 @@ def __init__(self, patience:int=4, outlier_multiplier:int=2, window_size:int=50, def batch_end(self, state: State, logger: Logger) -> None: - train_time = state.timestamp.total_wct.total_seconds() + # train_time = state.timestamp.total_wct.total_seconds() if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') @@ -42,13 +42,13 @@ def batch_end(self, state: State, logger: Logger) -> None: self.outlier_counter += 1 log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') if self.outlier_counter > self.patience: - if train_time < 5000: + # if train_time < 5000: log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'Loss Spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.'}) - else: - raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) + destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.'}) + # else: + # raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) # Previous step loss was an outlier, current step loss is not. Reset outlier counter. elif self.outlier_counter > 0: @@ -57,13 +57,13 @@ def batch_end(self, state: State, logger: Logger) -> None: # Half of the running losses are greater than our "high loss" threshold, after the first window elif state.timestamp.batch >= 200 and (sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2): - if train_time < 5000: + # if train_time < 5000: log.info(f'High losses >{self.loss_cap} detected.') for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'High Loss': f'Persistently high (>{self.loss_cap}) training losses detected. Try lowering the learning rate.'}) - else: - raise LossSpikeError() + destination.log_metadata({'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected. Try lowering the learning rate.'}) + # else: + # raise LossSpikeError() else: log.info(f'Full loss window size not reached ({len(self.loss_window)} < {self.window_size}). Collecting loss data...') From e646b2a879207e84b9617fea52f57335d0e48268 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Wed, 21 Aug 2024 11:26:41 -0700 Subject: [PATCH 26/83] Log all potential spikes to run metadata --- .../callbacks/kill_loss_spike_callback.py | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 288ffba18a..d3021927f3 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -1,7 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -"""Monitor rate of change of loss.""" +"""Track training runs for loss spikes or persistently high training loss.""" from __future__ import annotations import torch @@ -9,14 +9,14 @@ import numpy as np from composer.core import Callback, State from composer.loggers import Logger, MosaicMLLogger -from llmfoundry.utils.exceptions import LossSpikeError +# from llmfoundry.utils.exceptions import LossSpikeError log = logging.getLogger(__name__) __all__ = ['KillLossSpike'] class KillLossSpike(Callback): - def __init__(self, patience:int=4, outlier_multiplier:int=2, window_size:int=20, loss_cap:int=10): + def __init__(self, patience:int=4, outlier_multiplier:int=2, window_size:int=100, loss_cap:int=5): self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size @@ -25,8 +25,6 @@ def __init__(self, patience:int=4, outlier_multiplier:int=2, window_size:int=20, self.loss_window = [] def batch_end(self, state: State, logger: Logger) -> None: - - # train_time = state.timestamp.total_wct.total_seconds() if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') @@ -42,13 +40,11 @@ def batch_end(self, state: State, logger: Logger) -> None: self.outlier_counter += 1 log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') if self.outlier_counter > self.patience: - # if train_time < 5000: - log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.'}) - # else: - # raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) + log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.'}) + # raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) # Previous step loss was an outlier, current step loss is not. Reset outlier counter. elif self.outlier_counter > 0: @@ -57,13 +53,11 @@ def batch_end(self, state: State, logger: Logger) -> None: # Half of the running losses are greater than our "high loss" threshold, after the first window elif state.timestamp.batch >= 200 and (sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2): - # if train_time < 5000: - log.info(f'High losses >{self.loss_cap} detected.') - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected. Try lowering the learning rate.'}) - # else: - # raise LossSpikeError() + log.info(f'High losses >{self.loss_cap} detected.') + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata({'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected. Try lowering the learning rate.'}) + # raise LossSpikeError() else: log.info(f'Full loss window size not reached ({len(self.loss_window)} < {self.window_size}). Collecting loss data...') From ebd8ff53bf74f6039a827461966b145fc17af23a Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Wed, 21 Aug 2024 11:57:25 -0700 Subject: [PATCH 27/83] Test logging --- llmfoundry/callbacks/kill_loss_spike_callback.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index d3021927f3..383839e96a 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -30,6 +30,11 @@ def batch_end(self, state: State, logger: Logger) -> None: raise NotImplementedError('Multiple losses not supported yet') train_loss = state.loss.item() + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.', + 'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected. Try lowering the learning rate.'}) + # Only start early stopping once a full window of loss data if len(self.loss_window) == self.window_size: running_loss_avg = np.mean(self.loss_window) From 53f41172917714480dd44337e4df5339a60fa8ca Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Wed, 21 Aug 2024 12:04:53 -0700 Subject: [PATCH 28/83] log to be sure --- llmfoundry/callbacks/kill_loss_spike_callback.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 383839e96a..eec5012395 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -34,6 +34,7 @@ def batch_end(self, state: State, logger: Logger) -> None: if isinstance(destination, MosaicMLLogger): destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.', 'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected. Try lowering the learning rate.'}) + log.info(f'Logging metadata for loss spike and high loss.') # Only start early stopping once a full window of loss data if len(self.loss_window) == self.window_size: From decf9c509474a203941125a21f796504acdee887 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Wed, 21 Aug 2024 12:33:45 -0700 Subject: [PATCH 29/83] Final --- llmfoundry/callbacks/kill_loss_spike_callback.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index eec5012395..d3021927f3 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -30,12 +30,6 @@ def batch_end(self, state: State, logger: Logger) -> None: raise NotImplementedError('Multiple losses not supported yet') train_loss = state.loss.item() - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.', - 'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected. Try lowering the learning rate.'}) - log.info(f'Logging metadata for loss spike and high loss.') - # Only start early stopping once a full window of loss data if len(self.loss_window) == self.window_size: running_loss_avg = np.mean(self.loss_window) From 9db4814faef77f0a4751b7ebdba06ad6084d6252 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Wed, 21 Aug 2024 13:07:40 -0700 Subject: [PATCH 30/83] Explanatory comment --- llmfoundry/callbacks/kill_loss_spike_callback.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index d3021927f3..5037f8d4bd 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -41,6 +41,8 @@ def batch_end(self, state: State, logger: Logger) -> None: log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') if self.outlier_counter > self.patience: log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') + # NOTE: Adding this info the TRAIN_UPDATED event is temporary to 1) collect data on spiky runs and 2) give users information about their run. + # This will be replaced with the hard error LossSpikeError. for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.'}) From 4b732fa7bc19a7ae3a145098d149c3035314f957 Mon Sep 17 00:00:00 2001 From: joyce-chen-uni Date: Wed, 21 Aug 2024 14:09:03 -0700 Subject: [PATCH 31/83] Increase loss cap --- llmfoundry/callbacks/kill_loss_spike_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 5037f8d4bd..cccdf87132 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -16,7 +16,7 @@ class KillLossSpike(Callback): - def __init__(self, patience:int=4, outlier_multiplier:int=2, window_size:int=100, loss_cap:int=5): + def __init__(self, patience:int=4, outlier_multiplier:int=2, window_size:int=100, loss_cap:int=10): self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size From 88a5417d57e20ed61a7a6696059a3ec591c43f78 Mon Sep 17 00:00:00 2001 From: joyce-chen-uni Date: Wed, 21 Aug 2024 14:22:22 -0700 Subject: [PATCH 32/83] Remove advisory --- llmfoundry/callbacks/kill_loss_spike_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index cccdf87132..987926cbcd 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -45,7 +45,7 @@ def batch_end(self, state: State, logger: Logger) -> None: # This will be replaced with the hard error LossSpikeError. for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps. Try lowering the learning rate.'}) + destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps.'}) # raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) # Previous step loss was an outlier, current step loss is not. Reset outlier counter. @@ -58,7 +58,7 @@ def batch_end(self, state: State, logger: Logger) -> None: log.info(f'High losses >{self.loss_cap} detected.') for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected. Try lowering the learning rate.'}) + destination.log_metadata({'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected.'}) # raise LossSpikeError() else: From 45eaa0ee8c1f3aedcba87ae4e4c5d447a948eaee Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Wed, 21 Aug 2024 14:32:46 -0700 Subject: [PATCH 33/83] Remove hardcoded const --- llmfoundry/callbacks/kill_loss_spike_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 987926cbcd..654e6ffc05 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -54,7 +54,7 @@ def batch_end(self, state: State, logger: Logger) -> None: self.outlier_counter = 0 # Half of the running losses are greater than our "high loss" threshold, after the first window - elif state.timestamp.batch >= 200 and (sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2): + elif (state.timestamp.batch >= self.window_size * 2) and (sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2): log.info(f'High losses >{self.loss_cap} detected.') for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): From e160e216925d9b365aada70fb0e94d062d6d8b81 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Wed, 21 Aug 2024 15:45:44 -0700 Subject: [PATCH 34/83] Docstring, deque, log_only mode --- .../callbacks/kill_loss_spike_callback.py | 59 +++++++++++++------ 1 file changed, 41 insertions(+), 18 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 654e6ffc05..0038c65439 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -7,22 +7,48 @@ import torch import logging import numpy as np +from collections import deque from composer.core import Callback, State from composer.loggers import Logger, MosaicMLLogger -# from llmfoundry.utils.exceptions import LossSpikeError +from llmfoundry.utils.exceptions import LossSpikeError log = logging.getLogger(__name__) __all__ = ['KillLossSpike'] class KillLossSpike(Callback): + """ + A callback for detecting and handling loss spikes or persistently high training losses during model training. + + Monitors the training loss at the end of each batch and maintains a rolling window of recent losses. + If recent training losses exceed a specified cap or if a significant spike in loss is detected, the callback can either + log a warning (displayed as a message on the run event) or raise a LossSpikeError to stop the run without retry. + + Parameters: + log_only (bool): If True, the callback will only log warnings without interrupting training. If False, a + LossSpikeError will be raised to stop training upon detecting a loss spike or persistently + high loss. + patience (int): The number of consecutive outlier losses tolerated before considering the training loss to be + persistently high. Default is 4 (so 5 consecutive outlier losses will trigger an error). + outlier_multiplier (int): The multiplier used to determine if a loss is an outlier. A loss is considered an + outlier if it is outlier_multiplier times greater than the mean of losses in + the current window. Default is 2. + window_size (int): The size of the rolling window used to track recent losses. Default is 100. + loss_cap (int): The maximum allowable loss. If the training loss consistently exceeds this value, + it is considered a diverging or unstable run. Default is 10. + + Raises: + LossSpikeError: If log_only is False and a loss spike or persistently high loss is detected, this error is + raised to stop the run with an error message. + """ - def __init__(self, patience:int=4, outlier_multiplier:int=2, window_size:int=100, loss_cap:int=10): + def __init__(self, log_only:bool, patience:int=4, outlier_multiplier:int=2, window_size:int=100, loss_cap:int=10): + self.log_only = log_only self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size self.loss_cap = loss_cap self.outlier_counter = 0 - self.loss_window = [] + self.loss_window = deque(maxlen=self.window_size) def batch_end(self, state: State, logger: Logger) -> None: @@ -41,12 +67,12 @@ def batch_end(self, state: State, logger: Logger) -> None: log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') if self.outlier_counter > self.patience: log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') - # NOTE: Adding this info the TRAIN_UPDATED event is temporary to 1) collect data on spiky runs and 2) give users information about their run. - # This will be replaced with the hard error LossSpikeError. - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps.'}) - # raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) + if self.log_only: + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps.'}) + else: + raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) # Previous step loss was an outlier, current step loss is not. Reset outlier counter. elif self.outlier_counter > 0: @@ -56,14 +82,11 @@ def batch_end(self, state: State, logger: Logger) -> None: # Half of the running losses are greater than our "high loss" threshold, after the first window elif (state.timestamp.batch >= self.window_size * 2) and (sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2): log.info(f'High losses >{self.loss_cap} detected.') - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected.'}) - # raise LossSpikeError() - - else: - log.info(f'Full loss window size not reached ({len(self.loss_window)} < {self.window_size}). Collecting loss data...') + if self.log_only: + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata({'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected.'}) + else: + raise LossSpikeError() self.loss_window.append(train_loss) - if len(self.loss_window) > self.window_size: - self.loss_window.pop(0) From 39fdbfb40900ee37f9cb08e3577e1e667d8ab118 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Wed, 21 Aug 2024 16:02:09 -0700 Subject: [PATCH 35/83] edit docstring --- llmfoundry/callbacks/kill_loss_spike_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 0038c65439..b93506dd43 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -17,13 +17,13 @@ class KillLossSpike(Callback): """ - A callback for detecting and handling loss spikes or persistently high training losses during model training. + This callback detects and handles loss spikes or persistently high training losses during model training. Monitors the training loss at the end of each batch and maintains a rolling window of recent losses. If recent training losses exceed a specified cap or if a significant spike in loss is detected, the callback can either log a warning (displayed as a message on the run event) or raise a LossSpikeError to stop the run without retry. - Parameters: + Args: log_only (bool): If True, the callback will only log warnings without interrupting training. If False, a LossSpikeError will be raised to stop training upon detecting a loss spike or persistently high loss. From 8cb2f954065e13dff31452774acb20e304a3edcf Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Wed, 21 Aug 2024 17:04:12 -0700 Subject: [PATCH 36/83] Specific different error for high losses --- llmfoundry/callbacks/kill_loss_spike_callback.py | 2 +- llmfoundry/utils/exceptions.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index b93506dd43..8c7ef952a6 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -87,6 +87,6 @@ def batch_end(self, state: State, logger: Logger) -> None: if isinstance(destination, MosaicMLLogger): destination.log_metadata({'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected.'}) else: - raise LossSpikeError() + raise LossSpikeError(self.loss_cap, self.window_size) self.loss_window.append(train_loss) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 97287f8c7a..7535cd547b 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -391,8 +391,14 @@ def __init__(self, timeout: int) -> None: class LossSpikeError(UserError): """Error thrown a severe loss spike occurs.""" - def __init__(self, outlier_multiplier: int, running_loss_avg: int, outlier_counter: int) -> None: - message = f'Training stopped due to a loss spike. The training loss was {outlier_multiplier} times greater than the \ - running average loss (approx. {running_loss_avg}) over {outlier_counter} consecutive training steps. \ - Please try submitting the run again with a lower learning rate.' - super().__init__(message, outlier_multiplier=outlier_multiplier, running_loss_avg=running_loss_avg, outlier_counter=outlier_counter) + def __init__(self, outlier_multiplier: Optional[int], running_loss_avg: Optional[int], outlier_counter: Optional[int], loss_cap: Optional[int], window_size: Optional[int]) -> None: + if outlier_multiplier and running_loss_avg and outlier_counter: + message = f'Training stopped due to a loss spike. The training loss was {outlier_multiplier} times greater than the \ + running average loss (approx. {running_loss_avg}) over {outlier_counter} consecutive training steps. \ + Please try submitting the run again with a lower learning rate.' + elif loss_cap and window_size: + message = f'Training stopped due to consistently high losses. The training loss exceeded the threshold of {loss_cap} \ + for more than half of the {window_size} most recent training steps. Please try submitting the run again with a lower learning rate.' + else: + message = 'Training stopped due to a loss spike or consistently high losses. Please try submitting the run again with a lower learning rate.' + super().__init__(message, outlier_multiplier=outlier_multiplier, running_loss_avg=running_loss_avg, outlier_counter=outlier_counter, loss_cap=loss_cap, window_size=window_size) From a539866090b29f05cbe1c7a02677da41fe84a1b7 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Wed, 21 Aug 2024 17:26:40 -0700 Subject: [PATCH 37/83] Decompose detection functions --- .../callbacks/kill_loss_spike_callback.py | 50 +++++++++++-------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 8c7ef952a6..41a43f7d32 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -50,6 +50,28 @@ def __init__(self, log_only:bool, patience:int=4, outlier_multiplier:int=2, wind self.outlier_counter = 0 self.loss_window = deque(maxlen=self.window_size) + def detect_loss_spike(self, train_loss: float, running_loss_avg: float): + # Train loss is an outlier + if train_loss > running_loss_avg * self.outlier_multiplier: + self.outlier_counter += 1 + log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') + if self.outlier_counter > self.patience: + log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') + return True + # Previous step loss was an outlier, current step loss is not. Reset outlier counter. + elif self.outlier_counter > 0: + log.info(f'Not a persistent loss spike. Resetting outlier counter.') + self.outlier_counter = 0 + return False + + def detect_high_losses(self, current_step: int): + # Half of the running losses are greater than our "high loss" threshold, after an initial buffer period + if (current_step >= self.window_size * 2) and (sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2): + log.info(f'High losses (train loss consistently greater than {self.loss_cap}) detected.') + return True + return False + + def batch_end(self, state: State, logger: Logger) -> None: if not isinstance(state.loss, torch.Tensor): @@ -61,27 +83,15 @@ def batch_end(self, state: State, logger: Logger) -> None: running_loss_avg = np.mean(self.loss_window) log.info(f'Running loss average: {running_loss_avg}') - # If train loss is an outlier - if train_loss > running_loss_avg * self.outlier_multiplier: - self.outlier_counter += 1 - log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') - if self.outlier_counter > self.patience: - log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') - if self.log_only: - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps.'}) - else: - raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) - - # Previous step loss was an outlier, current step loss is not. Reset outlier counter. - elif self.outlier_counter > 0: - log.info(f'Not a persistent loss spike. Resetting outlier counter.') - self.outlier_counter = 0 + if self.detect_loss_spike(train_loss, running_loss_avg): + if self.log_only: + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps.'}) + else: + raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) - # Half of the running losses are greater than our "high loss" threshold, after the first window - elif (state.timestamp.batch >= self.window_size * 2) and (sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2): - log.info(f'High losses >{self.loss_cap} detected.') + elif self.detect_high_losses(state.timestamp.batch): if self.log_only: for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): From 5da701c230e128bb1a882e5a00f5cd3bef41587e Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 09:32:51 -0700 Subject: [PATCH 38/83] First pass unit tests --- .../test_kill_loss_spike_callback.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/callbacks/test_kill_loss_spike_callback.py diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py new file mode 100644 index 0000000000..c7001ff14f --- /dev/null +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -0,0 +1,43 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +import unittest +from unittest.mock import patch +from llmfoundry.callbacks.kill_loss_spike_callback import KillLossSpike +from collections import deque + +class TestKillLossSpike(unittest.TestCase): + def setUp(self): + self.callback = KillLossSpike(log_only=True, patience=4, outlier_multiplier=2, window_size=10, loss_cap=10) # type: ignore + self.callback.loss_window = deque([2] * 10, maxlen=10) + + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_detect_loss_spike_no_spike(self): + self.callback.outlier_counter = 0 + train_loss = 4 + running_loss_avg = 2 + result = self.callback.detect_loss_spike(train_loss, running_loss_avg) + self.assertFalse(result) + + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_detect_loss_spike_with_spike(self): + self.callback.outlier_counter = 4 # Simulating previous spikes + train_loss = 4 + running_loss_avg = 2 + result = self.callback.detect_loss_spike(train_loss, running_loss_avg) + self.assertTrue(result) + + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_detect_high_losses_no_high_losses(self): + current_step = 21 + result = self.callback.detect_high_losses(current_step) + self.assertFalse(result) + + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_detect_high_losses_with_high_losses(self): + self.callback.loss_window = deque([11] * 10, maxlen=10) # Simulate high losses in loss window + current_step = 21 + result = self.callback.detect_high_losses(current_step) + self.assertTrue(result) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 7fe9c070d3edd84b8673864d36b4a24020e94361 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 10:14:58 -0700 Subject: [PATCH 39/83] Working unit tests --- tests/callbacks/test_kill_loss_spike_callback.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py index c7001ff14f..f817d31be0 100644 --- a/tests/callbacks/test_kill_loss_spike_callback.py +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -8,10 +8,9 @@ class TestKillLossSpike(unittest.TestCase): def setUp(self): self.callback = KillLossSpike(log_only=True, patience=4, outlier_multiplier=2, window_size=10, loss_cap=10) # type: ignore - self.callback.loss_window = deque([2] * 10, maxlen=10) @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') - def test_detect_loss_spike_no_spike(self): + def test_detect_loss_spike_no_spike(self, _): self.callback.outlier_counter = 0 train_loss = 4 running_loss_avg = 2 @@ -19,21 +18,22 @@ def test_detect_loss_spike_no_spike(self): self.assertFalse(result) @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') - def test_detect_loss_spike_with_spike(self): + def test_detect_loss_spike_with_spike(self, _): self.callback.outlier_counter = 4 # Simulating previous spikes - train_loss = 4 + train_loss = 5 running_loss_avg = 2 result = self.callback.detect_loss_spike(train_loss, running_loss_avg) self.assertTrue(result) @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') - def test_detect_high_losses_no_high_losses(self): + def test_detect_high_losses_no_high_losses(self, _): + self.callback.loss_window = deque([2] * 10, maxlen=10) current_step = 21 result = self.callback.detect_high_losses(current_step) self.assertFalse(result) @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') - def test_detect_high_losses_with_high_losses(self): + def test_detect_high_losses_with_high_losses(self, _): self.callback.loss_window = deque([11] * 10, maxlen=10) # Simulate high losses in loss window current_step = 21 result = self.callback.detect_high_losses(current_step) From 3721ed07bad40140ccb9ce34d6c69805e3348f56 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 10:25:48 -0700 Subject: [PATCH 40/83] Nits --- llmfoundry/callbacks/kill_loss_spike_callback.py | 3 +-- tests/callbacks/test_kill_loss_spike_callback.py | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 41a43f7d32..b6b2ed7d64 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -52,7 +52,7 @@ def __init__(self, log_only:bool, patience:int=4, outlier_multiplier:int=2, wind def detect_loss_spike(self, train_loss: float, running_loss_avg: float): # Train loss is an outlier - if train_loss > running_loss_avg * self.outlier_multiplier: + if train_loss >= running_loss_avg * self.outlier_multiplier: self.outlier_counter += 1 log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') if self.outlier_counter > self.patience: @@ -71,7 +71,6 @@ def detect_high_losses(self, current_step: int): return True return False - def batch_end(self, state: State, logger: Logger) -> None: if not isinstance(state.loss, torch.Tensor): diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py index f817d31be0..ca2f409f00 100644 --- a/tests/callbacks/test_kill_loss_spike_callback.py +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -20,7 +20,7 @@ def test_detect_loss_spike_no_spike(self, _): @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') def test_detect_loss_spike_with_spike(self, _): self.callback.outlier_counter = 4 # Simulating previous spikes - train_loss = 5 + train_loss = 4 running_loss_avg = 2 result = self.callback.detect_loss_spike(train_loss, running_loss_avg) self.assertTrue(result) @@ -34,10 +34,10 @@ def test_detect_high_losses_no_high_losses(self, _): @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') def test_detect_high_losses_with_high_losses(self, _): - self.callback.loss_window = deque([11] * 10, maxlen=10) # Simulate high losses in loss window + self.callback.loss_window = deque([9, 8, 7, 6, 5, 11, 12, 13, 14, 15], maxlen=10) # Simulate mix of losses in loss window current_step = 21 result = self.callback.detect_high_losses(current_step) self.assertTrue(result) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() From e062db0d80bf6064720fc17667f711a7ad7e2c9b Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 11:10:45 -0700 Subject: [PATCH 41/83] Type issues --- llmfoundry/callbacks/kill_loss_spike_callback.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index b6b2ed7d64..941104e461 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -79,10 +79,10 @@ def batch_end(self, state: State, logger: Logger) -> None: # Only start early stopping once a full window of loss data if len(self.loss_window) == self.window_size: - running_loss_avg = np.mean(self.loss_window) + running_loss_avg = float(np.mean(self.loss_window)) log.info(f'Running loss average: {running_loss_avg}') - if self.detect_loss_spike(train_loss, running_loss_avg): + if self.detect_loss_spike(self, train_loss, running_loss_avg): if self.log_only: for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): @@ -90,7 +90,7 @@ def batch_end(self, state: State, logger: Logger) -> None: else: raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) - elif self.detect_high_losses(state.timestamp.batch): + elif self.detect_high_losses(self, int(state.timestamp.batch)): if self.log_only: for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): From 8ebec3f77c74e19b8b61cc65055f93835fcce127 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 11:17:00 -0700 Subject: [PATCH 42/83] Init callback --- tests/callbacks/test_kill_loss_spike_callback.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py index ca2f409f00..bacff00bf1 100644 --- a/tests/callbacks/test_kill_loss_spike_callback.py +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -6,8 +6,12 @@ from collections import deque class TestKillLossSpike(unittest.TestCase): + def __init__(self, *args:tuple, **kwargs:dict): + super(TestKillLossSpike, self).__init__(*args, **kwargs) + self.callback = None + def setUp(self): - self.callback = KillLossSpike(log_only=True, patience=4, outlier_multiplier=2, window_size=10, loss_cap=10) # type: ignore + self.callback = KillLossSpike(log_only=True, patience=4, outlier_multiplier=2, window_size=10, loss_cap=10) @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') def test_detect_loss_spike_no_spike(self, _): From df874b56ded34219f63dac8aae310b79051a7b8c Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 11:33:39 -0700 Subject: [PATCH 43/83] init --- tests/callbacks/test_kill_loss_spike_callback.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py index bacff00bf1..482b21ce5d 100644 --- a/tests/callbacks/test_kill_loss_spike_callback.py +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -8,9 +8,6 @@ class TestKillLossSpike(unittest.TestCase): def __init__(self, *args:tuple, **kwargs:dict): super(TestKillLossSpike, self).__init__(*args, **kwargs) - self.callback = None - - def setUp(self): self.callback = KillLossSpike(log_only=True, patience=4, outlier_multiplier=2, window_size=10, loss_cap=10) @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') From 30fceb03a51e692da010f0a01ffd0db377a97de7 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 11:39:12 -0700 Subject: [PATCH 44/83] Missing args --- llmfoundry/callbacks/kill_loss_spike_callback.py | 8 ++++---- llmfoundry/utils/exceptions.py | 2 +- tests/callbacks/test_kill_loss_spike_callback.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 941104e461..988696d8fc 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -82,20 +82,20 @@ def batch_end(self, state: State, logger: Logger) -> None: running_loss_avg = float(np.mean(self.loss_window)) log.info(f'Running loss average: {running_loss_avg}') - if self.detect_loss_spike(self, train_loss, running_loss_avg): + if self.detect_loss_spike(train_loss, running_loss_avg): if self.log_only: for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps.'}) else: - raise LossSpikeError(self.outlier_multiplier, round(running_loss_avg), self.outlier_counter) + raise LossSpikeError(outlier_multiplier=self.outlier_multiplier, running_loss_avg=round(running_loss_avg), outlier_counter=self.outlier_counter) - elif self.detect_high_losses(self, int(state.timestamp.batch)): + elif self.detect_high_losses(int(state.timestamp.batch)): if self.log_only: for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): destination.log_metadata({'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected.'}) else: - raise LossSpikeError(self.loss_cap, self.window_size) + raise LossSpikeError(loss_cap=self.loss_cap, window_size=self.window_size) self.loss_window.append(train_loss) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 7535cd547b..a300ecf506 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -391,7 +391,7 @@ def __init__(self, timeout: int) -> None: class LossSpikeError(UserError): """Error thrown a severe loss spike occurs.""" - def __init__(self, outlier_multiplier: Optional[int], running_loss_avg: Optional[int], outlier_counter: Optional[int], loss_cap: Optional[int], window_size: Optional[int]) -> None: + def __init__(self, outlier_multiplier: Optional[int] = None, running_loss_avg: Optional[int] = None, outlier_counter: Optional[int] = None, loss_cap: Optional[int] = None, window_size: Optional[int] = None) -> None: if outlier_multiplier and running_loss_avg and outlier_counter: message = f'Training stopped due to a loss spike. The training loss was {outlier_multiplier} times greater than the \ running average loss (approx. {running_loss_avg}) over {outlier_counter} consecutive training steps. \ diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py index 482b21ce5d..a6f24aa146 100644 --- a/tests/callbacks/test_kill_loss_spike_callback.py +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -6,7 +6,7 @@ from collections import deque class TestKillLossSpike(unittest.TestCase): - def __init__(self, *args:tuple, **kwargs:dict): + def __init__(self, *args:str, **kwargs:dict): super(TestKillLossSpike, self).__init__(*args, **kwargs) self.callback = KillLossSpike(log_only=True, patience=4, outlier_multiplier=2, window_size=10, loss_cap=10) From 0b6483ac9e997f3751f7aed54d8b69819e46c506 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 12:06:08 -0700 Subject: [PATCH 45/83] Wording change --- llmfoundry/utils/exceptions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index a300ecf506..032045119e 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -393,8 +393,8 @@ class LossSpikeError(UserError): def __init__(self, outlier_multiplier: Optional[int] = None, running_loss_avg: Optional[int] = None, outlier_counter: Optional[int] = None, loss_cap: Optional[int] = None, window_size: Optional[int] = None) -> None: if outlier_multiplier and running_loss_avg and outlier_counter: - message = f'Training stopped due to a loss spike. The training loss was {outlier_multiplier} times greater than the \ - running average loss (approx. {running_loss_avg}) over {outlier_counter} consecutive training steps. \ + message = f'Training stopped due to a loss spike. The training loss was more than {outlier_multiplier} times greater than \ + the running average loss (approx. {running_loss_avg}) over {outlier_counter} consecutive training steps. \ Please try submitting the run again with a lower learning rate.' elif loss_cap and window_size: message = f'Training stopped due to consistently high losses. The training loss exceeded the threshold of {loss_cap} \ From f8fe2d8ac44c0d1fb001eec93fe88f87272c71b3 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 12:59:26 -0700 Subject: [PATCH 46/83] lint --- .../callbacks/kill_loss_spike_callback.py | 64 ++++++++++++++----- llmfoundry/utils/exceptions.py | 21 +++++- .../test_kill_loss_spike_callback.py | 21 ++++-- 3 files changed, 84 insertions(+), 22 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 988696d8fc..f0f6c63ff9 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -4,20 +4,23 @@ """Track training runs for loss spikes or persistently high training loss.""" from __future__ import annotations -import torch import logging -import numpy as np from collections import deque + +import numpy as np +import torch from composer.core import Callback, State from composer.loggers import Logger, MosaicMLLogger + from llmfoundry.utils.exceptions import LossSpikeError + log = logging.getLogger(__name__) __all__ = ['KillLossSpike'] + class KillLossSpike(Callback): - """ - This callback detects and handles loss spikes or persistently high training losses during model training. + """Detects and handles loss spikes or persistently high training losses during model training. Monitors the training loss at the end of each batch and maintains a rolling window of recent losses. If recent training losses exceed a specified cap or if a significant spike in loss is detected, the callback can either @@ -33,15 +36,22 @@ class KillLossSpike(Callback): outlier if it is outlier_multiplier times greater than the mean of losses in the current window. Default is 2. window_size (int): The size of the rolling window used to track recent losses. Default is 100. - loss_cap (int): The maximum allowable loss. If the training loss consistently exceeds this value, + loss_cap (int): The maximum allowable loss. If the training loss consistently exceeds this value, it is considered a diverging or unstable run. Default is 10. Raises: LossSpikeError: If log_only is False and a loss spike or persistently high loss is detected, this error is raised to stop the run with an error message. """ - - def __init__(self, log_only:bool, patience:int=4, outlier_multiplier:int=2, window_size:int=100, loss_cap:int=10): + + def __init__( + self, + log_only: bool, + patience: int = 4, + outlier_multiplier: int = 2, + window_size: int = 100, + loss_cap: int = 10, + ): self.log_only = log_only self.patience = patience self.outlier_multiplier = outlier_multiplier @@ -54,20 +64,29 @@ def detect_loss_spike(self, train_loss: float, running_loss_avg: float): # Train loss is an outlier if train_loss >= running_loss_avg * self.outlier_multiplier: self.outlier_counter += 1 - log.info(f'Potential loss spike detected. Iteration: {self.outlier_counter}') + log.info( + f'Potential loss spike detected. Iteration: {self.outlier_counter}', + ) if self.outlier_counter > self.patience: - log.info(f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.') + log.info( + f'Loss spike detected for {self.outlier_counter} steps. Try lowering the learning rate.', + ) return True # Previous step loss was an outlier, current step loss is not. Reset outlier counter. elif self.outlier_counter > 0: log.info(f'Not a persistent loss spike. Resetting outlier counter.') self.outlier_counter = 0 return False - + def detect_high_losses(self, current_step: int): # Half of the running losses are greater than our "high loss" threshold, after an initial buffer period - if (current_step >= self.window_size * 2) and (sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2): - log.info(f'High losses (train loss consistently greater than {self.loss_cap}) detected.') + if (current_step >= self.window_size * 2) and ( + sum(1 for loss in self.loss_window if loss > self.loss_cap) >= + self.window_size / 2 + ): + log.info( + f'High losses (train loss consistently greater than {self.loss_cap}) detected.', + ) return True return False @@ -86,16 +105,29 @@ def batch_end(self, state: State, logger: Logger) -> None: if self.log_only: for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps.'}) + destination.log_metadata({ + 'loss_spike': + f'Training loss spike detected for {self.outlier_counter} consecutive steps.', + }) else: - raise LossSpikeError(outlier_multiplier=self.outlier_multiplier, running_loss_avg=round(running_loss_avg), outlier_counter=self.outlier_counter) + raise LossSpikeError( + outlier_multiplier=self.outlier_multiplier, + running_loss_avg=round(running_loss_avg), + outlier_counter=self.outlier_counter, + ) elif self.detect_high_losses(int(state.timestamp.batch)): if self.log_only: for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected.'}) + destination.log_metadata({ + 'high_loss': + f'Persistently high (>{self.loss_cap}) training losses detected.', + }) else: - raise LossSpikeError(loss_cap=self.loss_cap, window_size=self.window_size) + raise LossSpikeError( + loss_cap=self.loss_cap, + window_size=self.window_size, + ) self.loss_window.append(train_loss) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 032045119e..4f061d6b34 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -388,17 +388,34 @@ def __init__(self, timeout: int) -> None: message = f'Run timed out after {timeout} seconds.' super().__init__(message, timeout=timeout) + class LossSpikeError(UserError): """Error thrown a severe loss spike occurs.""" - def __init__(self, outlier_multiplier: Optional[int] = None, running_loss_avg: Optional[int] = None, outlier_counter: Optional[int] = None, loss_cap: Optional[int] = None, window_size: Optional[int] = None) -> None: + def __init__( + self, + outlier_multiplier: Optional[int] = None, + running_loss_avg: Optional[int] = None, + outlier_counter: Optional[int] = None, + loss_cap: Optional[int] = None, + window_size: Optional[int] = None, + ) -> None: if outlier_multiplier and running_loss_avg and outlier_counter: message = f'Training stopped due to a loss spike. The training loss was more than {outlier_multiplier} times greater than \ the running average loss (approx. {running_loss_avg}) over {outlier_counter} consecutive training steps. \ Please try submitting the run again with a lower learning rate.' + elif loss_cap and window_size: message = f'Training stopped due to consistently high losses. The training loss exceeded the threshold of {loss_cap} \ for more than half of the {window_size} most recent training steps. Please try submitting the run again with a lower learning rate.' + else: message = 'Training stopped due to a loss spike or consistently high losses. Please try submitting the run again with a lower learning rate.' - super().__init__(message, outlier_multiplier=outlier_multiplier, running_loss_avg=running_loss_avg, outlier_counter=outlier_counter, loss_cap=loss_cap, window_size=window_size) + super().__init__( + message, + outlier_multiplier=outlier_multiplier, + running_loss_avg=running_loss_avg, + outlier_counter=outlier_counter, + loss_cap=loss_cap, + window_size=window_size, + ) diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py index a6f24aa146..34164f73b5 100644 --- a/tests/callbacks/test_kill_loss_spike_callback.py +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -1,14 +1,23 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 import unittest +from collections import deque from unittest.mock import patch + from llmfoundry.callbacks.kill_loss_spike_callback import KillLossSpike -from collections import deque + class TestKillLossSpike(unittest.TestCase): - def __init__(self, *args:str, **kwargs:dict): + + def __init__(self, *args: str, **kwargs: dict): super(TestKillLossSpike, self).__init__(*args, **kwargs) - self.callback = KillLossSpike(log_only=True, patience=4, outlier_multiplier=2, window_size=10, loss_cap=10) + self.callback = KillLossSpike( + log_only=True, + patience=4, + outlier_multiplier=2, + window_size=10, + loss_cap=10, + ) @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') def test_detect_loss_spike_no_spike(self, _): @@ -35,10 +44,14 @@ def test_detect_high_losses_no_high_losses(self, _): @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') def test_detect_high_losses_with_high_losses(self, _): - self.callback.loss_window = deque([9, 8, 7, 6, 5, 11, 12, 13, 14, 15], maxlen=10) # Simulate mix of losses in loss window + self.callback.loss_window = deque( + [9, 8, 7, 6, 5, 11, 12, 13, 14, 15], + maxlen=10, + ) # Simulate mix of losses in loss window current_step = 21 result = self.callback.detect_high_losses(current_step) self.assertTrue(result) + if __name__ == '__main__': unittest.main() From b08b1dce3bf61185ab00e9ede4907908a91feb86 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 13:01:39 -0700 Subject: [PATCH 47/83] docstring --- llmfoundry/callbacks/kill_loss_spike_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index f0f6c63ff9..5c943d9ad9 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -20,7 +20,7 @@ class KillLossSpike(Callback): - """Detects and handles loss spikes or persistently high training losses during model training. + """Detects and handles loss spikes or high losses during training. Monitors the training loss at the end of each batch and maintains a rolling window of recent losses. If recent training losses exceed a specified cap or if a significant spike in loss is detected, the callback can either From b606ca85c1590f78caef437d67b312e3f0562225 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 13:06:32 -0700 Subject: [PATCH 48/83] Default log_only to true --- llmfoundry/callbacks/kill_loss_spike_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 5c943d9ad9..db1125293d 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -29,7 +29,7 @@ class KillLossSpike(Callback): Args: log_only (bool): If True, the callback will only log warnings without interrupting training. If False, a LossSpikeError will be raised to stop training upon detecting a loss spike or persistently - high loss. + high loss. Default is True. patience (int): The number of consecutive outlier losses tolerated before considering the training loss to be persistently high. Default is 4 (so 5 consecutive outlier losses will trigger an error). outlier_multiplier (int): The multiplier used to determine if a loss is an outlier. A loss is considered an @@ -46,7 +46,7 @@ class KillLossSpike(Callback): def __init__( self, - log_only: bool, + log_only: bool = True, patience: int = 4, outlier_multiplier: int = 2, window_size: int = 100, From 378beeca2e1bd685939b84b7662b8d48d06ec083 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 13:12:00 -0700 Subject: [PATCH 49/83] Address comments --- llmfoundry/callbacks/__init__.py | 2 +- .../callbacks/kill_loss_spike_callback.py | 30 +++++++++---------- .../test_kill_loss_spike_callback.py | 4 --- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index 4818617704..fe84efa316 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -22,6 +22,7 @@ from llmfoundry.callbacks.eval_output_logging_callback import EvalOutputLogging from llmfoundry.callbacks.fdiff_callback import FDiffMetrics from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer +from llmfoundry.callbacks.kill_loss_spike_callback import KillLossSpike from llmfoundry.callbacks.log_mbmoe_tok_per_expert_callback import ( MegaBlocksMoE_TokPerExpert, ) @@ -37,7 +38,6 @@ from llmfoundry.callbacks.run_timeout_callback import RunTimeoutCallback from llmfoundry.callbacks.scheduled_gc_callback import ScheduledGarbageCollector from llmfoundry.registry import callbacks, callbacks_with_config -from llmfoundry.callbacks.kill_loss_spike_callback import KillLossSpike callbacks.register('system_metrics_monitor', func=SystemMetricsMonitor) callbacks.register('lr_monitor', func=LRMonitor) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index db1125293d..8874129c54 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -102,14 +102,13 @@ def batch_end(self, state: State, logger: Logger) -> None: log.info(f'Running loss average: {running_loss_avg}') if self.detect_loss_spike(train_loss, running_loss_avg): - if self.log_only: - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({ - 'loss_spike': - f'Training loss spike detected for {self.outlier_counter} consecutive steps.', - }) - else: + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata({ + 'loss_spike': + f'Training loss spike detected for {self.outlier_counter} consecutive steps.', + }) + if not self.log_only: raise LossSpikeError( outlier_multiplier=self.outlier_multiplier, running_loss_avg=round(running_loss_avg), @@ -117,14 +116,13 @@ def batch_end(self, state: State, logger: Logger) -> None: ) elif self.detect_high_losses(int(state.timestamp.batch)): - if self.log_only: - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({ - 'high_loss': - f'Persistently high (>{self.loss_cap}) training losses detected.', - }) - else: + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata({ + 'high_loss': + f'Persistently high (>{self.loss_cap}) training losses detected.', + }) + if not self.log_only: raise LossSpikeError( loss_cap=self.loss_cap, window_size=self.window_size, diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py index 34164f73b5..d492dc4c8f 100644 --- a/tests/callbacks/test_kill_loss_spike_callback.py +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -51,7 +51,3 @@ def test_detect_high_losses_with_high_losses(self, _): current_step = 21 result = self.callback.detect_high_losses(current_step) self.assertTrue(result) - - -if __name__ == '__main__': - unittest.main() From fd08f80430bcb480024f34db804608b2b98bfdcf Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 14:19:47 -0700 Subject: [PATCH 50/83] try fix type issue --- tests/utils/test_exceptions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 097bdf77fb..54b6d35bc6 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -45,6 +45,8 @@ def get_default_value(arg_type: Optional[type] = None): return bool elif arg_type == list[dict[str, Any]]: return [{'key': 'value'}] + elif arg_type == Optional[int]: + return 1 raise ValueError(f'Unsupported arg type: {arg_type}') required_args.pop('self', None) From e759f5a4b763012273da33ecf84b071317b3af04 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 14:50:24 -0700 Subject: [PATCH 51/83] New exception for high loss --- .../callbacks/kill_loss_spike_callback.py | 8 ++--- llmfoundry/utils/exceptions.py | 33 +++++++++++++------ tests/utils/test_exceptions.py | 2 -- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 8874129c54..ec40d53d37 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -12,7 +12,7 @@ from composer.core import Callback, State from composer.loggers import Logger, MosaicMLLogger -from llmfoundry.utils.exceptions import LossSpikeError +from llmfoundry.utils.exceptions import HighLossError, LossSpikeError log = logging.getLogger(__name__) @@ -48,9 +48,9 @@ def __init__( self, log_only: bool = True, patience: int = 4, - outlier_multiplier: int = 2, + outlier_multiplier: float = 2, window_size: int = 100, - loss_cap: int = 10, + loss_cap: float = 10, ): self.log_only = log_only self.patience = patience @@ -123,7 +123,7 @@ def batch_end(self, state: State, logger: Logger) -> None: f'Persistently high (>{self.loss_cap}) training losses detected.', }) if not self.log_only: - raise LossSpikeError( + raise HighLossError( loss_cap=self.loss_cap, window_size=self.window_size, ) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 4f061d6b34..a0e4d2c751 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -394,28 +394,41 @@ class LossSpikeError(UserError): def __init__( self, - outlier_multiplier: Optional[int] = None, - running_loss_avg: Optional[int] = None, - outlier_counter: Optional[int] = None, - loss_cap: Optional[int] = None, - window_size: Optional[int] = None, + outlier_multiplier: float, + running_loss_avg: int, + outlier_counter: int, ) -> None: if outlier_multiplier and running_loss_avg and outlier_counter: message = f'Training stopped due to a loss spike. The training loss was more than {outlier_multiplier} times greater than \ the running average loss (approx. {running_loss_avg}) over {outlier_counter} consecutive training steps. \ Please try submitting the run again with a lower learning rate.' - elif loss_cap and window_size: - message = f'Training stopped due to consistently high losses. The training loss exceeded the threshold of {loss_cap} \ - for more than half of the {window_size} most recent training steps. Please try submitting the run again with a lower learning rate.' - else: - message = 'Training stopped due to a loss spike or consistently high losses. Please try submitting the run again with a lower learning rate.' + message = 'Training stopped due to a loss spike. Please try submitting the run again with a lower learning rate.' super().__init__( message, outlier_multiplier=outlier_multiplier, running_loss_avg=running_loss_avg, outlier_counter=outlier_counter, + ) + + +class HighLossError(UserError): + """Error thrown training loss plateaus or is unstable at a high level.""" + + def __init__( + self, + loss_cap: float, + window_size: int, + ) -> None: + if loss_cap and window_size: + message = f'Training stopped due to consistently high losses. The training loss exceeded the threshold of {loss_cap} \ + for more than half of the {window_size} most recent training steps. Please try submitting the run again with a lower learning rate.' + + else: + message = 'Training stopped due to consistently high losses. Please try submitting the run again with a lower learning rate.' + super().__init__( + message, loss_cap=loss_cap, window_size=window_size, ) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 54b6d35bc6..097bdf77fb 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -45,8 +45,6 @@ def get_default_value(arg_type: Optional[type] = None): return bool elif arg_type == list[dict[str, Any]]: return [{'key': 'value'}] - elif arg_type == Optional[int]: - return 1 raise ValueError(f'Unsupported arg type: {arg_type}') required_args.pop('self', None) From 89250e1d40d33b4756d6e83722fb577fd691dc6f Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 15:15:20 -0700 Subject: [PATCH 52/83] Handle float arg --- tests/utils/test_exceptions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 097bdf77fb..75c50511dd 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -35,6 +35,8 @@ def get_default_value(arg_type: Optional[type] = None): return 'string' elif arg_type == int: return 1 + elif arg_type == float: + return 1.0 elif arg_type == set[str]: return {'set'} elif arg_type == list[str]: From 1154deb4515034eb3ee9dfb5bff16d783d3f0aab Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Thu, 22 Aug 2024 16:31:30 -0700 Subject: [PATCH 53/83] Fixes exceptions, add exp class --- .../callbacks/kill_loss_spike_callback.py | 2 ++ llmfoundry/utils/exceptions.py | 18 ++++++------------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index ec40d53d37..222dd85914 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -13,12 +13,14 @@ from composer.loggers import Logger, MosaicMLLogger from llmfoundry.utils.exceptions import HighLossError, LossSpikeError +from llmfoundry.utils.warnings import experimental_class log = logging.getLogger(__name__) __all__ = ['KillLossSpike'] +@experimental_class('KillLossSpike') class KillLossSpike(Callback): """Detects and handles loss spikes or high losses during training. diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index a0e4d2c751..73951ef19e 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -390,7 +390,7 @@ def __init__(self, timeout: int) -> None: class LossSpikeError(UserError): - """Error thrown a severe loss spike occurs.""" + """Error thrown if a severe loss spike occurs.""" def __init__( self, @@ -398,13 +398,10 @@ def __init__( running_loss_avg: int, outlier_counter: int, ) -> None: - if outlier_multiplier and running_loss_avg and outlier_counter: - message = f'Training stopped due to a loss spike. The training loss was more than {outlier_multiplier} times greater than \ - the running average loss (approx. {running_loss_avg}) over {outlier_counter} consecutive training steps. \ - Please try submitting the run again with a lower learning rate.' + message = f'Training stopped due to a loss spike. The training loss was more than {outlier_multiplier} times greater than \ + the running average loss (approx. {running_loss_avg}) over {outlier_counter} consecutive training steps. \ + Please try submitting the run again with a lower learning rate.' - else: - message = 'Training stopped due to a loss spike. Please try submitting the run again with a lower learning rate.' super().__init__( message, outlier_multiplier=outlier_multiplier, @@ -414,19 +411,16 @@ def __init__( class HighLossError(UserError): - """Error thrown training loss plateaus or is unstable at a high level.""" + """Error thrown if training loss plateaus or is unstable at a high level.""" def __init__( self, loss_cap: float, window_size: int, ) -> None: - if loss_cap and window_size: - message = f'Training stopped due to consistently high losses. The training loss exceeded the threshold of {loss_cap} \ + message = f'Training stopped due to consistently high losses. The training loss exceeded the threshold of {loss_cap} \ for more than half of the {window_size} most recent training steps. Please try submitting the run again with a lower learning rate.' - else: - message = 'Training stopped due to consistently high losses. Please try submitting the run again with a lower learning rate.' super().__init__( message, loss_cap=loss_cap, From 03826b5875c6e5ff028e1ca055e6ee15f02ee56c Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Fri, 23 Aug 2024 11:07:54 -0700 Subject: [PATCH 54/83] Add suggestion back to run event message --- llmfoundry/callbacks/kill_loss_spike_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 222dd85914..adfe38ab7f 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -108,7 +108,7 @@ def batch_end(self, state: State, logger: Logger) -> None: if isinstance(destination, MosaicMLLogger): destination.log_metadata({ 'loss_spike': - f'Training loss spike detected for {self.outlier_counter} consecutive steps.', + f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.', }) if not self.log_only: raise LossSpikeError( @@ -122,7 +122,7 @@ def batch_end(self, state: State, logger: Logger) -> None: if isinstance(destination, MosaicMLLogger): destination.log_metadata({ 'high_loss': - f'Persistently high (>{self.loss_cap}) training losses detected.', + f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.', }) if not self.log_only: raise HighLossError( From 753793e7c53994e43e2937c23fb8062ca93e2a9a Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Fri, 23 Aug 2024 11:09:48 -0700 Subject: [PATCH 55/83] Log the loss window when threshold is met --- llmfoundry/callbacks/kill_loss_spike_callback.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index adfe38ab7f..6fa93b3e69 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -109,6 +109,8 @@ def batch_end(self, state: State, logger: Logger) -> None: destination.log_metadata({ 'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.', + "loss_window": + list(self.loss_window), }) if not self.log_only: raise LossSpikeError( @@ -123,6 +125,8 @@ def batch_end(self, state: State, logger: Logger) -> None: destination.log_metadata({ 'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.', + "loss_window": + list(self.loss_window), }) if not self.log_only: raise HighLossError( From b8f2634ef7016ab08c30ab9e35b19d1f6131d552 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Fri, 23 Aug 2024 11:24:58 -0700 Subject: [PATCH 56/83] Lint --- llmfoundry/callbacks/kill_loss_spike_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 6fa93b3e69..33a1c43913 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -109,7 +109,7 @@ def batch_end(self, state: State, logger: Logger) -> None: destination.log_metadata({ 'loss_spike': f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.', - "loss_window": + 'loss_window': list(self.loss_window), }) if not self.log_only: @@ -125,7 +125,7 @@ def batch_end(self, state: State, logger: Logger) -> None: destination.log_metadata({ 'high_loss': f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.', - "loss_window": + 'loss_window': list(self.loss_window), }) if not self.log_only: From 74d2cb4c6285b4ff293704aac9e4d0f56fd1dd29 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Fri, 23 Aug 2024 15:01:20 -0700 Subject: [PATCH 57/83] Set loss cap to max of first loss window, + only do callback on rank 0 GPU --- llmfoundry/callbacks/kill_loss_spike_callback.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 33a1c43913..67b1816488 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -11,6 +11,7 @@ import torch from composer.core import Callback, State from composer.loggers import Logger, MosaicMLLogger +from composer.utils import dist from llmfoundry.utils.exceptions import HighLossError, LossSpikeError from llmfoundry.utils.warnings import experimental_class @@ -52,13 +53,13 @@ def __init__( patience: int = 4, outlier_multiplier: float = 2, window_size: int = 100, - loss_cap: float = 10, ): + self._enabled = (dist.get_global_rank() == 0) self.log_only = log_only self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size - self.loss_cap = loss_cap + self.loss_cap = None self.outlier_counter = 0 self.loss_window = deque(maxlen=self.window_size) @@ -100,6 +101,12 @@ def batch_end(self, state: State, logger: Logger) -> None: # Only start early stopping once a full window of loss data if len(self.loss_window) == self.window_size: + + current_step = int(state.timestamp.batch) + # Set the loss cap to the maximum loss from the first loss window + if current_step == self.window_size: + self.loss_cap = max(self.loss_window) + running_loss_avg = float(np.mean(self.loss_window)) log.info(f'Running loss average: {running_loss_avg}') @@ -119,7 +126,7 @@ def batch_end(self, state: State, logger: Logger) -> None: outlier_counter=self.outlier_counter, ) - elif self.detect_high_losses(int(state.timestamp.batch)): + elif self.detect_high_losses(current_step): for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): destination.log_metadata({ From 3d74b99fa64e8aa20b2001a30163ed852ea300f1 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Fri, 23 Aug 2024 16:49:05 -0700 Subject: [PATCH 58/83] Specify window size as a fraction of training duration steps --- .../callbacks/kill_loss_spike_callback.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 67b1816488..6864e38493 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -9,7 +9,7 @@ import numpy as np import torch -from composer.core import Callback, State +from composer.core import Callback, State, TimeUnit from composer.loggers import Logger, MosaicMLLogger from composer.utils import dist @@ -20,6 +20,8 @@ __all__ = ['KillLossSpike'] +MIN_WINDOW_SIZE = 100 + @experimental_class('KillLossSpike') class KillLossSpike(Callback): @@ -52,13 +54,12 @@ def __init__( log_only: bool = True, patience: int = 4, outlier_multiplier: float = 2, - window_size: int = 100, ): self._enabled = (dist.get_global_rank() == 0) self.log_only = log_only self.patience = patience self.outlier_multiplier = outlier_multiplier - self.window_size = window_size + self.window_size = None self.loss_cap = None self.outlier_counter = 0 self.loss_window = deque(maxlen=self.window_size) @@ -92,6 +93,16 @@ def detect_high_losses(self, current_step: int): ) return True return False + + def init(self, state: State, logger: Logger) -> None: + """Set the window to a fraction of the total number of training batches. At least 100 steps. + The unit of window size is number of batches.""" + if state.max_duration.unit == TimeUnit.EPOCH: + self.window_size = max(MIN_WINDOW_SIZE, (state.dataloader_len * state.max_duration.value / 20)) + elif state.max_duration.unit == TimeUnit.BATCH: + self.window_size = max(MIN_WINDOW_SIZE, state.max_duration.value / 20) + elif state.max_duration.unit == TimeUnit.TOKEN: + self.window_size = max(MIN_WINDOW_SIZE, state.max_duration.value / 20) def batch_end(self, state: State, logger: Logger) -> None: @@ -101,8 +112,14 @@ def batch_end(self, state: State, logger: Logger) -> None: # Only start early stopping once a full window of loss data if len(self.loss_window) == self.window_size: - + current_step = int(state.timestamp.batch) + # Only applies to if max_duration is set in tokens. If current batch is less than MIN_WINDOW_SIZE + # as set by tokens, we should raise the window size to the MIN_WINDOW_SIZE and continue. + if current_step < MIN_WINDOW_SIZE: + self.window_size = MIN_WINDOW_SIZE + return + # Set the loss cap to the maximum loss from the first loss window if current_step == self.window_size: self.loss_cap = max(self.loss_window) From 10f9ac7447ea529071444f794203cb91a07f6b0c Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Fri, 23 Aug 2024 16:52:46 -0700 Subject: [PATCH 59/83] Formatting stuff --- .../callbacks/kill_loss_spike_callback.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 6864e38493..07dc9679c8 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -93,16 +93,24 @@ def detect_high_losses(self, current_step: int): ) return True return False - + def init(self, state: State, logger: Logger) -> None: - """Set the window to a fraction of the total number of training batches. At least 100 steps. - The unit of window size is number of batches.""" + #Set the window to a fraction of the total number of training batches, minimum 100. if state.max_duration.unit == TimeUnit.EPOCH: - self.window_size = max(MIN_WINDOW_SIZE, (state.dataloader_len * state.max_duration.value / 20)) + self.window_size = max( + MIN_WINDOW_SIZE, + (state.dataloader_len * state.max_duration.value / 20), + ) elif state.max_duration.unit == TimeUnit.BATCH: - self.window_size = max(MIN_WINDOW_SIZE, state.max_duration.value / 20) + self.window_size = max( + MIN_WINDOW_SIZE, + state.max_duration.value / 20, + ) elif state.max_duration.unit == TimeUnit.TOKEN: - self.window_size = max(MIN_WINDOW_SIZE, state.max_duration.value / 20) + self.window_size = max( + MIN_WINDOW_SIZE, + state.max_duration.value / 20, + ) def batch_end(self, state: State, logger: Logger) -> None: @@ -112,7 +120,7 @@ def batch_end(self, state: State, logger: Logger) -> None: # Only start early stopping once a full window of loss data if len(self.loss_window) == self.window_size: - + current_step = int(state.timestamp.batch) # Only applies to if max_duration is set in tokens. If current batch is less than MIN_WINDOW_SIZE # as set by tokens, we should raise the window size to the MIN_WINDOW_SIZE and continue. From 4167a74c5b7dd3e7c67c9a8e5fd8763c64b23b1c Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Fri, 23 Aug 2024 16:54:47 -0700 Subject: [PATCH 60/83] Formatting --- .../callbacks/kill_loss_spike_callback.py | 35 ++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 07dc9679c8..63ee5f53a1 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -6,6 +6,7 @@ import logging from collections import deque +from typing import Optional import numpy as np import torch @@ -54,12 +55,13 @@ def __init__( log_only: bool = True, patience: int = 4, outlier_multiplier: float = 2, + window_size: Optional[int] = None, ): self._enabled = (dist.get_global_rank() == 0) self.log_only = log_only self.patience = patience self.outlier_multiplier = outlier_multiplier - self.window_size = None + self.window_size = window_size self.loss_cap = None self.outlier_counter = 0 self.loss_window = deque(maxlen=self.window_size) @@ -96,21 +98,22 @@ def detect_high_losses(self, current_step: int): def init(self, state: State, logger: Logger) -> None: #Set the window to a fraction of the total number of training batches, minimum 100. - if state.max_duration.unit == TimeUnit.EPOCH: - self.window_size = max( - MIN_WINDOW_SIZE, - (state.dataloader_len * state.max_duration.value / 20), - ) - elif state.max_duration.unit == TimeUnit.BATCH: - self.window_size = max( - MIN_WINDOW_SIZE, - state.max_duration.value / 20, - ) - elif state.max_duration.unit == TimeUnit.TOKEN: - self.window_size = max( - MIN_WINDOW_SIZE, - state.max_duration.value / 20, - ) + if not self.window_size: + if state.max_duration.unit == TimeUnit.EPOCH: + self.window_size = max( + MIN_WINDOW_SIZE, + (state.dataloader_len * state.max_duration.value / 20), + ) + elif state.max_duration.unit == TimeUnit.BATCH: + self.window_size = max( + MIN_WINDOW_SIZE, + state.max_duration.value / 20, + ) + elif state.max_duration.unit == TimeUnit.TOKEN: + self.window_size = max( + MIN_WINDOW_SIZE, + state.max_duration.value / 20, + ) def batch_end(self, state: State, logger: Logger) -> None: From bf220f461592fc5dc2d9bbbbbec35f6324306c4f Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Fri, 23 Aug 2024 17:03:43 -0700 Subject: [PATCH 61/83] fit start? --- llmfoundry/callbacks/kill_loss_spike_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 63ee5f53a1..06d46adcc1 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -96,7 +96,7 @@ def detect_high_losses(self, current_step: int): return True return False - def init(self, state: State, logger: Logger) -> None: + def fit_start(self, state: State, logger: Logger) -> None: #Set the window to a fraction of the total number of training batches, minimum 100. if not self.window_size: if state.max_duration.unit == TimeUnit.EPOCH: From 8ec812c158c35ff233bdc425c19db88a08123302 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Fri, 23 Aug 2024 17:14:51 -0700 Subject: [PATCH 62/83] Logging for window size setting check --- llmfoundry/callbacks/kill_loss_spike_callback.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 06d46adcc1..750d8455c7 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -121,6 +121,8 @@ def batch_end(self, state: State, logger: Logger) -> None: raise NotImplementedError('Multiple losses not supported yet') train_loss = state.loss.item() + log.info(f'Window size: {self.window_size}') + # Only start early stopping once a full window of loss data if len(self.loss_window) == self.window_size: @@ -129,6 +131,7 @@ def batch_end(self, state: State, logger: Logger) -> None: # as set by tokens, we should raise the window size to the MIN_WINDOW_SIZE and continue. if current_step < MIN_WINDOW_SIZE: self.window_size = MIN_WINDOW_SIZE + self.loss_window.append(train_loss) return # Set the loss cap to the maximum loss from the first loss window From 93b0fe895ead331afcfc1ebf3aa593d260330c6c Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Fri, 23 Aug 2024 17:26:23 -0700 Subject: [PATCH 63/83] Add back window_size and loss_cap as optl params --- llmfoundry/callbacks/kill_loss_spike_callback.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 750d8455c7..c29d7abfd5 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -55,14 +55,15 @@ def __init__( log_only: bool = True, patience: int = 4, outlier_multiplier: float = 2, - window_size: Optional[int] = None, + window_size: int = None, + loss_cap: float = None, ): self._enabled = (dist.get_global_rank() == 0) self.log_only = log_only self.patience = patience self.outlier_multiplier = outlier_multiplier self.window_size = window_size - self.loss_cap = None + self.loss_cap = loss_cap self.outlier_counter = 0 self.loss_window = deque(maxlen=self.window_size) From 22141751236855bc814aa302c3fa4491f559cc1f Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Sun, 25 Aug 2024 14:30:19 -0700 Subject: [PATCH 64/83] Test that error is not raised when log_only set to true --- .../callbacks/kill_loss_spike_callback.py | 1 - .../test_kill_loss_spike_callback.py | 35 ++++++++++++++++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index c29d7abfd5..83fd164f31 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -6,7 +6,6 @@ import logging from collections import deque -from typing import Optional import numpy as np import torch diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py index d492dc4c8f..72cf6166c5 100644 --- a/tests/callbacks/test_kill_loss_spike_callback.py +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -2,9 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import unittest from collections import deque -from unittest.mock import patch +from unittest.mock import MagicMock, patch + +import torch +from composer.core import State, Timestamp +from composer.devices import DeviceCPU +from composer.loggers import Logger, MosaicMLLogger from llmfoundry.callbacks.kill_loss_spike_callback import KillLossSpike +from llmfoundry.utils.exceptions import LossSpikeError class TestKillLossSpike(unittest.TestCase): @@ -35,6 +41,33 @@ def test_detect_loss_spike_with_spike(self, _): result = self.callback.detect_loss_spike(train_loss, running_loss_avg) self.assertTrue(result) + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_no_error_raised_with_log_only_true(self, _): + build_tiny_mpt = MagicMock() + build_tiny_mpt.return_value = MagicMock() + state = State( + model=build_tiny_mpt(loss_fn='torch_crossentropy'), + rank_zero_seed=0, + run_name='test_state', + device=DeviceCPU(), + ) + state.loss = torch.tensor(4) + state.timestamp = Timestamp(batch=21) + logger = Logger(state, destinations=[MosaicMLLogger()]) + + # Loss spike detection should trigger + self.callback.outlier_counter = 4 + self.callback.loss_window = deque([2] * 10, maxlen=10) + + result = self.callback.detect_loss_spike(state.loss.item(), 2) + self.assertTrue(result) + + # batch_end should not raise an error due to log_only=True + try: + self.callback.batch_end(state, logger) + except Exception as e: + self.fail(f'batch_end raised an exception {e} with log_only=True') + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') def test_detect_high_losses_no_high_losses(self, _): self.callback.loss_window = deque([2] * 10, maxlen=10) From 57800e6f6405e79ed690de6f854af6c8c54046f5 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Sun, 25 Aug 2024 14:47:21 -0700 Subject: [PATCH 65/83] Update docstring --- llmfoundry/callbacks/kill_loss_spike_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 83fd164f31..f3a852619a 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -40,9 +40,9 @@ class KillLossSpike(Callback): outlier_multiplier (int): The multiplier used to determine if a loss is an outlier. A loss is considered an outlier if it is outlier_multiplier times greater than the mean of losses in the current window. Default is 2. - window_size (int): The size of the rolling window used to track recent losses. Default is 100. + window_size (int): The size of the rolling window used to track recent losses. This is set to 1/20 of the total training batches by default, with a minimum of 100 steps. loss_cap (int): The maximum allowable loss. If the training loss consistently exceeds this value, - it is considered a diverging or unstable run. Default is 10. + it is considered a diverging or unstable run. This is set to the maximum loss from the first window of losses by default. Raises: LossSpikeError: If log_only is False and a loss spike or persistently high loss is detected, this error is From c59e5be3c6d99054c34cc60e50be64811fc7837f Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Sun, 25 Aug 2024 15:00:58 -0700 Subject: [PATCH 66/83] Round fractional window size --- llmfoundry/callbacks/kill_loss_spike_callback.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index f3a852619a..c0c27320ce 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -102,17 +102,17 @@ def fit_start(self, state: State, logger: Logger) -> None: if state.max_duration.unit == TimeUnit.EPOCH: self.window_size = max( MIN_WINDOW_SIZE, - (state.dataloader_len * state.max_duration.value / 20), + round(state.dataloader_len * state.max_duration.value / 20), ) elif state.max_duration.unit == TimeUnit.BATCH: self.window_size = max( MIN_WINDOW_SIZE, - state.max_duration.value / 20, + round(state.max_duration.value / 20), ) elif state.max_duration.unit == TimeUnit.TOKEN: self.window_size = max( MIN_WINDOW_SIZE, - state.max_duration.value / 20, + round(state.max_duration.value / 20), ) def batch_end(self, state: State, logger: Logger) -> None: From 1aeb46f7e40163965ce5a61a87bfc5facd88b467 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Sun, 25 Aug 2024 15:21:36 -0700 Subject: [PATCH 67/83] Window size and loss cap not specifiable --- .../callbacks/kill_loss_spike_callback.py | 41 ++++++++----------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index c0c27320ce..6403ac76c2 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -40,9 +40,9 @@ class KillLossSpike(Callback): outlier_multiplier (int): The multiplier used to determine if a loss is an outlier. A loss is considered an outlier if it is outlier_multiplier times greater than the mean of losses in the current window. Default is 2. - window_size (int): The size of the rolling window used to track recent losses. This is set to 1/20 of the total training batches by default, with a minimum of 100 steps. + window_size (int): The size of the rolling window used to track recent losses. This is set to 1/20 of the total training batches, with a minimum of 100 steps. loss_cap (int): The maximum allowable loss. If the training loss consistently exceeds this value, - it is considered a diverging or unstable run. This is set to the maximum loss from the first window of losses by default. + it is considered a diverging or unstable run. This is set to the maximum loss from the first window of losses. Raises: LossSpikeError: If log_only is False and a loss spike or persistently high loss is detected, this error is @@ -54,17 +54,12 @@ def __init__( log_only: bool = True, patience: int = 4, outlier_multiplier: float = 2, - window_size: int = None, - loss_cap: float = None, ): - self._enabled = (dist.get_global_rank() == 0) + # self._enabled = (dist.get_global_rank() == 0) self.log_only = log_only self.patience = patience self.outlier_multiplier = outlier_multiplier - self.window_size = window_size - self.loss_cap = loss_cap self.outlier_counter = 0 - self.loss_window = deque(maxlen=self.window_size) def detect_loss_spike(self, train_loss: float, running_loss_avg: float): # Train loss is an outlier @@ -97,23 +92,19 @@ def detect_high_losses(self, current_step: int): return False def fit_start(self, state: State, logger: Logger) -> None: - #Set the window to a fraction of the total number of training batches, minimum 100. - if not self.window_size: - if state.max_duration.unit == TimeUnit.EPOCH: - self.window_size = max( - MIN_WINDOW_SIZE, - round(state.dataloader_len * state.max_duration.value / 20), - ) - elif state.max_duration.unit == TimeUnit.BATCH: - self.window_size = max( - MIN_WINDOW_SIZE, - round(state.max_duration.value / 20), - ) - elif state.max_duration.unit == TimeUnit.TOKEN: - self.window_size = max( - MIN_WINDOW_SIZE, - round(state.max_duration.value / 20), - ) + # Set the window size to a fraction of the total number of training batches for the run, minimum 100 batches. + if state.max_duration.unit == TimeUnit.EPOCH: + self.window_size = max( + MIN_WINDOW_SIZE, + round(state.dataloader_len * state.max_duration.value / 20), + ) + elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: + self.window_size = max( + MIN_WINDOW_SIZE, + round(state.max_duration.value / 20), + ) + self.loss_window = deque(maxlen=self.window_size) + def batch_end(self, state: State, logger: Logger) -> None: From 1b5a757c9c3e4af60aa9a3ceeb7cda76a238f393 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Sun, 25 Aug 2024 15:33:40 -0700 Subject: [PATCH 68/83] Adjust tests no window size / loss cap inputs --- llmfoundry/callbacks/kill_loss_spike_callback.py | 11 ++++++----- tests/callbacks/test_kill_loss_spike_callback.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 6403ac76c2..33786fc0de 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -11,7 +11,6 @@ import torch from composer.core import Callback, State, TimeUnit from composer.loggers import Logger, MosaicMLLogger -from composer.utils import dist from llmfoundry.utils.exceptions import HighLossError, LossSpikeError from llmfoundry.utils.warnings import experimental_class @@ -60,6 +59,9 @@ def __init__( self.patience = patience self.outlier_multiplier = outlier_multiplier self.outlier_counter = 0 + self.window_size = None + self.loss_window = None + self.loss_cap = None def detect_loss_spike(self, train_loss: float, running_loss_avg: float): # Train loss is an outlier @@ -96,15 +98,14 @@ def fit_start(self, state: State, logger: Logger) -> None: if state.max_duration.unit == TimeUnit.EPOCH: self.window_size = max( MIN_WINDOW_SIZE, - round(state.dataloader_len * state.max_duration.value / 20), + round(float(state.dataloader_len * state.max_duration.value / 20)), ) elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: self.window_size = max( MIN_WINDOW_SIZE, - round(state.max_duration.value / 20), + round(float(state.max_duration.value / 20)), ) - self.loss_window = deque(maxlen=self.window_size) - + self.loss_window = deque(maxlen=self.window_size) def batch_end(self, state: State, logger: Logger) -> None: diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py index 72cf6166c5..f99a48b26d 100644 --- a/tests/callbacks/test_kill_loss_spike_callback.py +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -21,9 +21,9 @@ def __init__(self, *args: str, **kwargs: dict): log_only=True, patience=4, outlier_multiplier=2, - window_size=10, - loss_cap=10, ) + self.callback.window_size = 10 + self.callback.loss_cap = 10 @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') def test_detect_loss_spike_no_spike(self, _): From e3c00b35fb011b949c3e1c787168c340a91961e0 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Sun, 25 Aug 2024 16:16:52 -0700 Subject: [PATCH 69/83] init vals for window size, loss window, loss cap --- .../callbacks/kill_loss_spike_callback.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 33786fc0de..97557cb123 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -59,9 +59,9 @@ def __init__( self.patience = patience self.outlier_multiplier = outlier_multiplier self.outlier_counter = 0 - self.window_size = None - self.loss_window = None - self.loss_cap = None + self.window_size = MIN_WINDOW_SIZE + self.loss_window = deque(maxlen=self.window_size) + self.loss_cap = float('inf') def detect_loss_spike(self, train_loss: float, running_loss_avg: float): # Train loss is an outlier @@ -95,16 +95,17 @@ def detect_high_losses(self, current_step: int): def fit_start(self, state: State, logger: Logger) -> None: # Set the window size to a fraction of the total number of training batches for the run, minimum 100 batches. - if state.max_duration.unit == TimeUnit.EPOCH: - self.window_size = max( - MIN_WINDOW_SIZE, - round(float(state.dataloader_len * state.max_duration.value / 20)), - ) - elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: - self.window_size = max( - MIN_WINDOW_SIZE, - round(float(state.max_duration.value / 20)), - ) + if state.max_duration is not None: + if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None: + self.window_size = max( + MIN_WINDOW_SIZE, + round(float(state.dataloader_len * state.max_duration.value / 20)), + ) + elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: + self.window_size = max( + MIN_WINDOW_SIZE, + round(float(state.max_duration.value / 20)), + ) self.loss_window = deque(maxlen=self.window_size) def batch_end(self, state: State, logger: Logger) -> None: From 04d20fa0db96826b1f333ad2e592afbdbbd473ce Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Sun, 25 Aug 2024 18:45:32 -0700 Subject: [PATCH 70/83] decompose spike/high loss handling --- .../callbacks/kill_loss_spike_callback.py | 79 +++++++++++-------- .../test_kill_loss_spike_callback.py | 26 ++++++ 2 files changed, 72 insertions(+), 33 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 97557cb123..90bc66b6f4 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -20,6 +20,7 @@ __all__ = ['KillLossSpike'] MIN_WINDOW_SIZE = 100 +MAX_LOSS_CAP = 10 @experimental_class('KillLossSpike') @@ -41,7 +42,7 @@ class KillLossSpike(Callback): the current window. Default is 2. window_size (int): The size of the rolling window used to track recent losses. This is set to 1/20 of the total training batches, with a minimum of 100 steps. loss_cap (int): The maximum allowable loss. If the training loss consistently exceeds this value, - it is considered a diverging or unstable run. This is set to the maximum loss from the first window of losses. + it is considered a diverging or unstable run. This is set to the maximum loss from the first window of losses, with a maximum of 10. Raises: LossSpikeError: If log_only is False and a loss spike or persistently high loss is detected, this error is @@ -61,7 +62,7 @@ def __init__( self.outlier_counter = 0 self.window_size = MIN_WINDOW_SIZE self.loss_window = deque(maxlen=self.window_size) - self.loss_cap = float('inf') + self.loss_cap = MAX_LOSS_CAP def detect_loss_spike(self, train_loss: float, running_loss_avg: float): # Train loss is an outlier @@ -93,17 +94,54 @@ def detect_high_losses(self, current_step: int): return True return False + def handle_loss_spike( + self, logger: Logger, running_loss_avg: float + ) -> None: + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata({ + 'loss_spike': + f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.', + 'loss_window': + list(self.loss_window), + }) + if not self.log_only: + raise LossSpikeError( + outlier_multiplier=self.outlier_multiplier, + running_loss_avg=round(running_loss_avg), + outlier_counter=self.outlier_counter, + ) + + def handle_high_losses(self, logger: Logger) -> None: + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata({ + 'high_loss': + f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.', + 'loss_window': + list(self.loss_window), + }) + if not self.log_only: + raise HighLossError( + loss_cap=self.loss_cap, + window_size=self.window_size, + ) + def fit_start(self, state: State, logger: Logger) -> None: # Set the window size to a fraction of the total number of training batches for the run, minimum 100 batches. if state.max_duration is not None: if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None: self.window_size = max( - MIN_WINDOW_SIZE, - round(float(state.dataloader_len * state.max_duration.value / 20)), + self.window_size, + round( + float( + state.dataloader_len * state.max_duration.value / 20 + ) + ), ) elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: self.window_size = max( - MIN_WINDOW_SIZE, + self.window_size, round(float(state.max_duration.value / 20)), ) self.loss_window = deque(maxlen=self.window_size) @@ -129,40 +167,15 @@ def batch_end(self, state: State, logger: Logger) -> None: # Set the loss cap to the maximum loss from the first loss window if current_step == self.window_size: - self.loss_cap = max(self.loss_window) + self.loss_cap = max(max(self.loss_window), self.loss_cap) running_loss_avg = float(np.mean(self.loss_window)) log.info(f'Running loss average: {running_loss_avg}') if self.detect_loss_spike(train_loss, running_loss_avg): - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({ - 'loss_spike': - f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.', - 'loss_window': - list(self.loss_window), - }) - if not self.log_only: - raise LossSpikeError( - outlier_multiplier=self.outlier_multiplier, - running_loss_avg=round(running_loss_avg), - outlier_counter=self.outlier_counter, - ) + self.handle_loss_spike(logger, running_loss_avg) elif self.detect_high_losses(current_step): - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({ - 'high_loss': - f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.', - 'loss_window': - list(self.loss_window), - }) - if not self.log_only: - raise HighLossError( - loss_cap=self.loss_cap, - window_size=self.window_size, - ) + self.handle_high_losses(logger) self.loss_window.append(train_loss) diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py index f99a48b26d..21c2b939a8 100644 --- a/tests/callbacks/test_kill_loss_spike_callback.py +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -68,6 +68,32 @@ def test_no_error_raised_with_log_only_true(self, _): except Exception as e: self.fail(f'batch_end raised an exception {e} with log_only=True') + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_error_raised_with_log_only_false(self, _): + build_tiny_mpt = MagicMock() + build_tiny_mpt.return_value = MagicMock() + state = State( + model=build_tiny_mpt(loss_fn='torch_crossentropy'), + rank_zero_seed=0, + run_name='test_state', + device=DeviceCPU(), + ) + state.loss = torch.tensor(4) + state.timestamp = Timestamp(batch=21) + logger = Logger(state, destinations=[MosaicMLLogger()]) + + # Loss spike detection should trigger + self.callback.outlier_counter = 4 + self.callback.loss_window = deque([2] * 10, maxlen=10) + self.callback.log_only = False + + result = self.callback.detect_loss_spike(state.loss.item(), 2) + self.assertTrue(result) + + # batch_end should raise an error due to log_only=False + with self.assertRaises(LossSpikeError): + self.callback.batch_end(state, logger) + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') def test_detect_high_losses_no_high_losses(self, _): self.callback.loss_window = deque([2] * 10, maxlen=10) From 4658a8f82f961f7afbdad8ba43c5411e3b35fec1 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Sun, 25 Aug 2024 18:59:41 -0700 Subject: [PATCH 71/83] decompose logging and rename helpers --- .../callbacks/kill_loss_spike_callback.py | 34 ++++++++----------- .../test_kill_loss_spike_callback.py | 12 +++---- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 90bc66b6f4..4917cf7ef4 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -64,7 +64,7 @@ def __init__( self.loss_window = deque(maxlen=self.window_size) self.loss_cap = MAX_LOSS_CAP - def detect_loss_spike(self, train_loss: float, running_loss_avg: float): + def _detect_loss_spike(self, train_loss: float, running_loss_avg: float): # Train loss is an outlier if train_loss >= running_loss_avg * self.outlier_multiplier: self.outlier_counter += 1 @@ -82,7 +82,7 @@ def detect_loss_spike(self, train_loss: float, running_loss_avg: float): self.outlier_counter = 0 return False - def detect_high_losses(self, current_step: int): + def _detect_high_losses(self, current_step: int): # Half of the running losses are greater than our "high loss" threshold, after an initial buffer period if (current_step >= self.window_size * 2) and ( sum(1 for loss in self.loss_window if loss > self.loss_cap) >= @@ -93,18 +93,20 @@ def detect_high_losses(self, current_step: int): ) return True return False - - def handle_loss_spike( - self, logger: Logger, running_loss_avg: float - ) -> None: + + def _log_metadata(self, logger: Logger, key: str, message: str) -> None: for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): destination.log_metadata({ - 'loss_spike': - f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.', - 'loss_window': - list(self.loss_window), + key: message, + 'loss_window': list(self.loss_window), }) + + def _handle_loss_spike( + self, logger: Logger, running_loss_avg: float + ) -> None: + message = f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.' + self._log_metadata(logger, 'loss_spike', message) if not self.log_only: raise LossSpikeError( outlier_multiplier=self.outlier_multiplier, @@ -112,15 +114,9 @@ def handle_loss_spike( outlier_counter=self.outlier_counter, ) - def handle_high_losses(self, logger: Logger) -> None: - for destination in logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({ - 'high_loss': - f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.', - 'loss_window': - list(self.loss_window), - }) + def _handle_high_losses(self, logger: Logger) -> None: + message = f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.' + self._log_metadata(logger, 'high_loss', message) if not self.log_only: raise HighLossError( loss_cap=self.loss_cap, diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py index 21c2b939a8..eeaa36f751 100644 --- a/tests/callbacks/test_kill_loss_spike_callback.py +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -30,7 +30,7 @@ def test_detect_loss_spike_no_spike(self, _): self.callback.outlier_counter = 0 train_loss = 4 running_loss_avg = 2 - result = self.callback.detect_loss_spike(train_loss, running_loss_avg) + result = self.callback._detect_loss_spike(train_loss, running_loss_avg) self.assertFalse(result) @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') @@ -38,7 +38,7 @@ def test_detect_loss_spike_with_spike(self, _): self.callback.outlier_counter = 4 # Simulating previous spikes train_loss = 4 running_loss_avg = 2 - result = self.callback.detect_loss_spike(train_loss, running_loss_avg) + result = self.callback._detect_loss_spike(train_loss, running_loss_avg) self.assertTrue(result) @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') @@ -59,7 +59,7 @@ def test_no_error_raised_with_log_only_true(self, _): self.callback.outlier_counter = 4 self.callback.loss_window = deque([2] * 10, maxlen=10) - result = self.callback.detect_loss_spike(state.loss.item(), 2) + result = self.callback._detect_loss_spike(state.loss.item(), 2) self.assertTrue(result) # batch_end should not raise an error due to log_only=True @@ -87,7 +87,7 @@ def test_error_raised_with_log_only_false(self, _): self.callback.loss_window = deque([2] * 10, maxlen=10) self.callback.log_only = False - result = self.callback.detect_loss_spike(state.loss.item(), 2) + result = self.callback._detect_loss_spike(state.loss.item(), 2) self.assertTrue(result) # batch_end should raise an error due to log_only=False @@ -98,7 +98,7 @@ def test_error_raised_with_log_only_false(self, _): def test_detect_high_losses_no_high_losses(self, _): self.callback.loss_window = deque([2] * 10, maxlen=10) current_step = 21 - result = self.callback.detect_high_losses(current_step) + result = self.callback._detect_high_losses(current_step) self.assertFalse(result) @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') @@ -108,5 +108,5 @@ def test_detect_high_losses_with_high_losses(self, _): maxlen=10, ) # Simulate mix of losses in loss window current_step = 21 - result = self.callback.detect_high_losses(current_step) + result = self.callback._detect_high_losses(current_step) self.assertTrue(result) From 55dd5bb59b07ae832871b2f137cc23eaa5e6398f Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Sun, 25 Aug 2024 23:01:45 -0700 Subject: [PATCH 72/83] Test error raised/not raised depending on log_only --- .../callbacks/kill_loss_spike_callback.py | 20 +++--- .../test_kill_loss_spike_callback.py | 63 +++++-------------- 2 files changed, 26 insertions(+), 57 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 4917cf7ef4..9cc1f6204d 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -21,6 +21,7 @@ MIN_WINDOW_SIZE = 100 MAX_LOSS_CAP = 10 +WINDOW_FRACTION = 20 @experimental_class('KillLossSpike') @@ -93,7 +94,7 @@ def _detect_high_losses(self, current_step: int): ) return True return False - + def _log_metadata(self, logger: Logger, key: str, message: str) -> None: for destination in logger.destinations: if isinstance(destination, MosaicMLLogger): @@ -103,7 +104,9 @@ def _log_metadata(self, logger: Logger, key: str, message: str) -> None: }) def _handle_loss_spike( - self, logger: Logger, running_loss_avg: float + self, + logger: Logger, + running_loss_avg: float, ) -> None: message = f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.' self._log_metadata(logger, 'loss_spike', message) @@ -131,8 +134,9 @@ def fit_start(self, state: State, logger: Logger) -> None: self.window_size, round( float( - state.dataloader_len * state.max_duration.value / 20 - ) + state.dataloader_len * state.max_duration.value / + 20, + ), ), ) elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: @@ -168,10 +172,10 @@ def batch_end(self, state: State, logger: Logger) -> None: running_loss_avg = float(np.mean(self.loss_window)) log.info(f'Running loss average: {running_loss_avg}') - if self.detect_loss_spike(train_loss, running_loss_avg): - self.handle_loss_spike(logger, running_loss_avg) + if self._detect_loss_spike(train_loss, running_loss_avg): + self._handle_loss_spike(logger, running_loss_avg) - elif self.detect_high_losses(current_step): - self.handle_high_losses(logger) + elif self._detect_high_losses(current_step): + self._handle_high_losses(logger) self.loss_window.append(train_loss) diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py index eeaa36f751..a8746299e8 100644 --- a/tests/callbacks/test_kill_loss_spike_callback.py +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -4,11 +4,6 @@ from collections import deque from unittest.mock import MagicMock, patch -import torch -from composer.core import State, Timestamp -from composer.devices import DeviceCPU -from composer.loggers import Logger, MosaicMLLogger - from llmfoundry.callbacks.kill_loss_spike_callback import KillLossSpike from llmfoundry.utils.exceptions import LossSpikeError @@ -42,57 +37,27 @@ def test_detect_loss_spike_with_spike(self, _): self.assertTrue(result) @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') - def test_no_error_raised_with_log_only_true(self, _): - build_tiny_mpt = MagicMock() - build_tiny_mpt.return_value = MagicMock() - state = State( - model=build_tiny_mpt(loss_fn='torch_crossentropy'), - rank_zero_seed=0, - run_name='test_state', - device=DeviceCPU(), - ) - state.loss = torch.tensor(4) - state.timestamp = Timestamp(batch=21) - logger = Logger(state, destinations=[MosaicMLLogger()]) - - # Loss spike detection should trigger - self.callback.outlier_counter = 4 - self.callback.loss_window = deque([2] * 10, maxlen=10) - - result = self.callback._detect_loss_spike(state.loss.item(), 2) - self.assertTrue(result) + def test_handle_loss_spike_logs_only_when_log_only_true(self, _): + logger = MagicMock() + running_loss_avg = 2 + self.callback.log_only = True + self.callback.outlier_counter = 5 - # batch_end should not raise an error due to log_only=True try: - self.callback.batch_end(state, logger) - except Exception as e: - self.fail(f'batch_end raised an exception {e} with log_only=True') + self.callback._handle_loss_spike(logger, running_loss_avg) + except LossSpikeError: + self.fail('LossSpikeError was raised unexpectedly') @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') - def test_error_raised_with_log_only_false(self, _): - build_tiny_mpt = MagicMock() - build_tiny_mpt.return_value = MagicMock() - state = State( - model=build_tiny_mpt(loss_fn='torch_crossentropy'), - rank_zero_seed=0, - run_name='test_state', - device=DeviceCPU(), - ) - state.loss = torch.tensor(4) - state.timestamp = Timestamp(batch=21) - logger = Logger(state, destinations=[MosaicMLLogger()]) - - # Loss spike detection should trigger - self.callback.outlier_counter = 4 - self.callback.loss_window = deque([2] * 10, maxlen=10) + def test_handle_loss_spike_raises_error_log_only_false(self, _): + logger = MagicMock() + running_loss_avg = 2 self.callback.log_only = False + self.callback.outlier_counter = 5 - result = self.callback._detect_loss_spike(state.loss.item(), 2) - self.assertTrue(result) - - # batch_end should raise an error due to log_only=False + # LossSpikeError is raised with self.assertRaises(LossSpikeError): - self.callback.batch_end(state, logger) + self.callback._handle_loss_spike(logger, running_loss_avg) @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') def test_detect_high_losses_no_high_losses(self, _): From 5b044fa0720091be454ea191c3aba5bd0bb0facc Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Sun, 25 Aug 2024 23:36:55 -0700 Subject: [PATCH 73/83] Cleanup window frac --- .../callbacks/kill_loss_spike_callback.py | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 9cc1f6204d..dc5ae1ed5c 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -21,7 +21,7 @@ MIN_WINDOW_SIZE = 100 MAX_LOSS_CAP = 10 -WINDOW_FRACTION = 20 +WINDOW_FRACTION = 0.05 @experimental_class('KillLossSpike') @@ -65,7 +65,11 @@ def __init__( self.loss_window = deque(maxlen=self.window_size) self.loss_cap = MAX_LOSS_CAP - def _detect_loss_spike(self, train_loss: float, running_loss_avg: float): + def _detect_loss_spike( + self, + train_loss: float, + running_loss_avg: float, + ) -> bool: # Train loss is an outlier if train_loss >= running_loss_avg * self.outlier_multiplier: self.outlier_counter += 1 @@ -83,17 +87,22 @@ def _detect_loss_spike(self, train_loss: float, running_loss_avg: float): self.outlier_counter = 0 return False - def _detect_high_losses(self, current_step: int): + def _detect_high_losses(self, current_step: int) -> bool: + if current_step < self.window_size * 2: + return False + # Half of the running losses are greater than our "high loss" threshold, after an initial buffer period - if (current_step >= self.window_size * 2) and ( - sum(1 for loss in self.loss_window if loss > self.loss_cap) >= - self.window_size / 2 - ): + high_loss_count = sum( + 1 for loss in self.loss_window if loss > self.loss_cap + ) + is_high_loss = high_loss_count >= self.window_size / 2 + + if is_high_loss: log.info( - f'High losses (train loss consistently greater than {self.loss_cap}) detected.', + f'High losses detected: {high_loss_count}/{self.window_size} losses above {self.loss_cap}.', ) - return True - return False + + return is_high_loss def _log_metadata(self, logger: Logger, key: str, message: str) -> None: for destination in logger.destinations: @@ -134,15 +143,15 @@ def fit_start(self, state: State, logger: Logger) -> None: self.window_size, round( float( - state.dataloader_len * state.max_duration.value / - 20, + state.dataloader_len * state.max_duration.value * + WINDOW_FRACTION, ), ), ) elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: self.window_size = max( self.window_size, - round(float(state.max_duration.value / 20)), + round(float(state.max_duration.value * WINDOW_FRACTION)), ) self.loss_window = deque(maxlen=self.window_size) From 95bed00727e86b7a65e5151cc3ae1fdc0f5fba36 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Mon, 26 Aug 2024 11:14:15 -0700 Subject: [PATCH 74/83] Clean up fit_start & callback rank 0 only --- .../callbacks/kill_loss_spike_callback.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index dc5ae1ed5c..c90a85a42f 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -11,6 +11,7 @@ import torch from composer.core import Callback, State, TimeUnit from composer.loggers import Logger, MosaicMLLogger +from composer.utils import dist from llmfoundry.utils.exceptions import HighLossError, LossSpikeError from llmfoundry.utils.warnings import experimental_class @@ -56,7 +57,7 @@ def __init__( patience: int = 4, outlier_multiplier: float = 2, ): - # self._enabled = (dist.get_global_rank() == 0) + self._enabled = True self.log_only = log_only self.patience = patience self.outlier_multiplier = outlier_multiplier @@ -139,21 +140,15 @@ def fit_start(self, state: State, logger: Logger) -> None: # Set the window size to a fraction of the total number of training batches for the run, minimum 100 batches. if state.max_duration is not None: if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None: - self.window_size = max( - self.window_size, - round( - float( - state.dataloader_len * state.max_duration.value * - WINDOW_FRACTION, - ), - ), - ) + total_training_steps = state.dataloader_len * state.max_duration.value elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: - self.window_size = max( - self.window_size, - round(float(state.max_duration.value * WINDOW_FRACTION)), - ) + total_training_steps = state.max_duration.value + self.window_size = max( + self.window_size, + round(float(total_training_steps * WINDOW_FRACTION)), + ) self.loss_window = deque(maxlen=self.window_size) + self._enabled = (dist.get_global_rank() == 0) def batch_end(self, state: State, logger: Logger) -> None: From 50a6b5efbc2a0b702073d1236dfd823bae0aa6aa Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Mon, 26 Aug 2024 12:32:54 -0700 Subject: [PATCH 75/83] Try moving self._enabled setting --- llmfoundry/callbacks/kill_loss_spike_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index c90a85a42f..ac9a585538 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -57,7 +57,7 @@ def __init__( patience: int = 4, outlier_multiplier: float = 2, ): - self._enabled = True + self._enabled = (dist.get_global_rank() == 0) self.log_only = log_only self.patience = patience self.outlier_multiplier = outlier_multiplier @@ -138,6 +138,7 @@ def _handle_high_losses(self, logger: Logger) -> None: def fit_start(self, state: State, logger: Logger) -> None: # Set the window size to a fraction of the total number of training batches for the run, minimum 100 batches. + total_training_steps = 0 if state.max_duration is not None: if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None: total_training_steps = state.dataloader_len * state.max_duration.value @@ -148,7 +149,6 @@ def fit_start(self, state: State, logger: Logger) -> None: round(float(total_training_steps * WINDOW_FRACTION)), ) self.loss_window = deque(maxlen=self.window_size) - self._enabled = (dist.get_global_rank() == 0) def batch_end(self, state: State, logger: Logger) -> None: From 9f372811f00eb8eb16a73f6ca9230814b7b9e543 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Mon, 26 Aug 2024 12:52:48 -0700 Subject: [PATCH 76/83] Private consts --- .../callbacks/kill_loss_spike_callback.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index ac9a585538..1a40e89fd7 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -20,9 +20,9 @@ __all__ = ['KillLossSpike'] -MIN_WINDOW_SIZE = 100 -MAX_LOSS_CAP = 10 -WINDOW_FRACTION = 0.05 +_MIN_WINDOW_SIZE = 100 +_MAX_LOSS_CAP = 10 +_WINDOW_FRACTION = 0.05 @experimental_class('KillLossSpike') @@ -62,9 +62,9 @@ def __init__( self.patience = patience self.outlier_multiplier = outlier_multiplier self.outlier_counter = 0 - self.window_size = MIN_WINDOW_SIZE + self.window_size = _MIN_WINDOW_SIZE self.loss_window = deque(maxlen=self.window_size) - self.loss_cap = MAX_LOSS_CAP + self.loss_cap = _MAX_LOSS_CAP def _detect_loss_spike( self, @@ -146,7 +146,7 @@ def fit_start(self, state: State, logger: Logger) -> None: total_training_steps = state.max_duration.value self.window_size = max( self.window_size, - round(float(total_training_steps * WINDOW_FRACTION)), + round(float(total_training_steps * _WINDOW_FRACTION)), ) self.loss_window = deque(maxlen=self.window_size) @@ -163,9 +163,9 @@ def batch_end(self, state: State, logger: Logger) -> None: current_step = int(state.timestamp.batch) # Only applies to if max_duration is set in tokens. If current batch is less than MIN_WINDOW_SIZE - # as set by tokens, we should raise the window size to the MIN_WINDOW_SIZE and continue. - if current_step < MIN_WINDOW_SIZE: - self.window_size = MIN_WINDOW_SIZE + # as set by tokens, we should raise the window size to the MINWINDOW_SIZE and continue. + if current_step < _MIN_WINDOW_SIZE: + self.window_size = _MIN_WINDOW_SIZE self.loss_window.append(train_loss) return From acd46f4bcdfcaa6bf7596e045ff9bd3e8074b3a9 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Mon, 26 Aug 2024 13:29:30 -0700 Subject: [PATCH 77/83] Add absolute loss cap and loss window as arguments --- .../callbacks/kill_loss_spike_callback.py | 46 +++++++++++-------- tests/utils/test_exceptions.py | 4 ++ 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 1a40e89fd7..3ee55a98d4 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -6,6 +6,7 @@ import logging from collections import deque +from typing import Optional import numpy as np import torch @@ -56,15 +57,19 @@ def __init__( log_only: bool = True, patience: int = 4, outlier_multiplier: float = 2, + user_window_size: Optional[int] = None, + user_loss_cap: Optional[float] = None, ): self._enabled = (dist.get_global_rank() == 0) self.log_only = log_only self.patience = patience self.outlier_multiplier = outlier_multiplier self.outlier_counter = 0 - self.window_size = _MIN_WINDOW_SIZE + self.user_defined_window_size = user_window_size is not None + self.window_size = user_window_size or _MIN_WINDOW_SIZE self.loss_window = deque(maxlen=self.window_size) - self.loss_cap = _MAX_LOSS_CAP + self.user_defined_loss_cap = user_loss_cap is not None + self.loss_cap = user_loss_cap or _MAX_LOSS_CAP def _detect_loss_spike( self, @@ -137,17 +142,18 @@ def _handle_high_losses(self, logger: Logger) -> None: ) def fit_start(self, state: State, logger: Logger) -> None: - # Set the window size to a fraction of the total number of training batches for the run, minimum 100 batches. - total_training_steps = 0 - if state.max_duration is not None: - if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None: - total_training_steps = state.dataloader_len * state.max_duration.value - elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: - total_training_steps = state.max_duration.value - self.window_size = max( - self.window_size, - round(float(total_training_steps * _WINDOW_FRACTION)), - ) + # If user does not provide a window size, set window size to a fraction of the total number of training batches for the run, minimum 100 batches. + if not self.user_defined_window_size: + total_training_steps = 0 + if state.max_duration is not None: + if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None: + total_training_steps = state.dataloader_len * state.max_duration.value + elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: + total_training_steps = state.max_duration.value + self.window_size = max( + self.window_size, + round(float(total_training_steps * _WINDOW_FRACTION)), + ) self.loss_window = deque(maxlen=self.window_size) def batch_end(self, state: State, logger: Logger) -> None: @@ -162,16 +168,18 @@ def batch_end(self, state: State, logger: Logger) -> None: if len(self.loss_window) == self.window_size: current_step = int(state.timestamp.batch) - # Only applies to if max_duration is set in tokens. If current batch is less than MIN_WINDOW_SIZE - # as set by tokens, we should raise the window size to the MINWINDOW_SIZE and continue. - if current_step < _MIN_WINDOW_SIZE: + # Only applies if max_duration is set in tokens and user does not provide window size. If current batch is less than MIN_WINDOW_SIZE as set by tokens, we should raise the window size to the MIN_WINDOW_SIZE and continue. + if not self.user_defined_window_size and current_step < _MIN_WINDOW_SIZE: self.window_size = _MIN_WINDOW_SIZE + self.loss_window = deque( + self.loss_window, maxlen=self.window_size + ) self.loss_window.append(train_loss) return - # Set the loss cap to the maximum loss from the first loss window - if current_step == self.window_size: - self.loss_cap = max(max(self.loss_window), self.loss_cap) + # If user does not provide a loss cap, set loss cap to the maximum loss from the first loss window. Hard cap at loss=10. + if not self.user_defined_loss_cap and current_step == self.window_size: + self.loss_cap = min(max(self.loss_window), self.loss_cap) running_loss_avg = float(np.mean(self.loss_window)) log.info(f'Running loss average: {running_loss_avg}') diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 75c50511dd..bfd13216a5 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -35,8 +35,12 @@ def get_default_value(arg_type: Optional[type] = None): return 'string' elif arg_type == int: return 1 + elif arg_type == Optional[int]: + return None elif arg_type == float: return 1.0 + elif arg_type == Optional[float]: + return None elif arg_type == set[str]: return {'set'} elif arg_type == list[str]: From 782316b8a7df17d21769af5d2438568be88c36a6 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Mon, 26 Aug 2024 13:56:43 -0700 Subject: [PATCH 78/83] Take out optional type --- .../callbacks/kill_loss_spike_callback.py | 28 ++++++------------- tests/utils/test_exceptions.py | 4 --- 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 3ee55a98d4..12b2abaed9 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -6,7 +6,6 @@ import logging from collections import deque -from typing import Optional import numpy as np import torch @@ -57,19 +56,19 @@ def __init__( log_only: bool = True, patience: int = 4, outlier_multiplier: float = 2, - user_window_size: Optional[int] = None, - user_loss_cap: Optional[float] = None, + window_size: int = _MIN_WINDOW_SIZE, + loss_cap: float = _MAX_LOSS_CAP, ): self._enabled = (dist.get_global_rank() == 0) self.log_only = log_only self.patience = patience self.outlier_multiplier = outlier_multiplier self.outlier_counter = 0 - self.user_defined_window_size = user_window_size is not None - self.window_size = user_window_size or _MIN_WINDOW_SIZE + self.user_defined_window_size = (window_size != _MIN_WINDOW_SIZE) + self.window_size = window_size self.loss_window = deque(maxlen=self.window_size) - self.user_defined_loss_cap = user_loss_cap is not None - self.loss_cap = user_loss_cap or _MAX_LOSS_CAP + self.user_defined_loss_cap = (loss_cap != _MAX_LOSS_CAP) + self.loss_cap = loss_cap def _detect_loss_spike( self, @@ -148,13 +147,14 @@ def fit_start(self, state: State, logger: Logger) -> None: if state.max_duration is not None: if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None: total_training_steps = state.dataloader_len * state.max_duration.value - elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: + elif state.max_duration.unit == TimeUnit.BATCH: total_training_steps = state.max_duration.value self.window_size = max( self.window_size, round(float(total_training_steps * _WINDOW_FRACTION)), ) self.loss_window = deque(maxlen=self.window_size) + log.info(f'Window size set to: {self.window_size}') def batch_end(self, state: State, logger: Logger) -> None: @@ -162,21 +162,10 @@ def batch_end(self, state: State, logger: Logger) -> None: raise NotImplementedError('Multiple losses not supported yet') train_loss = state.loss.item() - log.info(f'Window size: {self.window_size}') - # Only start early stopping once a full window of loss data if len(self.loss_window) == self.window_size: current_step = int(state.timestamp.batch) - # Only applies if max_duration is set in tokens and user does not provide window size. If current batch is less than MIN_WINDOW_SIZE as set by tokens, we should raise the window size to the MIN_WINDOW_SIZE and continue. - if not self.user_defined_window_size and current_step < _MIN_WINDOW_SIZE: - self.window_size = _MIN_WINDOW_SIZE - self.loss_window = deque( - self.loss_window, maxlen=self.window_size - ) - self.loss_window.append(train_loss) - return - # If user does not provide a loss cap, set loss cap to the maximum loss from the first loss window. Hard cap at loss=10. if not self.user_defined_loss_cap and current_step == self.window_size: self.loss_cap = min(max(self.loss_window), self.loss_cap) @@ -186,7 +175,6 @@ def batch_end(self, state: State, logger: Logger) -> None: if self._detect_loss_spike(train_loss, running_loss_avg): self._handle_loss_spike(logger, running_loss_avg) - elif self._detect_high_losses(current_step): self._handle_high_losses(logger) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index bfd13216a5..75c50511dd 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -35,12 +35,8 @@ def get_default_value(arg_type: Optional[type] = None): return 'string' elif arg_type == int: return 1 - elif arg_type == Optional[int]: - return None elif arg_type == float: return 1.0 - elif arg_type == Optional[float]: - return None elif arg_type == set[str]: return {'set'} elif arg_type == list[str]: From 187c8617bb06f940b9cf31a8e7412b126bfa0168 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Mon, 26 Aug 2024 17:29:45 -0700 Subject: [PATCH 79/83] Calculate the loss window once hit min loss window size --- .../callbacks/kill_loss_spike_callback.py | 44 +++++++++++++------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 12b2abaed9..39e5e5a914 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -64,9 +64,9 @@ def __init__( self.patience = patience self.outlier_multiplier = outlier_multiplier self.outlier_counter = 0 - self.user_defined_window_size = (window_size != _MIN_WINDOW_SIZE) + self.window_size_set = (window_size != _MIN_WINDOW_SIZE) self.window_size = window_size - self.loss_window = deque(maxlen=self.window_size) + self.loss_window = deque() self.user_defined_loss_cap = (loss_cap != _MAX_LOSS_CAP) self.loss_cap = loss_cap @@ -140,21 +140,28 @@ def _handle_high_losses(self, logger: Logger) -> None: window_size=self.window_size, ) - def fit_start(self, state: State, logger: Logger) -> None: - # If user does not provide a window size, set window size to a fraction of the total number of training batches for the run, minimum 100 batches. - if not self.user_defined_window_size: - total_training_steps = 0 - if state.max_duration is not None: - if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None: - total_training_steps = state.dataloader_len * state.max_duration.value - elif state.max_duration.unit == TimeUnit.BATCH: - total_training_steps = state.max_duration.value - self.window_size = max( - self.window_size, - round(float(total_training_steps * _WINDOW_FRACTION)), + def _set_window_size(self, state: State) -> None: + total_training_steps = 0 + current_step = int(state.timestamp.batch) + current_token = int(state.timestamp.token) + + if state.max_duration is not None: + if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None: + total_training_steps = state.dataloader_len * state.max_duration.value + elif state.max_duration.unit == TimeUnit.BATCH: + total_training_steps = state.max_duration.value + elif state.max_duration.unit == TimeUnit.TOKEN: + # This is an approximation of the total batches from the total tokens, assuming the ratio of tokens:batch is constant. + total_training_steps = current_step * ( + state.max_duration.value / current_token ) + self.window_size = max( + self.window_size, + round(float(total_training_steps * _WINDOW_FRACTION)), + ) self.loss_window = deque(maxlen=self.window_size) log.info(f'Window size set to: {self.window_size}') + self.window_size_set = True def batch_end(self, state: State, logger: Logger) -> None: @@ -166,6 +173,15 @@ def batch_end(self, state: State, logger: Logger) -> None: if len(self.loss_window) == self.window_size: current_step = int(state.timestamp.batch) + + # If window size has not yet been set either by user or during run, set window size to a fraction of the total training duration. Minimum 100 batches. + if not self.window_size_set: + self._set_window_size(state) + # Window size has been expanded -- keep adding losses until we reach the window size. + if self.window_size > _MIN_WINDOW_SIZE: + self.loss_window.append(train_loss) + return + # If user does not provide a loss cap, set loss cap to the maximum loss from the first loss window. Hard cap at loss=10. if not self.user_defined_loss_cap and current_step == self.window_size: self.loss_cap = min(max(self.loss_window), self.loss_cap) From 47a461cf780bccbe95a37b77a09b6e440e5a3993 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Mon, 26 Aug 2024 21:12:46 -0700 Subject: [PATCH 80/83] Test window size setting --- .../test_kill_loss_spike_callback.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py index a8746299e8..577dbd33cc 100644 --- a/tests/callbacks/test_kill_loss_spike_callback.py +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -4,6 +4,8 @@ from collections import deque from unittest.mock import MagicMock, patch +from composer.core.time import TimeUnit + from llmfoundry.callbacks.kill_loss_spike_callback import KillLossSpike from llmfoundry.utils.exceptions import LossSpikeError @@ -75,3 +77,29 @@ def test_detect_high_losses_with_high_losses(self, _): current_step = 21 result = self.callback._detect_high_losses(current_step) self.assertTrue(result) + + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_set_window_size_from_token(self, _): + state = MagicMock() + state.max_duration.unit = TimeUnit.TOKEN + state.max_duration.value = 100000 + state.timestamp.batch = 100 + state.timestamp.token = 4000 + + self.callback._set_window_size(state) + + self.assertEqual(self.callback.window_size, 125) + self.assertTrue(self.callback.window_size_set) + + @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') + def test_set_window_size_from_epoch(self, _): + state = MagicMock() + state.max_duration.unit = TimeUnit.EPOCH + state.dataloader_len = 1000 + state.max_duration.value = 3 + state.timestamp.batch = 100 + + self.callback._set_window_size(state) + + self.assertEqual(self.callback.window_size, 150) + self.assertTrue(self.callback.window_size_set) From fee774b70ce8635f60d7299fc3e82d9f13909042 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Mon, 26 Aug 2024 21:18:29 -0700 Subject: [PATCH 81/83] Keep old deque --- llmfoundry/callbacks/kill_loss_spike_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 39e5e5a914..c17b033146 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -159,7 +159,7 @@ def _set_window_size(self, state: State) -> None: self.window_size, round(float(total_training_steps * _WINDOW_FRACTION)), ) - self.loss_window = deque(maxlen=self.window_size) + self.loss_window = deque(self.loss_window, maxlen=self.window_size) log.info(f'Window size set to: {self.window_size}') self.window_size_set = True From 3e2ecc68b263ce7fccb47eaf01fdbc1e7c5e485f Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Tue, 27 Aug 2024 09:16:41 -0700 Subject: [PATCH 82/83] Cleanup --- llmfoundry/callbacks/kill_loss_spike_callback.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index c17b033146..4708d87598 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -64,11 +64,11 @@ def __init__( self.patience = patience self.outlier_multiplier = outlier_multiplier self.outlier_counter = 0 - self.window_size_set = (window_size != _MIN_WINDOW_SIZE) self.window_size = window_size self.loss_window = deque() - self.user_defined_loss_cap = (loss_cap != _MAX_LOSS_CAP) self.loss_cap = loss_cap + self.window_size_set = (window_size != _MIN_WINDOW_SIZE) + self.loss_cap_set = (loss_cap != _MAX_LOSS_CAP) def _detect_loss_spike( self, @@ -100,7 +100,7 @@ def _detect_high_losses(self, current_step: int) -> bool: high_loss_count = sum( 1 for loss in self.loss_window if loss > self.loss_cap ) - is_high_loss = high_loss_count >= self.window_size / 2 + is_high_loss = (high_loss_count >= self.window_size / 2) if is_high_loss: log.info( @@ -183,8 +183,9 @@ def batch_end(self, state: State, logger: Logger) -> None: return # If user does not provide a loss cap, set loss cap to the maximum loss from the first loss window. Hard cap at loss=10. - if not self.user_defined_loss_cap and current_step == self.window_size: + if not self.loss_cap_set and current_step == self.window_size: self.loss_cap = min(max(self.loss_window), self.loss_cap) + self.loss_cap_set = True running_loss_avg = float(np.mean(self.loss_window)) log.info(f'Running loss average: {running_loss_avg}') From 4005a5cd357a36676b2dbd61e638d52f7f72e961 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Tue, 27 Aug 2024 15:42:22 -0700 Subject: [PATCH 83/83] Remove running loss avg log --- llmfoundry/callbacks/kill_loss_spike_callback.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 4708d87598..b0a92c85e5 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -188,7 +188,6 @@ def batch_end(self, state: State, logger: Logger) -> None: self.loss_cap_set = True running_loss_avg = float(np.mean(self.loss_window)) - log.info(f'Running loss average: {running_loss_avg}') if self._detect_loss_spike(train_loss, running_loss_avg): self._handle_loss_spike(logger, running_loss_avg)