diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a708d8deb4efcc..85052326a71147 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3697,10 +3697,12 @@ def training_step( with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: - self.accelerator.backward(loss, **kwargs) # Finally we need to normalize the loss for reporting if num_items_in_batch is None: - return loss.detach() / self.args.gradient_accumulation_steps + loss /= self.args.gradient_accumulation_steps + + self.accelerator.backward(loss, **kwargs) + return loss.detach() def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):