From a25a19c1e054da52a35a5226cc46e96f20c1840c Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Sun, 25 Aug 2024 23:36:55 -0700 Subject: [PATCH] Cleanup window frac --- .../callbacks/kill_loss_spike_callback.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 9cc1f6204d..e2a8c77545 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -21,7 +21,7 @@ MIN_WINDOW_SIZE = 100 MAX_LOSS_CAP = 10 -WINDOW_FRACTION = 20 +WINDOW_FRACTION = 0.05 @experimental_class('KillLossSpike') @@ -65,7 +65,7 @@ def __init__( self.loss_window = deque(maxlen=self.window_size) self.loss_cap = MAX_LOSS_CAP - def _detect_loss_spike(self, train_loss: float, running_loss_avg: float): + def _detect_loss_spike(self, train_loss: float, running_loss_avg: float) -> bool: # Train loss is an outlier if train_loss >= running_loss_avg * self.outlier_multiplier: self.outlier_counter += 1 @@ -83,17 +83,20 @@ def _detect_loss_spike(self, train_loss: float, running_loss_avg: float): self.outlier_counter = 0 return False - def _detect_high_losses(self, current_step: int): + def _detect_high_losses(self, current_step: int) -> bool: + if current_step < self.window_size * 2: + return False + # Half of the running losses are greater than our "high loss" threshold, after an initial buffer period - if (current_step >= self.window_size * 2) and ( - sum(1 for loss in self.loss_window if loss > self.loss_cap) >= - self.window_size / 2 - ): + high_loss_count = sum(1 for loss in self.loss_window if loss > self.loss_cap) + is_high_loss = high_loss_count >= self.window_size / 2 + + if is_high_loss: log.info( - f'High losses (train loss consistently greater than {self.loss_cap}) detected.', + f'High losses detected: {high_loss_count}/{self.window_size} losses above {self.loss_cap}.', ) - return True - return False + + return is_high_loss def _log_metadata(self, logger: Logger, key: str, message: str) -> None: for destination in logger.destinations: @@ -134,15 +137,14 @@ def fit_start(self, state: State, logger: Logger) -> None: self.window_size, round( float( - state.dataloader_len * state.max_duration.value / - 20, + state.dataloader_len * state.max_duration.value * WINDOW_FRACTION, ), ), ) elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: self.window_size = max( self.window_size, - round(float(state.max_duration.value / 20)), + round(float(state.max_duration.value * WINDOW_FRACTION)), ) self.loss_window = deque(maxlen=self.window_size)