diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c878d2b345cc31..5957f8025d2a0b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3698,10 +3698,12 @@ def training_step( with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: - self.accelerator.backward(loss, **kwargs) # Finally we need to normalize the loss for reporting if num_items_in_batch is None: - return loss.detach() / self.args.gradient_accumulation_steps + loss /= self.args.gradient_accumulation_steps + + self.accelerator.backward(loss, **kwargs) + return loss.detach() def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):