Skip to content

Commit

Permalink
[Judges] use the pair-judges in online-preference trainers (#2243)
Browse files Browse the repository at this point in the history
* use the pair-judges

* add test

* Update trl/trainer/online_dpo_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update trl/trainer/online_dpo_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* decode and skip special characters

* initial nash

* return tensors

* Update trl/trainer/online_dpo_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update trl/trainer/online_dpo_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update trl/trainer/online_dpo_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* add back the logging

* use batch_decode

* add judges api to XPO trainer

* Update tests/test_online_dpo_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* judge in examples

* judge in config

* add back logs when using reward model

* typo

* add back model_scores logging when using reward model

* log scores for reward model only

* better cond on what to log

* same for rlhf reward

* Update trl/trainer/online_dpo_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* use decode_and_strip_padding

* error if both reward and judge or none are set

* remove unused check

* Uniform way to pass conversation into judge

* heading -> leading

* LogCompletionsCallback compat with online method

* Update Online DPO doc

* check if data is conversational for judges

* update example

* remove comment

* use zip

* fix stats xpo

* Replace judge with PairRMJudge and import AutoModelForSequenceClassification

* update xpo documentation

* Remove doc duplication

* update nash doc

* XPO trl chat

* nash md doc

* HfPairwiseJudge

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
3 people authored Oct 24, 2024
1 parent 1699473 commit 9c376c5
Show file tree
Hide file tree
Showing 15 changed files with 502 additions and 161 deletions.
69 changes: 47 additions & 22 deletions docs/source/nash_md_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ This post-training method was contributed by [Kashif Rasul](https://huggingface.

## Quick start

This example demonstrates how to train a model using the Nash-MD method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and the [Qwen 0.5B reward model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) as the reward model. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
This example demonstrates how to train a model using the Nash-MD method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:

<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0"
Expand All @@ -28,21 +28,17 @@ Below is the script to train the model:
```python
# train_nash_md.py
from datasets import load_dataset
from trl import NashMDConfig, NashMDTrainer
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl import NashMDConfig, NashMDTrainer, PairRMJudge
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")

training_args = NashMDConfig(output_dir="nash-md-qwen2", logging_steps=10)
training_args = NashMDConfig(output_dir="Qwen2-0.5B-NashMD", logging_steps=10)
trainer = NashMDTrainer(
model=model,
reward_model=reward_model,
args=training_args,
processing_class=tokenizer,
train_dataset=train_dataset,
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
trainer.train()
```
Expand All @@ -53,15 +49,47 @@ Execute the script using the following command:
accelerate launch train_nash_md.py
```

Distributed across 8 GPUs, the training takes approximately 3 hours.

To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-NashMD) performs, you can use the [TRL Chat CLI](clis#chat-interface).

<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-NashMD
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?

<strong><span style="color: blue;">&lt;trl-lib/Qwen2-0.5B-NashMD&gt;:</span></strong>
The best programming language depends on personal preference, the complexity of the project, and the specific requirements of the task. Some programming languages that are often recommended include Python, Java, and JavaScript, and there are many other languages to choose from depending on individual needs.
</code></pre>

## Expected dataset type

Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

## Usage tips

### ⚠️ Use the same chat template
### Use a reward model

Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model:

```diff
- from trl import PairRMJudge
+ from transformers import AutoModelForSequenceClassification

- judge = PairRMJudge()
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)

trainer = NashMDTrainer(
...
- judge=judge,
+ reward_model=reward_model,
)
```

<Tip warning={true}>

Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.

Make sure that the SFT model and reward model use the _same_ chat template. Otherwise, you may find the model completions are scored incorrectly during training.
</Tip>

### Encourage EOS token generation

Expand Down Expand Up @@ -89,21 +117,17 @@ This callback logs the model's generated completions directly to Weights & Biase

We provide an example script to train a model using the Nash-MD method. The script is available in [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py)

To test the Nash-MD script with the [Pythia 14M model](https://huggingface.co/EleutherAI/pythia-14m) on the TL;DR summarization task, run the following command:
To test the online DPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command:

```bash
python examples/scripts/nash_md.py \
--model_name_or_path EleutherAI/pythia-14m \
--reward_model_path EleutherAI/pythia-14m \
--dataset_name trl-lib/tldr \
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--judge pair_rm \
--dataset_name trl-lib/ultrafeedback-prompt \
--learning_rate 5.0e-7 \
--output_dir pythia-14m-tldr-nash-md \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 32 \
--num_train_epochs 3 \
--max_new_tokens 64 \
--logging_steps 25 \
--output_dir Qwen2.5-0.5B-NashMD-PairRM \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--push_to_hub
```

Expand All @@ -116,6 +140,7 @@ The logged metrics are as follows:
* `loss/score`: The mean reinforce score loss.
* `rewards/chosen`: The mean scores (according to the reward model) of the model completions.
* `rewards/rejected`: The mean scores (according to the reward model) of the mixture completions.
* `rewards/probabilities`: The mean probability (according to the reward model or judge) of the model completions chosen vs the mixture completion.
* `rewards/accuracies`: The accuracies of the Nash-MD's implicit reward model.
* `rewards/margins`: The mean reward margin (according to reward model) between the chosen and mixture completions.
* `logps/chosen`: The mean log probabilities of the chosen completions.
Expand Down
81 changes: 46 additions & 35 deletions docs/source/online_dpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@ The abstract from the paper is the following:

> Direct alignment from preferences (DAP) methods, such as DPO, have recently emerged as efficient alternatives to reinforcement learning from human feedback (RLHF), that do not require a separate reward model. However, the preference datasets used in DAP methods are usually collected ahead of training and never updated, thus the feedback is purely offline. Moreover, responses in these datasets are often sampled from a language model distinct from the one being aligned, and since the model evolves over training, the alignment phase is inevitably off-policy. In this study, we posit that online feedback is key and improves DAP methods. Our method, online AI feedback (OAIF), uses an LLM as annotator: on each training iteration, we sample two responses from the current model and prompt the LLM annotator to choose which one is preferred, thus providing online feedback. Despite its simplicity, we demonstrate via human evaluation in several tasks that OAIF outperforms both offline DAP and RLHF methods. We further show that the feedback leveraged in OAIF is easily controllable, via instruction prompts to the LLM annotator.

The current implementation uses reward models for scoring completions -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use.

This post-training method was contributed by [Michael Noukhovitch](https://huggingface.co/mnoukhov), [Shengyi Costa Huang](https://huggingface.co/vwxyzjn), [Quentin Gallouédec](https://huggingface.co/qgallouedec), and [Edward Beeching](https://huggingface.co/edbeeching).

## Quick start

This example demonstrates how to train a model using the online DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and the [Qwen 0.5B reward model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) as the reward model. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
This example demonstrates how to train a model using the online DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:

<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0"
Expand All @@ -30,17 +28,17 @@ Below is the script to train the model:
```python
# train_online_dpo.py
from datasets import load_dataset
from trl import OnlineDPOConfig, OnlineDPOTrainer
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl import OnlineDPOConfig, OnlineDPOTrainer, PairRMJudge
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")

training_args = OnlineDPOConfig(output_dir="online-dpo-qwen2", logging_steps=10)
training_args = OnlineDPOConfig(output_dir="Qwen2-0.5B-OnlineDPO", logging_steps=10)
trainer = OnlineDPOTrainer(
model=model, reward_model=reward_model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
trainer.train()
```
Expand All @@ -53,34 +51,51 @@ accelerate launch train_online_dpo.py

Distributed across 8 GPUs, the training takes approximately 1 hour. You can verify the training progress by checking the reward graph. An increasing trend in both the reward for rejected and chosen completions indicates that the model is improving and generating better responses over time.

![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/online-dpo-qwen2-reward.png)
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/online-dpo-qwen2.png)

To see how the trained model performs, use the following code to generate completions:
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).

```python
>>> from transformers import pipeline
>>> generator = pipeline("text-generation", model="online-dpo-qwen2/checkpoint-1773", device="cuda")
>>> question = "Why is the problem always DNS?"
>>> output = generator([{"role": "user", "content": question}], max_new_tokens=200, return_full_text=False)[0]
>>> print(output["generated_text"])
The reason why the problem of DNS (Domain Name System) can always be encountered is that it is designed to provide reliable and accurate information about the availability, ownership, or expiration of domain names. However, there may be some circumstances where the system fails to resolve an IP address correctly, leading to the problem of DNS.
For example, if the server hosting the domain name does not have the correct IP address associated with it, or if the IP address is incorrectly formatted, then the DNS system will fail to resolve the domain name correctly. Additionally, if the server hosting the domain name has been compromised, then the DNS system may also fail to resolve the domain name correctly.
It's worth noting that the exact cause of DNS failure can vary depending on the specific situation, so it's important to carefully check all relevant factors before attempting to resolve the issue. If you suspect that your DNS problem may be caused by a bug in the system, you should report it to the DNS provider directly for further investigation.
```
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-OnlineDPO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?

<strong><span style="color: blue;">&lt;trl-lib/Qwen2-0.5B-OnlineDPO&gt;:</span></strong>
The best programming language depends on your specific needs and priorities. Some people prefer imperative programming languages (like Haskell or Lisp), while others prefer functional programming languages (like Scala or Python). It's important to consider your work style, programming environment, and project requirements when choosing a programming language.
</code></pre>

## Expected dataset type

Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

## Usage tips

### ⚠️ Use the same chat template
### Use a reward model

Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model:

```diff
- from trl import PairRMJudge
+ from transformers import AutoModelForSequenceClassification

- judge = PairRMJudge()
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)

trainer = OnlineDPOTrainer(
...
- judge=judge,
+ reward_model=reward_model,
)
```

<Tip warning={true}>

Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.

Make sure that the SFT model and reward model use the _same_ chat template. Otherwise, you may find the model completions are scored incorrectly during training.
</Tip>

### Encourage EOS token generation

We may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`OnlineDPOConfig`]:
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`OnlineDPOConfig`]:

```python
training_args = OnlineDPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
Expand All @@ -105,33 +120,29 @@ This callback logs the model's generated completions directly to Weights & Biase

We provide an example script to train a model using the online DPO method. The script is available in [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py)

To test the online DPO script with the [Pythia 1B model](https://huggingface.co/trl-lib/pythia-1b-deduped-tldr-sft) on the TL;DR summarization task, run the following command:
To test the online DPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command:

```bash
python examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--judge pair_rm \
--dataset_name trl-lib/ultrafeedback-prompt \
--learning_rate 5.0e-7 \
--output_dir pythia-1b-tldr-online-dpo \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 32 \
--num_train_epochs 3 \
--max_new_tokens 53 \
--logging_steps 25 \
--output_dir Qwen2.5-0.5B-Online-DPO-PairRM \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--push_to_hub
```

## Logged metrics

The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/w4apmsi9)

* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current model and reference model.
* `objective/entropy`: The mean entropy of the model, indicating the randomness of the actions chosen by the model.
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
* `objective/rlhf_reward`: The mean RLHF reward, which is `scores - non_score_reward`. The `rlhf_reward` is the ultimate objective of online DPO training. If training works as intended, this metric should keep going up.
* `objective/scores`: The mean scores returned by the reward mode.
* `objective/scores`: The mean scores returned by the reward model.
* `objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions.
* `rewards/chosen`: The mean reward (according to online DPO's implicit reward model)of the chosen completions.
* `rewards/rejected`: The mean reward (according to online DPO's implicit reward model) of the rejected completions.
Expand Down
Loading

0 comments on commit 9c376c5

Please sign in to comment.