From bb56c6e6afcec01207dd721c7939eb4f0370507c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 31 Oct 2024 12:35:25 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=BE=20Fix=20`=5Fsave=5Fcheckpoint`=20f?= =?UTF-8?q?or=20online=20methods=20(#2288)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update trainer_utils import and save strategy in online_dpo_trainer.py * fix back-compat for online-dpo * better comment * Update transformers dependency to commit f33904 --- .github/workflows/tests.yml | 2 +- trl/trainer/online_dpo_trainer.py | 51 +++++++++++++++++++++++++++++-- trl/trainer/ppo_trainer.py | 2 +- trl/trainer/rloo_trainer.py | 2 +- 4 files changed, 52 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6ba5bcab42..ffa634ad15 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -112,4 +112,4 @@ jobs: slack_channel: ${{ env.CI_SLACK_CHANNEL }} title: 🤗 Results of the TRL CI with dev dependencies status: ${{ job.status }} - slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} \ No newline at end of file + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 2575c2f886..790e546387 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -20,6 +20,7 @@ import datasets import jinja2 +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -40,7 +41,7 @@ is_apex_available, is_wandb_available, ) -from transformers.trainer_utils import EvalPrediction, seed_worker +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, seed_worker from transformers.training_args import OptimizerNames from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging @@ -614,11 +615,57 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno metrics = None if self.control.should_evaluate: metrics = self._evaluate(trial, ignore_keys_for_eval) + is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) + + if self.args.save_strategy == "best": + self.control.should_save = is_new_best_metric if self.control.should_save: - self._save_checkpoint(model, trial, metrics=metrics) + self._save_checkpoint(model, trial) self.control = self.callback_handler.on_save(self.args, self.state, self.control) + # Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions. + # This can be removed once the minimum transformers version is updated to 4.47. + # Refer to https://github.com/huggingface/trl/pull/2288 for more details. + def _determine_best_metric(self, metrics, trial): + """ + Determine if the model should be saved based on the evaluation metrics. + If args.metric_for_best_model is not set, the loss is used. + Returns: + bool: True if a new best metric was found, else False + """ + is_new_best_metric = False + + if self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + + try: + metric_value = metrics[metric_to_check] + except KeyError as exc: + raise KeyError( + f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. " + f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments." + ) from exc + + operator = np.greater if self.args.greater_is_better else np.less + + if self.state.best_metric is None: + self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf") + + if operator(metric_value, self.state.best_metric): + run_dir = self._get_output_dir(trial=trial) + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + output_dir = os.path.join(run_dir, checkpoint_folder) + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + + is_new_best_metric = True + + return is_new_best_metric + def create_model_card( self, model_name: Optional[str] = None, diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index e491b0622a..b36be8ffff 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -566,7 +566,7 @@ def repeat_generator(): self.lr_scheduler.step() self.control = self.callback_handler.on_step_end(args, self.state, self.control) if self.control.should_save: - self._save_checkpoint(model, trial=None, metrics=metrics) + self._save_checkpoint(model, trial=None) self.control = self.callback_handler.on_save(self.args, self.state, self.control) del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward torch.cuda.empty_cache() diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 18066976ca..941a90e0a7 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -482,7 +482,7 @@ def repeat_generator(): self.lr_scheduler.step() self.control = self.callback_handler.on_step_end(args, self.state, self.control) if self.control.should_save: - self._save_checkpoint(model, trial=None, metrics=metrics) + self._save_checkpoint(model, trial=None) self.control = self.callback_handler.on_save(self.args, self.state, self.control) torch.cuda.empty_cache() gc.collect()