Skip to content

Commit

Permalink
Adjust tests no window size / loss cap inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-chen-uni committed Aug 25, 2024
1 parent 1aeb46f commit 1b5a757
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
11 changes: 6 additions & 5 deletions llmfoundry/callbacks/kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch
from composer.core import Callback, State, TimeUnit
from composer.loggers import Logger, MosaicMLLogger
from composer.utils import dist

from llmfoundry.utils.exceptions import HighLossError, LossSpikeError
from llmfoundry.utils.warnings import experimental_class
Expand Down Expand Up @@ -60,6 +59,9 @@ def __init__(
self.patience = patience
self.outlier_multiplier = outlier_multiplier
self.outlier_counter = 0
self.window_size = None
self.loss_window = None
self.loss_cap = None

def detect_loss_spike(self, train_loss: float, running_loss_avg: float):
# Train loss is an outlier
Expand Down Expand Up @@ -96,15 +98,14 @@ def fit_start(self, state: State, logger: Logger) -> None:
if state.max_duration.unit == TimeUnit.EPOCH:
self.window_size = max(
MIN_WINDOW_SIZE,
round(state.dataloader_len * state.max_duration.value / 20),
round(float(state.dataloader_len * state.max_duration.value / 20)),
)
elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN:
self.window_size = max(
MIN_WINDOW_SIZE,
round(state.max_duration.value / 20),
round(float(state.max_duration.value / 20)),
)
self.loss_window = deque(maxlen=self.window_size)

self.loss_window = deque(maxlen=self.window_size)

def batch_end(self, state: State, logger: Logger) -> None:

Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def __init__(self, *args: str, **kwargs: dict):
log_only=True,
patience=4,
outlier_multiplier=2,
window_size=10,
loss_cap=10,
)
self.callback.window_size = 10
self.callback.loss_cap = 10

@patch('llmfoundry.callbacks.kill_loss_spike_callback.log')
def test_detect_loss_spike_no_spike(self, _):
Expand Down

0 comments on commit 1b5a757

Please sign in to comment.