Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix step shifting when accumulate gradient #33673

Merged
9 changes: 1 addition & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2404,7 +2404,6 @@ def _inner_training_loop(
if args.eval_on_start:
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)

total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
epoch_dataloader = train_dataloader
if hasattr(epoch_dataloader, "set_epoch"):
Expand Down Expand Up @@ -2447,13 +2446,7 @@ def _inner_training_loop(
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
for inputs in batch_samples:
step += 1
total_batched_samples += 1
is_last_step_and_steps_less_than_grad_acc = (
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
)
do_sync_step = is_last_step_and_steps_less_than_grad_acc or (
total_batched_samples % args.gradient_accumulation_steps == 0
)
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
# Since we perform prefetching, we need to manually set sync_gradients
if not do_sync_step:
self.accelerator.gradient_state._set_sync_gradients(False)
Expand Down
Loading