diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index c0c27320ce..6403ac76c2 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -40,9 +40,9 @@ class KillLossSpike(Callback): outlier_multiplier (int): The multiplier used to determine if a loss is an outlier. A loss is considered an outlier if it is outlier_multiplier times greater than the mean of losses in the current window. Default is 2. - window_size (int): The size of the rolling window used to track recent losses. This is set to 1/20 of the total training batches by default, with a minimum of 100 steps. + window_size (int): The size of the rolling window used to track recent losses. This is set to 1/20 of the total training batches, with a minimum of 100 steps. loss_cap (int): The maximum allowable loss. If the training loss consistently exceeds this value, - it is considered a diverging or unstable run. This is set to the maximum loss from the first window of losses by default. + it is considered a diverging or unstable run. This is set to the maximum loss from the first window of losses. Raises: LossSpikeError: If log_only is False and a loss spike or persistently high loss is detected, this error is @@ -54,17 +54,12 @@ def __init__( log_only: bool = True, patience: int = 4, outlier_multiplier: float = 2, - window_size: int = None, - loss_cap: float = None, ): - self._enabled = (dist.get_global_rank() == 0) + # self._enabled = (dist.get_global_rank() == 0) self.log_only = log_only self.patience = patience self.outlier_multiplier = outlier_multiplier - self.window_size = window_size - self.loss_cap = loss_cap self.outlier_counter = 0 - self.loss_window = deque(maxlen=self.window_size) def detect_loss_spike(self, train_loss: float, running_loss_avg: float): # Train loss is an outlier @@ -97,23 +92,19 @@ def detect_high_losses(self, current_step: int): return False def fit_start(self, state: State, logger: Logger) -> None: - #Set the window to a fraction of the total number of training batches, minimum 100. - if not self.window_size: - if state.max_duration.unit == TimeUnit.EPOCH: - self.window_size = max( - MIN_WINDOW_SIZE, - round(state.dataloader_len * state.max_duration.value / 20), - ) - elif state.max_duration.unit == TimeUnit.BATCH: - self.window_size = max( - MIN_WINDOW_SIZE, - round(state.max_duration.value / 20), - ) - elif state.max_duration.unit == TimeUnit.TOKEN: - self.window_size = max( - MIN_WINDOW_SIZE, - round(state.max_duration.value / 20), - ) + # Set the window size to a fraction of the total number of training batches for the run, minimum 100 batches. + if state.max_duration.unit == TimeUnit.EPOCH: + self.window_size = max( + MIN_WINDOW_SIZE, + round(state.dataloader_len * state.max_duration.value / 20), + ) + elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: + self.window_size = max( + MIN_WINDOW_SIZE, + round(state.max_duration.value / 20), + ) + self.loss_window = deque(maxlen=self.window_size) + def batch_end(self, state: State, logger: Logger) -> None: