Skip to content

Commit

Permalink
Calculate the loss window once hit min loss window size
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-chen-uni committed Aug 27, 2024
1 parent 782316b commit 184311a
Showing 1 changed file with 30 additions and 14 deletions.
44 changes: 30 additions & 14 deletions llmfoundry/callbacks/kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def __init__(
self.patience = patience
self.outlier_multiplier = outlier_multiplier
self.outlier_counter = 0
self.user_defined_window_size = (window_size != _MIN_WINDOW_SIZE)
self.window_size_set = (window_size != _MIN_WINDOW_SIZE)
self.window_size = window_size
self.loss_window = deque(maxlen=self.window_size)
self.loss_window = deque()
self.user_defined_loss_cap = (loss_cap != _MAX_LOSS_CAP)
self.loss_cap = loss_cap

Expand Down Expand Up @@ -140,21 +140,28 @@ def _handle_high_losses(self, logger: Logger) -> None:
window_size=self.window_size,
)

def fit_start(self, state: State, logger: Logger) -> None:
# If user does not provide a window size, set window size to a fraction of the total number of training batches for the run, minimum 100 batches.
if not self.user_defined_window_size:
total_training_steps = 0
if state.max_duration is not None:
if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None:
total_training_steps = state.dataloader_len * state.max_duration.value
elif state.max_duration.unit == TimeUnit.BATCH:
total_training_steps = state.max_duration.value
self.window_size = max(
self.window_size,
round(float(total_training_steps * _WINDOW_FRACTION)),
def _set_window_size(self, state: State) -> None:
total_training_steps = 0
current_step = int(state.timestamp.batch)
current_token = int(state.timestamp.token)

if state.max_duration is not None:
if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None:
total_training_steps = state.dataloader_len * state.max_duration.value
elif state.max_duration.unit == TimeUnit.BATCH:
total_training_steps = state.max_duration.value
elif state.max_duration.unit == TimeUnit.TOKEN:
# This is an approximation of the total batches from the total tokens, assuming the ratio of tokens:batch is constant.
total_training_steps = current_step * (
state.max_duration.value / current_token
)
self.window_size = max(
self.window_size,
round(float(total_training_steps * _WINDOW_FRACTION)),
)
self.loss_window = deque(maxlen=self.window_size)
log.info(f'Window size set to: {self.window_size}')
self.window_size_set = True

def batch_end(self, state: State, logger: Logger) -> None:

Expand All @@ -166,6 +173,15 @@ def batch_end(self, state: State, logger: Logger) -> None:
if len(self.loss_window) == self.window_size:

current_step = int(state.timestamp.batch)

# If window size has not yet been set either by user or during run, set window size to a fraction of the total training duration. Minimum 100 batches.
if not self.window_size_set:
self._set_window_size(self, state)
# Window size has been expanded -- keep adding losses until we reach the window size.
if self.window_size > _MIN_WINDOW_SIZE:
self.loss_window.append(train_loss)
return

# If user does not provide a loss cap, set loss cap to the maximum loss from the first loss window. Hard cap at loss=10.
if not self.user_defined_loss_cap and current_step == self.window_size:
self.loss_cap = min(max(self.loss_window), self.loss_cap)
Expand Down

0 comments on commit 184311a

Please sign in to comment.