diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5b8ffeafc7c8ea..45b45992bf425a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3048,7 +3048,7 @@ def log(self, logs: Dict[str, float]) -> None: The values to log. """ if self.state.epoch is not None: - logs["epoch"] = round(self.state.epoch, 2) + logs["epoch"] = self.state.epoch if self.args.include_num_input_tokens_seen: logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index f5bbcdbd4218d5..225f645d631e41 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -15,6 +15,7 @@ """ Callbacks to use with the Trainer class and customize the training loop. """ +import copy import dataclasses import json from dataclasses import dataclass @@ -520,7 +521,12 @@ def on_predict(self, args, state, control, **kwargs): def on_log(self, args, state, control, logs=None, **kwargs): if state.is_world_process_zero and self.training_bar is not None: + # avoid modifying the logs object as it is shared between callbacks + logs = copy.deepcopy(logs) _ = logs.pop("total_flos", None) + # round numbers so that it looks better in console + if "epoch" in logs: + logs["epoch"] = round(logs["epoch"], 2) self.training_bar.write(str(logs)) def on_train_end(self, args, state, control, **kwargs):