diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e2ae622e2b6bf3..30caa2de260cb7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2404,7 +2404,6 @@ def _inner_training_loop( if args.eval_on_start: self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) - total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): epoch_dataloader = train_dataloader if hasattr(epoch_dataloader, "set_epoch"): @@ -2447,13 +2446,7 @@ def _inner_training_loop( batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) for inputs in batch_samples: step += 1 - total_batched_samples += 1 - is_last_step_and_steps_less_than_grad_acc = ( - steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch - ) - do_sync_step = is_last_step_and_steps_less_than_grad_acc or ( - total_batched_samples % args.gradient_accumulation_steps == 0 - ) + do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch # Since we perform prefetching, we need to manually set sync_gradients if not do_sync_step: self.accelerator.gradient_state._set_sync_gradients(False)