Skip to content

Commit

Permalink
Cleanup window frac
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-chen-uni committed Aug 26, 2024
1 parent 55dd5bb commit 44787df
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions llmfoundry/callbacks/kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

MIN_WINDOW_SIZE = 100
MAX_LOSS_CAP = 10
WINDOW_FRACTION = 20
WINDOW_FRACTION = 0.05


@experimental_class('KillLossSpike')
Expand Down Expand Up @@ -65,7 +65,9 @@ def __init__(
self.loss_window = deque(maxlen=self.window_size)
self.loss_cap = MAX_LOSS_CAP

def _detect_loss_spike(self, train_loss: float, running_loss_avg: float):
def _detect_loss_spike(
self, train_loss: float, running_loss_avg: float
) -> bool:
# Train loss is an outlier
if train_loss >= running_loss_avg * self.outlier_multiplier:
self.outlier_counter += 1
Expand All @@ -83,17 +85,22 @@ def _detect_loss_spike(self, train_loss: float, running_loss_avg: float):
self.outlier_counter = 0
return False

def _detect_high_losses(self, current_step: int):
def _detect_high_losses(self, current_step: int) -> bool:
if current_step < self.window_size * 2:
return False

# Half of the running losses are greater than our "high loss" threshold, after an initial buffer period
if (current_step >= self.window_size * 2) and (
sum(1 for loss in self.loss_window if loss > self.loss_cap) >=
self.window_size / 2
):
high_loss_count = sum(
1 for loss in self.loss_window if loss > self.loss_cap
)
is_high_loss = high_loss_count >= self.window_size / 2

if is_high_loss:
log.info(
f'High losses (train loss consistently greater than {self.loss_cap}) detected.',
f'High losses detected: {high_loss_count}/{self.window_size} losses above {self.loss_cap}.',
)
return True
return False

return is_high_loss

def _log_metadata(self, logger: Logger, key: str, message: str) -> None:
for destination in logger.destinations:
Expand Down Expand Up @@ -134,15 +141,15 @@ def fit_start(self, state: State, logger: Logger) -> None:
self.window_size,
round(
float(
state.dataloader_len * state.max_duration.value /
20,
state.dataloader_len * state.max_duration.value *
WINDOW_FRACTION,
),
),
)
elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN:
self.window_size = max(
self.window_size,
round(float(state.max_duration.value / 20)),
round(float(state.max_duration.value * WINDOW_FRACTION)),
)
self.loss_window = deque(maxlen=self.window_size)

Expand Down

0 comments on commit 44787df

Please sign in to comment.