From a0066f47f82f7af0145e3b5ebc06cf2a45b97352 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 20 Nov 2024 12:49:49 +0100 Subject: [PATCH] =?UTF-8?q?=E2=8F=B0=20Add=20`start=5Ftime`=20to=20`=5Fmay?= =?UTF-8?q?be=5Flog=5Fsave=5Fevaluate`=20(#2373)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/online_dpo_trainer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index eb6a07db7e..23ca6dd047 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -25,6 +25,7 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.data +import transformers from datasets import Dataset from packaging import version from torch.utils.data import DataLoader, IterableDataset @@ -587,8 +588,9 @@ def training_step( return loss.detach() / self.args.gradient_accumulation_steps - # Same as Trainer.evaluate but log our metrics - def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval): + # Same as Trainer._maybe_log_save_evaluate but log our metrics + # start_time defaults to None to allow compatibility with transformers<=4.46 + def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None): if self.control.should_log and self.state.global_step > self._globalstep_last_logged: logs: Dict[str, float] = {} @@ -612,7 +614,10 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno self._globalstep_last_logged = self.state.global_step self.store_flos() - self.log(logs) + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + self.log(logs, start_time) + else: # transformers<=4.46 + self.log(logs) metrics = None if self.control.should_evaluate: