diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 12b2abaed9..3e3877d9a8 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -64,9 +64,9 @@ def __init__( self.patience = patience self.outlier_multiplier = outlier_multiplier self.outlier_counter = 0 - self.user_defined_window_size = (window_size != _MIN_WINDOW_SIZE) + self.window_size_set = (window_size != _MIN_WINDOW_SIZE) self.window_size = window_size - self.loss_window = deque(maxlen=self.window_size) + self.loss_window = deque() self.user_defined_loss_cap = (loss_cap != _MAX_LOSS_CAP) self.loss_cap = loss_cap @@ -140,21 +140,28 @@ def _handle_high_losses(self, logger: Logger) -> None: window_size=self.window_size, ) - def fit_start(self, state: State, logger: Logger) -> None: - # If user does not provide a window size, set window size to a fraction of the total number of training batches for the run, minimum 100 batches. - if not self.user_defined_window_size: - total_training_steps = 0 - if state.max_duration is not None: - if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None: - total_training_steps = state.dataloader_len * state.max_duration.value - elif state.max_duration.unit == TimeUnit.BATCH: - total_training_steps = state.max_duration.value - self.window_size = max( - self.window_size, - round(float(total_training_steps * _WINDOW_FRACTION)), + def _set_window_size(self, state: State) -> None: + total_training_steps = 0 + current_step = int(state.timestamp.batch) + current_token = int(state.timestamp.token) + + if state.max_duration is not None: + if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None: + total_training_steps = state.dataloader_len * state.max_duration.value + elif state.max_duration.unit == TimeUnit.BATCH: + total_training_steps = state.max_duration.value + elif state.max_duration.unit == TimeUnit.TOKEN: + # This is an approximation of the total batches from the total tokens, assuming the ratio of tokens:batch is constant. + total_training_steps = current_step * ( + state.max_duration.value / current_token ) + self.window_size = max( + self.window_size, + round(float(total_training_steps * _WINDOW_FRACTION)), + ) self.loss_window = deque(maxlen=self.window_size) log.info(f'Window size set to: {self.window_size}') + self.window_size_set = True def batch_end(self, state: State, logger: Logger) -> None: @@ -166,6 +173,15 @@ def batch_end(self, state: State, logger: Logger) -> None: if len(self.loss_window) == self.window_size: current_step = int(state.timestamp.batch) + + # If window size has not yet been set either by user or during run, set window size to a fraction of the total training duration. Minimum 100 batches. + if not self.window_size_set: + self._set_window_size(self, state) + # Window size has been expanded -- keep adding losses until we reach the window size. + if self.window_size > _MIN_WINDOW_SIZE: + self.loss_window.append(train_loss) + return + # If user does not provide a loss cap, set loss cap to the maximum loss from the first loss window. Hard cap at loss=10. if not self.user_defined_loss_cap and current_step == self.window_size: self.loss_cap = min(max(self.loss_window), self.loss_cap)