From 3375acdcd8bc3688f114fc4dc1888ca023adefe6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 28 Oct 2024 16:09:54 +0000 Subject: [PATCH 1/4] Update trainer_utils import and save strategy in online_dpo_trainer.py --- trl/trainer/online_dpo_trainer.py | 8 ++++++-- trl/trainer/ppo_trainer.py | 2 +- trl/trainer/rloo_trainer.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 2575c2f886..1a76ba8c76 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -40,7 +40,7 @@ is_apex_available, is_wandb_available, ) -from transformers.trainer_utils import EvalPrediction, seed_worker +from transformers.trainer_utils import EvalPrediction, SaveStrategy, seed_worker from transformers.training_args import OptimizerNames from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging @@ -614,9 +614,13 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno metrics = None if self.control.should_evaluate: metrics = self._evaluate(trial, ignore_keys_for_eval) + is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) + + if self.args.save_strategy == SaveStrategy.BEST: + self.control.should_save = is_new_best_metric if self.control.should_save: - self._save_checkpoint(model, trial, metrics=metrics) + self._save_checkpoint(model, trial) self.control = self.callback_handler.on_save(self.args, self.state, self.control) def create_model_card( diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index e491b0622a..b36be8ffff 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -566,7 +566,7 @@ def repeat_generator(): self.lr_scheduler.step() self.control = self.callback_handler.on_step_end(args, self.state, self.control) if self.control.should_save: - self._save_checkpoint(model, trial=None, metrics=metrics) + self._save_checkpoint(model, trial=None) self.control = self.callback_handler.on_save(self.args, self.state, self.control) del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward torch.cuda.empty_cache() diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 18066976ca..941a90e0a7 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -482,7 +482,7 @@ def repeat_generator(): self.lr_scheduler.step() self.control = self.callback_handler.on_step_end(args, self.state, self.control) if self.control.should_save: - self._save_checkpoint(model, trial=None, metrics=metrics) + self._save_checkpoint(model, trial=None) self.control = self.callback_handler.on_save(self.args, self.state, self.control) torch.cuda.empty_cache() gc.collect() From bc6fe3bdb3105d7991df24c7e0f6dfa820ebe4d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 30 Oct 2024 15:08:52 +0000 Subject: [PATCH 2/4] fix back-compat for online-dpo --- trl/trainer/online_dpo_trainer.py | 47 +++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 1a76ba8c76..126a6bc9fc 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -20,6 +20,7 @@ import datasets import jinja2 +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -40,7 +41,7 @@ is_apex_available, is_wandb_available, ) -from transformers.trainer_utils import EvalPrediction, SaveStrategy, seed_worker +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, seed_worker from transformers.training_args import OptimizerNames from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging @@ -616,13 +617,55 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno metrics = self._evaluate(trial, ignore_keys_for_eval) is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) - if self.args.save_strategy == SaveStrategy.BEST: + if self.args.save_strategy == "best": self.control.should_save = is_new_best_metric if self.control.should_save: self._save_checkpoint(model, trial) self.control = self.callback_handler.on_save(self.args, self.state, self.control) + # This is copy-pasted from transformers.Trainer to ensure retro-compatibility with the previous versions. + # It can be removed once we bump the minimum version of transformers to 4.47. + # See https://github.com/huggingface/trl/pull/2288 + def _determine_best_metric(self, metrics, trial): + """ + Determine if the model should be saved based on the evaluation metrics. + If args.metric_for_best_model is not set, the loss is used. + Returns: + bool: True if a new best metric was found, else False + """ + is_new_best_metric = False + + if self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + + try: + metric_value = metrics[metric_to_check] + except KeyError as exc: + raise KeyError( + f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. " + f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments." + ) from exc + + operator = np.greater if self.args.greater_is_better else np.less + + if self.state.best_metric is None: + self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf") + + if operator(metric_value, self.state.best_metric): + run_dir = self._get_output_dir(trial=trial) + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + output_dir = os.path.join(run_dir, checkpoint_folder) + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + + is_new_best_metric = True + + return is_new_best_metric + def create_model_card( self, model_name: Optional[str] = None, From 4b3f081ce4ab9106e3c7dc545a6c55a78a4cbb9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 30 Oct 2024 15:10:20 +0000 Subject: [PATCH 3/4] better comment --- trl/trainer/online_dpo_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 126a6bc9fc..790e546387 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -624,9 +624,9 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno self._save_checkpoint(model, trial) self.control = self.callback_handler.on_save(self.args, self.state, self.control) - # This is copy-pasted from transformers.Trainer to ensure retro-compatibility with the previous versions. - # It can be removed once we bump the minimum version of transformers to 4.47. - # See https://github.com/huggingface/trl/pull/2288 + # Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions. + # This can be removed once the minimum transformers version is updated to 4.47. + # Refer to https://github.com/huggingface/trl/pull/2288 for more details. def _determine_best_metric(self, metrics, trial): """ Determine if the model should be saved based on the evaluation metrics. From 2ce8a0c8fda3db7bc747d835064e3eb10f3407a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 30 Oct 2024 15:45:21 +0000 Subject: [PATCH 4/4] Update transformers dependency to commit f33904 --- .github/workflows/tests-main.yml | 2 +- .github/workflows/tests.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests-main.yml b/.github/workflows/tests-main.yml index d47251fd25..f23ce167fc 100644 --- a/.github/workflows/tests-main.yml +++ b/.github/workflows/tests-main.yml @@ -30,7 +30,7 @@ jobs: python -m pip install --upgrade pip # install PEFT & transformers from source pip install -U git+https://github.com/huggingface/peft.git - pip install -U git+https://github.com/huggingface/transformers.git + pip install -U git+https://github.com/huggingface/transformers.git@f339042b0b8bdc0b57a70d37f67cafbea960a2ab # cpu version of pytorch pip install ".[test, diffusers]" - name: Test with pytest diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 46140b0450..ac9428b3aa 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -58,7 +58,7 @@ jobs: python -m pip install --upgrade pip # install PEFT & transformers from source pip install -U git+https://github.com/huggingface/peft.git - pip install -U git+https://github.com/huggingface/transformers.git + pip install -U git+https://github.com/huggingface/transformers.git@f339042b0b8bdc0b57a70d37f67cafbea960a2ab # cpu version of pytorch pip install ".[test, diffusers]" - name: Test with pytest @@ -82,7 +82,7 @@ jobs: run: | python -m pip install --upgrade pip # install transformers from source - pip install -U git+https://github.com/huggingface/transformers.git + pip install -U git+https://github.com/huggingface/transformers.git@f339042b0b8bdc0b57a70d37f67cafbea960a2ab # cpu version of pytorch pip install .[test] - name: Test with pytest