From 39fdbfb40900ee37f9cb08e3577e1e667d8ab118 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Wed, 21 Aug 2024 16:02:09 -0700 Subject: [PATCH] edit docstring --- llmfoundry/callbacks/kill_loss_spike_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 0038c65439..b93506dd43 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -17,13 +17,13 @@ class KillLossSpike(Callback): """ - A callback for detecting and handling loss spikes or persistently high training losses during model training. + This callback detects and handles loss spikes or persistently high training losses during model training. Monitors the training loss at the end of each batch and maintains a rolling window of recent losses. If recent training losses exceed a specified cap or if a significant spike in loss is detected, the callback can either log a warning (displayed as a message on the run event) or raise a LossSpikeError to stop the run without retry. - Parameters: + Args: log_only (bool): If True, the callback will only log warnings without interrupting training. If False, a LossSpikeError will be raised to stop training upon detecting a loss spike or persistently high loss.