diff --git a/llmfoundry/callbacks/kill_loss_spike_callback.py b/llmfoundry/callbacks/kill_loss_spike_callback.py index 6403ac76c2..33786fc0de 100644 --- a/llmfoundry/callbacks/kill_loss_spike_callback.py +++ b/llmfoundry/callbacks/kill_loss_spike_callback.py @@ -11,7 +11,6 @@ import torch from composer.core import Callback, State, TimeUnit from composer.loggers import Logger, MosaicMLLogger -from composer.utils import dist from llmfoundry.utils.exceptions import HighLossError, LossSpikeError from llmfoundry.utils.warnings import experimental_class @@ -60,6 +59,9 @@ def __init__( self.patience = patience self.outlier_multiplier = outlier_multiplier self.outlier_counter = 0 + self.window_size = None + self.loss_window = None + self.loss_cap = None def detect_loss_spike(self, train_loss: float, running_loss_avg: float): # Train loss is an outlier @@ -96,15 +98,14 @@ def fit_start(self, state: State, logger: Logger) -> None: if state.max_duration.unit == TimeUnit.EPOCH: self.window_size = max( MIN_WINDOW_SIZE, - round(state.dataloader_len * state.max_duration.value / 20), + round(float(state.dataloader_len * state.max_duration.value / 20)), ) elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN: self.window_size = max( MIN_WINDOW_SIZE, - round(state.max_duration.value / 20), + round(float(state.max_duration.value / 20)), ) - self.loss_window = deque(maxlen=self.window_size) - + self.loss_window = deque(maxlen=self.window_size) def batch_end(self, state: State, logger: Logger) -> None: diff --git a/tests/callbacks/test_kill_loss_spike_callback.py b/tests/callbacks/test_kill_loss_spike_callback.py index 72cf6166c5..f99a48b26d 100644 --- a/tests/callbacks/test_kill_loss_spike_callback.py +++ b/tests/callbacks/test_kill_loss_spike_callback.py @@ -21,9 +21,9 @@ def __init__(self, *args: str, **kwargs: dict): log_only=True, patience=4, outlier_multiplier=2, - window_size=10, - loss_cap=10, ) + self.callback.window_size = 10 + self.callback.loss_cap = 10 @patch('llmfoundry.callbacks.kill_loss_spike_callback.log') def test_detect_loss_spike_no_spike(self, _):