From e3c00b35fb011b949c3e1c787168c340a91961e0 Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Sun, 25 Aug 2024 16:16:52 -0700 Subject: [PATCH] init vals for window size, loss window, loss cap --- .../callbacks/kill_loss_spike_callback.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 33786fc0de..97557cb123 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -59,9 +59,9 @@ def __init__( self.patience = patience self.outlier_multiplier = outlier_multiplier self.outlier_counter = 0 - self.window_size = None - self.loss_window = None - self.loss_cap = None + self.window_size = MIN_WINDOW_SIZE + self.loss_window = deque(maxlen=self.window_size) + self.loss_cap = float('inf') def detect_loss_spike(self, train_loss: float, running_loss_avg: float): # Train loss is an outlier @@ -95,16 +95,17 @@ def detect_high_losses(self, current_step: int): def fit_start(self, state: State, logger: Logger) -> None: # Set the window size to a fraction of the total number of training batches for the run, minimum 100 batches. - if state.max_duration.unit == TimeUnit.EPOCH: - self.window_size = max( - MIN_WINDOW_SIZE, - round(float(state.dataloader_len * state.max_duration.value / 20)), - ) - elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: - self.window_size = max( - MIN_WINDOW_SIZE, - round(float(state.max_duration.value / 20)), - ) + if state.max_duration is not None: + if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None: + self.window_size = max( + MIN_WINDOW_SIZE, + round(float(state.dataloader_len * state.max_duration.value / 20)), + ) + elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: + self.window_size = max( + MIN_WINDOW_SIZE, + round(float(state.max_duration.value / 20)), + ) self.loss_window = deque(maxlen=self.window_size) def batch_end(self, state: State, logger: Logger) -> None: