From 4dc1a69349c02bf1c39497e2bcd0c2ac1d80b285 Mon Sep 17 00:00:00 2001 From: kang sheng Date: Mon, 25 Nov 2024 18:27:13 +0800 Subject: [PATCH] Sum gathered input tokens (#34554) * sum gathered input tokens * ruff line-length is 119, format the code --------- Co-authored-by: kangsheng --- src/transformers/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 46add00b018e3a..ed45624983ad20 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2489,7 +2489,9 @@ def _inner_training_loop( else: input_tokens = inputs[main_input_name].numel() input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) - self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).cpu().item() + self.state.num_input_tokens_seen += ( + self.accelerator.gather(input_tokens).sum().cpu().item() + ) if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False