Skip to content

Commit

Permalink
Round fractional window size
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-chen-uni committed Aug 25, 2024
1 parent 57800e6 commit c59e5be
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions llmfoundry/callbacks/kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,17 @@ def fit_start(self, state: State, logger: Logger) -> None:
if state.max_duration.unit == TimeUnit.EPOCH:
self.window_size = max(
MIN_WINDOW_SIZE,
(state.dataloader_len * state.max_duration.value / 20),
round(state.dataloader_len * state.max_duration.value / 20),
)
elif state.max_duration.unit == TimeUnit.BATCH:
self.window_size = max(
MIN_WINDOW_SIZE,
state.max_duration.value / 20,
round(state.max_duration.value / 20),
)
elif state.max_duration.unit == TimeUnit.TOKEN:
self.window_size = max(
MIN_WINDOW_SIZE,
state.max_duration.value / 20,
round(state.max_duration.value / 20),
)

def batch_end(self, state: State, logger: Logger) -> None:
Expand Down

0 comments on commit c59e5be

Please sign in to comment.