From dca93ca076c68372dcf3ad1239a2119afdda629c Mon Sep 17 00:00:00 2001 From: kibitzing Date: Thu, 31 Oct 2024 22:53:23 +0900 Subject: [PATCH] Fix step shifting when accumulate gradient (#33673) * replace total_batched_samples with step while counting grad accum step * remove unused variable * simplify condition for update step * fix format by ruff * simplify update step condition using accelerator.sync_gradients * simplify update condition using do_sync_step * remove print for test --------- Co-authored-by: Zach Mueller --- src/transformers/trainer.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e2ae622e2b6bf3..30caa2de260cb7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2404,7 +2404,6 @@ def _inner_training_loop( if args.eval_on_start: self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) - total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): epoch_dataloader = train_dataloader if hasattr(epoch_dataloader, "set_epoch"): @@ -2447,13 +2446,7 @@ def _inner_training_loop( batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) for inputs in batch_samples: step += 1 - total_batched_samples += 1 - is_last_step_and_steps_less_than_grad_acc = ( - steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch - ) - do_sync_step = is_last_step_and_steps_less_than_grad_acc or ( - total_batched_samples % args.gradient_accumulation_steps == 0 - ) + do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch # Since we perform prefetching, we need to manually set sync_gradients if not do_sync_step: self.accelerator.gradient_state._set_sync_gradients(False)