Skip to content

Commit

Permalink
💾 Fix _save_checkpoint for online methods (#2288)
Browse files Browse the repository at this point in the history
* Update trainer_utils import and save strategy in online_dpo_trainer.py

* fix back-compat for online-dpo

* better comment

* Update transformers dependency to commit f33904
  • Loading branch information
qgallouedec authored Oct 31, 2024
1 parent 06be6f4 commit bb56c6e
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,4 @@ jobs:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: 🤗 Results of the TRL CI with dev dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
51 changes: 49 additions & 2 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import datasets
import jinja2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -40,7 +41,7 @@
is_apex_available,
is_wandb_available,
)
from transformers.trainer_utils import EvalPrediction, seed_worker
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, seed_worker
from transformers.training_args import OptimizerNames
from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging

Expand Down Expand Up @@ -614,11 +615,57 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno
metrics = None
if self.control.should_evaluate:
metrics = self._evaluate(trial, ignore_keys_for_eval)
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)

if self.args.save_strategy == "best":
self.control.should_save = is_new_best_metric

if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self._save_checkpoint(model, trial)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

# Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
# This can be removed once the minimum transformers version is updated to 4.47.
# Refer to https://github.com/huggingface/trl/pull/2288 for more details.
def _determine_best_metric(self, metrics, trial):
"""
Determine if the model should be saved based on the evaluation metrics.
If args.metric_for_best_model is not set, the loss is used.
Returns:
bool: True if a new best metric was found, else False
"""
is_new_best_metric = False

if self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model

if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"

try:
metric_value = metrics[metric_to_check]
except KeyError as exc:
raise KeyError(
f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
) from exc

operator = np.greater if self.args.greater_is_better else np.less

if self.state.best_metric is None:
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")

if operator(metric_value, self.state.best_metric):
run_dir = self._get_output_dir(trial=trial)
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
output_dir = os.path.join(run_dir, checkpoint_folder)
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir

is_new_best_metric = True

return is_new_best_metric

def create_model_card(
self,
model_name: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def repeat_generator():
self.lr_scheduler.step()
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None, metrics=metrics)
self._save_checkpoint(model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def repeat_generator():
self.lr_scheduler.step()
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None, metrics=metrics)
self._save_checkpoint(model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
torch.cuda.empty_cache()
gc.collect()
Expand Down

0 comments on commit bb56c6e

Please sign in to comment.