From 9c376c571f2660342b965a5e417fe0010ba3ff4f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 24 Oct 2024 16:47:10 +0200 Subject: [PATCH] [Judges] use the pair-judges in online-preference trainers (#2243) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * use the pair-judges * add test * Update trl/trainer/online_dpo_trainer.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update trl/trainer/online_dpo_trainer.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * decode and skip special characters * initial nash * return tensors * Update trl/trainer/online_dpo_trainer.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update trl/trainer/online_dpo_trainer.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update trl/trainer/online_dpo_trainer.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * add back the logging * use batch_decode * add judges api to XPO trainer * Update tests/test_online_dpo_trainer.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * judge in examples * judge in config * add back logs when using reward model * typo * add back model_scores logging when using reward model * log scores for reward model only * better cond on what to log * same for rlhf reward * Update trl/trainer/online_dpo_trainer.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * use decode_and_strip_padding * error if both reward and judge or none are set * remove unused check * Uniform way to pass conversation into judge * heading -> leading * LogCompletionsCallback compat with online method * Update Online DPO doc * check if data is conversational for judges * update example * remove comment * use zip * fix stats xpo * Replace judge with PairRMJudge and import AutoModelForSequenceClassification * update xpo documentation * Remove doc duplication * update nash doc * XPO trl chat * nash md doc * HfPairwiseJudge --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- docs/source/nash_md_trainer.md | 69 +++++++++++++++-------- docs/source/online_dpo_trainer.md | 81 +++++++++++++++------------ docs/source/xpo_trainer.mdx | 71 ++++++++++++++++-------- examples/scripts/dpo_online.py | 27 +++++++-- examples/scripts/nash_md.py | 26 ++++++++- examples/scripts/xpo.py | 27 ++++++++- tests/test_nash_md_trainer.py | 34 +++++++++++- tests/test_online_dpo_trainer.py | 28 +++++++++- tests/test_xpo_trainer.py | 34 +++++++++++- trl/trainer/callbacks.py | 2 + trl/trainer/judges.py | 4 +- trl/trainer/nash_md_trainer.py | 92 ++++++++++++++++++++++++------- trl/trainer/online_dpo_config.py | 5 +- trl/trainer/online_dpo_trainer.py | 78 ++++++++++++++++++-------- trl/trainer/xpo_trainer.py | 85 ++++++++++++++++++++++------ 15 files changed, 502 insertions(+), 161 deletions(-) diff --git a/docs/source/nash_md_trainer.md b/docs/source/nash_md_trainer.md index 38e955639c..881e57e69c 100644 --- a/docs/source/nash_md_trainer.md +++ b/docs/source/nash_md_trainer.md @@ -14,7 +14,7 @@ This post-training method was contributed by [Kashif Rasul](https://huggingface. ## Quick start -This example demonstrates how to train a model using the Nash-MD method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and the [Qwen 0.5B reward model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) as the reward model. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here: +This example demonstrates how to train a model using the Nash-MD method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here: