Skip to content

Commit

Permalink
Window size and loss cap not specifiable
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-chen-uni committed Aug 25, 2024
1 parent c59e5be commit 1aeb46f
Showing 1 changed file with 16 additions and 25 deletions.
41 changes: 16 additions & 25 deletions llmfoundry/callbacks/kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ class KillLossSpike(Callback):
outlier_multiplier (int): The multiplier used to determine if a loss is an outlier. A loss is considered an
outlier if it is outlier_multiplier times greater than the mean of losses in
the current window. Default is 2.
window_size (int): The size of the rolling window used to track recent losses. This is set to 1/20 of the total training batches by default, with a minimum of 100 steps.
window_size (int): The size of the rolling window used to track recent losses. This is set to 1/20 of the total training batches, with a minimum of 100 steps.
loss_cap (int): The maximum allowable loss. If the training loss consistently exceeds this value,
it is considered a diverging or unstable run. This is set to the maximum loss from the first window of losses by default.
it is considered a diverging or unstable run. This is set to the maximum loss from the first window of losses.
Raises:
LossSpikeError: If log_only is False and a loss spike or persistently high loss is detected, this error is
Expand All @@ -54,17 +54,12 @@ def __init__(
log_only: bool = True,
patience: int = 4,
outlier_multiplier: float = 2,
window_size: int = None,
loss_cap: float = None,
):
self._enabled = (dist.get_global_rank() == 0)
# self._enabled = (dist.get_global_rank() == 0)
self.log_only = log_only
self.patience = patience
self.outlier_multiplier = outlier_multiplier
self.window_size = window_size
self.loss_cap = loss_cap
self.outlier_counter = 0
self.loss_window = deque(maxlen=self.window_size)

def detect_loss_spike(self, train_loss: float, running_loss_avg: float):
# Train loss is an outlier
Expand Down Expand Up @@ -97,23 +92,19 @@ def detect_high_losses(self, current_step: int):
return False

def fit_start(self, state: State, logger: Logger) -> None:
#Set the window to a fraction of the total number of training batches, minimum 100.
if not self.window_size:
if state.max_duration.unit == TimeUnit.EPOCH:
self.window_size = max(
MIN_WINDOW_SIZE,
round(state.dataloader_len * state.max_duration.value / 20),
)
elif state.max_duration.unit == TimeUnit.BATCH:
self.window_size = max(
MIN_WINDOW_SIZE,
round(state.max_duration.value / 20),
)
elif state.max_duration.unit == TimeUnit.TOKEN:
self.window_size = max(
MIN_WINDOW_SIZE,
round(state.max_duration.value / 20),
)
# Set the window size to a fraction of the total number of training batches for the run, minimum 100 batches.
if state.max_duration.unit == TimeUnit.EPOCH:
self.window_size = max(
MIN_WINDOW_SIZE,
round(state.dataloader_len * state.max_duration.value / 20),
)
elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN:
self.window_size = max(
MIN_WINDOW_SIZE,
round(state.max_duration.value / 20),
)
self.loss_window = deque(maxlen=self.window_size)


def batch_end(self, state: State, logger: Logger) -> None:

Expand Down

0 comments on commit 1aeb46f

Please sign in to comment.