Skip to content

Commit

Permalink
Fix step shifting when accumulate gradient (#33673)
Browse files Browse the repository at this point in the history
* replace total_batched_samples with step while counting grad accum step

* remove unused variable

* simplify condition for update step

* fix format by ruff

* simplify update step condition using accelerator.sync_gradients

* simplify update condition using do_sync_step

* remove print for test

---------

Co-authored-by: Zach Mueller <[email protected]>
  • Loading branch information
kibitzing and muellerzr authored Oct 31, 2024
1 parent 1b86772 commit dca93ca
Showing 1 changed file with 1 addition and 8 deletions.
9 changes: 1 addition & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2404,7 +2404,6 @@ def _inner_training_loop(
if args.eval_on_start:
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)

total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
epoch_dataloader = train_dataloader
if hasattr(epoch_dataloader, "set_epoch"):
Expand Down Expand Up @@ -2447,13 +2446,7 @@ def _inner_training_loop(
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
for inputs in batch_samples:
step += 1
total_batched_samples += 1
is_last_step_and_steps_less_than_grad_acc = (
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
)
do_sync_step = is_last_step_and_steps_less_than_grad_acc or (
total_batched_samples % args.gradient_accumulation_steps == 0
)
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
# Since we perform prefetching, we need to manually set sync_gradients
if not do_sync_step:
self.accelerator.gradient_state._set_sync_gradients(False)
Expand Down

0 comments on commit dca93ca

Please sign in to comment.