diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 33786fc0de..b2a9fc20dd 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -59,9 +59,9 @@ def __init__( self.patience = patience self.outlier_multiplier = outlier_multiplier self.outlier_counter = 0 - self.window_size = None - self.loss_window = None - self.loss_cap = None + self.window_size = MIN_WINDOW_SIZE + self.loss_window = deque(maxlen=self.window_size) + self.loss_cap = float('inf') def detect_loss_spike(self, train_loss: float, running_loss_avg: float): # Train loss is an outlier