From 3579d4dae665d8def6e422e4e72c12cc2385e845 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 11 Dec 2024 10:39:01 +0000 Subject: [PATCH] Scale loss before backward --- src/transformers/trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index be41a415e5a710..1ede681e791e4c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3682,10 +3682,12 @@ def training_step( with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: - self.accelerator.backward(loss, **kwargs) # Finally we need to normalize the loss for reporting if num_items_in_batch is None: - return loss.detach() / self.args.gradient_accumulation_steps + loss /= self.args.gradient_accumulation_steps + + self.accelerator.backward(loss, **kwargs) + return loss.detach() def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):