From 45eaa0ee8c1f3aedcba87ae4e4c5d447a948eaee Mon Sep 17 00:00:00 2001 From: Joyce Chen Date: Wed, 21 Aug 2024 14:32:46 -0700 Subject: [PATCH] Remove hardcoded const --- llmfoundry/callbacks/kill_loss_spike_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 987926cbcd..654e6ffc05 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -54,7 +54,7 @@ def batch_end(self, state: State, logger: Logger) -> None: self.outlier_counter = 0 # Half of the running losses are greater than our "high loss" threshold, after the first window - elif state.timestamp.batch >= 200 and (sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2): + elif (state.timestamp.batch >= self.window_size * 2) and (sum(1 for loss in self.loss_window if loss > self.loss_cap) >= self.window_size / 2): log.info(f'High losses >{self.loss_cap} detected.') for destination in logger.destinations: if isinstance(destination, MosaicMLLogger):