Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversational dataset support for Online DPO #2075

Merged
merged 24 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
- local: kto_trainer
title: KTO
- local: nash_md_trainer
title: Nash MD
title: Nash-MD
- local: orpo_trainer
title: ORPO
- local: ppo_trainer
Expand Down
4 changes: 4 additions & 0 deletions docs/source/callbacks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@
## WinRateCallback

[[autodoc]] WinRateCallback

## LogCompletionsCallback

[[autodoc]] LogCompletionsCallback
2 changes: 1 addition & 1 deletion docs/source/dataset_formats.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ Choosing the right dataset format depends on the task you are working on and the

<Tip>

TRL trainers only support standard dataset formats. If you have a conversational dataset, you must first convert it into a standard format.
TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.

</Tip>
Expand Down
98 changes: 88 additions & 10 deletions docs/source/nash_md_trainer.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,93 @@
# Nash MD Trainer
# Nash-MD Trainer

## Overview

## Overview
Nash-MD was proposed in the paper [Nash Learning from Human Feedback](https://huggingface.co/papers/2312.00886) by Rémi Munos, [Michal Valko](https://huggingface.co/misovalko), Daniele Calandriello, Mohammad Gheshlaghi Azar, Mark Rowland, Daniel Guo, Yunhao Tang, Matthieu Geist, Thomas Mésnard, and Andrea Michi.

The abstract from the paper is the following:

> Reinforcement learning from human feedback (RLHF) has emerged as the main paradigm for aligning large language models (LLMs) with human preferences. Typically, RLHF involves the initial step of learning a reward model from human feedback, often expressed as preferences between pairs of text generations produced by a pre-trained LLM. Subsequently, the LLM's policy is fine-tuned by optimizing it to maximize the reward model through a reinforcement learning algorithm. However, an inherent limitation of current reward models is their inability to fully represent the richness of human preferences and their dependency on the sampling distribution. In this study, we introduce an alternative pipeline for the fine-tuning of LLMs using pairwise human feedback. Our approach entails the initial learning of a preference model, which is conditioned on two inputs given a prompt, followed by the pursuit of a policy that consistently generates responses preferred over those generated by any competing policy, thus defining the Nash equilibrium of this preference model. We term this approach Nash learning from human feedback (NLHF). In the context of a tabular policy representation, we present a novel algorithmic solution, Nash-MD, founded on the principles of mirror descent. This algorithm produces a sequence of policies, with the last iteration converging to the regularized Nash equilibrium. Additionally, we explore parametric representations of policies and introduce gradient descent algorithms for deep-learning architectures. To demonstrate the effectiveness of our approach, we present experimental results involving the fine-tuning of a LLM for a text summarization task. We believe NLHF offers a compelling avenue for preference learning and policy optimization with the potential of advancing the field of aligning LLMs with human preferences.


This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Daniil Tiapkin](https://huggingface.co/dtiapkin), [Pierre Ménard](https://huggingface.co/menardprr), Daniele Calandriello and [Quentin Gallouédec](https://huggingface.co/qgallouedec).

## Get started
## Quick start

This example demonstrates how to train a model using the Nash-MD method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and the [Qwen 0.5B reward model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) as the reward model. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:

<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>

Below is the script to train the model:

```python
# train_nash_md.py
from datasets import load_dataset
from trl import NashMDConfig, NashMDTrainer
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")

args = NashMDConfig(output_dir="nash-md-qwen2", logging_steps=10)
trainer = NashMDTrainer(
model=model,
reward_model=reward_model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
)
trainer.train()
```

Execute the script using the following command:

```bash
accelerate launch train_nash_md.py
```

## Expected dataset format

Nash-MD requires a [prompt-only dataset](dataset_format#preference). The [`NashMDTrainer`] supports both [conversational](dataset_format#conversational-dataset-format) and [standard](dataset_format#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

## Usage tips

To just run the Nash MD script to make sure this trainer can run, you can run the following command to train a Nash MD model with a dummy reward model.
### ⚠️ Use the same chat template

Make sure that the SFT model and reward model use the _same_ chat template. Otherwise, you may find the model completions are scored incorrectly during training.

### Encourage EOS token generation

We can want the model to generate completion within a given length. During the learning, the model will generate completion up to the maximum completion length specified in the `max_new_tokens` argument of [`NashMDConfig`]. I you want to penalize for not generating an EOS token before the maximum completion length, you can use the `missing_eos_penalty` argument of [`NashMDConfig`]:

```python
args = NashMDConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
```

### Logging Completions

To better understand your model’s behavior during training, you can log sample completions periodically using the [`LogCompletionsCallback`].

```python
trainer = NashMDTrainer(..., eval_dataset=eval_dataset)
completions_callback = LogCompletionsCallback(trainer, num_prompts=8)
trainer.add_callback(completions_callback)
```

This callback logs the model's generated completions directly to Weights & Biases.

![Logged Completions](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/wandb_completions.png)

## Example script

We provide an example script to train a model using the Nash-MD method. The script is available in [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py)

To test the Nash-MD script with the [Pythia 14M model](https://huggingface.co/EleutherAI/pythia-14m) on the TL;DR summarization task, run the following command:

```bash
python examples/scripts/nash_md.py \
Expand All @@ -26,24 +101,27 @@ python examples/scripts/nash_md.py \
--num_train_epochs 3 \
--max_new_tokens 64 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0
--missing_eos_penalty 1.0 \
--push_to_hub
```

## Explanation of the logged metrics
## Logged metrics

The logged metrics are as follows:

* `loss/score`: The mean reinforce score loss.
* `loss/kl_div`: The mean kl divergence loss.
* `loss/kl`: The mean KL divergence between the model and reference data.
* `objective/entropy`: The mean entropy of the model and reference data.
* `rewards/accuracies`: The accuracies of the Nash MD's implicit reward model.
* `loss/score`: The mean reinforce score loss.
* `rewards/chosen`: The mean scores (according to the reward model) of the model completions.
* `rewards/rejected`: The mean scores (according to the reward model) of the mixture completions.
* `rewards/accuracies`: The accuracies of the Nash-MD's implicit reward model.
* `rewards/margins`: The mean reward margin (according to reward model) between the chosen and mixture completions.
* `logps/chosen`: The mean log probabilities of the chosen completions.
* `logps/rejected`: The mean log probabilities of the reference completions.
* `val/model_contain_eos_token`: The amount of times the model's output contains the eos token.
* `val/ref_contain_eos_token`: The amount of times the mixture's output contains the eos token.
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`NashMDConfig`].
* `mixture_coef`: Logit mixture coefficient for the model and reference model. Typically fixed, but can be made dynamic by passing a list to [`NashMDConfig`].

## NashMDTrainer

Expand Down
Loading
Loading