diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 0038c65439..b93506dd43 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -17,13 +17,13 @@ class KillLossSpike(Callback): """ - A callback for detecting and handling loss spikes or persistently high training losses during model training. + This callback detects and handles loss spikes or persistently high training losses during model training. Monitors the training loss at the end of each batch and maintains a rolling window of recent losses. If recent training losses exceed a specified cap or if a significant spike in loss is detected, the callback can either log a warning (displayed as a message on the run event) or raise a LossSpikeError to stop the run without retry. - Parameters: + Args: log_only (bool): If True, the callback will only log warnings without interrupting training. If False, a LossSpikeError will be raised to stop training upon detecting a loss spike or persistently high loss.