Skip to content

Commit

Permalink
init vals for window size, loss window, loss cap
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-chen-uni committed Aug 25, 2024
1 parent 1b5a757 commit e56b797
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions llmfoundry/callbacks/kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def __init__(
self.patience = patience
self.outlier_multiplier = outlier_multiplier
self.outlier_counter = 0
self.window_size = None
self.loss_window = None
self.loss_cap = None
self.window_size = MIN_WINDOW_SIZE
self.loss_window = deque(maxlen=self.window_size)
self.loss_cap = float('inf')

def detect_loss_spike(self, train_loss: float, running_loss_avg: float):
# Train loss is an outlier
Expand Down Expand Up @@ -95,16 +95,17 @@ def detect_high_losses(self, current_step: int):

def fit_start(self, state: State, logger: Logger) -> None:
# Set the window size to a fraction of the total number of training batches for the run, minimum 100 batches.
if state.max_duration.unit == TimeUnit.EPOCH:
self.window_size = max(
MIN_WINDOW_SIZE,
round(float(state.dataloader_len * state.max_duration.value / 20)),
)
elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN:
self.window_size = max(
MIN_WINDOW_SIZE,
round(float(state.max_duration.value / 20)),
)
if state.max_duration is not None:
if state.max_duration.unit == TimeUnit.EPOCH:
self.window_size = max(
MIN_WINDOW_SIZE,
round(float(state.dataloader_len * state.max_duration.value / 20)),
)
elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN:
self.window_size = max(
MIN_WINDOW_SIZE,
round(float(state.max_duration.value / 20)),
)
self.loss_window = deque(maxlen=self.window_size)

def batch_end(self, state: State, logger: Logger) -> None:
Expand Down

0 comments on commit e56b797

Please sign in to comment.