Skip to content

Commit

Permalink
📒 Fix type/format confusions (#2223)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Oct 11, 2024
1 parent b81a612 commit 5e24101
Show file tree
Hide file tree
Showing 13 changed files with 59 additions and 54 deletions.
4 changes: 2 additions & 2 deletions docs/source/bco_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ TRL supports the Binary Classifier Optimization (BCO).
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0.
For a full example have a look at [`examples/scripts/bco.py`].

## Expected dataset format
## Expected dataset type

The [`BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference).
The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

## Expected model format
The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/cpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ Execute the script using the following command:
accelerate launch train_cpo.py
```

## Expected dataset format
## Expected dataset type

CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

## Example script

Expand Down
47 changes: 26 additions & 21 deletions docs/source/dataset_formats.mdx
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Dataset formats
# Dataset formats and types

This guide provides an overview of the dataset formats supported by each trainer in TRL. Since conversational datasets are very common, we also provide a guide on how to use them, and how to convert them into a standard dataset format for TRL trainers.
This guide provides an overview of the dataset formats and types supported by each trainer in TRL.

## Overview of the dataset formats and types

The *format* of a dataset refers to how the data is structured, typically categorized as either *standard* or *conversational*. The *type* is associated with the specific task the dataset is designed for, such as *prompt-only* or *preference*. Each type is characterized by its columns, which vary according to the task, as shown in the table.
- The *format* of a dataset refers to how the data is structured, typically categorized as either *standard* or *conversational*.
- The *type* is associated with the specific task the dataset is designed for, such as *prompt-only* or *preference*. Each type is characterized by its columns, which vary according to the task, as shown in the table.

<table>
<tr>
Expand Down Expand Up @@ -78,8 +79,9 @@ The *format* of a dataset refers to how the data is structured, typically catego
</tr>
</table>

### Formats

### Standard dataset format
#### Standard

The standard dataset format typically consists of plain text strings. The columns in the dataset vary depending on the task. This is the format expected by TRL trainers. Below are examples of standard dataset formats for different tasks:

Expand All @@ -90,7 +92,7 @@ example = {"text": "The sky is blue."}
example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
```

### Conversational dataset format
#### Conversational

Conversational datasets are used for tasks involving dialogues or chat interactions between users and assistants. Unlike standard dataset formats, these contain sequences of messages where each message has a `role` (e.g., `"user"` or `"assistant"`) and `content` (the message text).

Expand Down Expand Up @@ -119,15 +121,17 @@ example = {

Conversational datasets are useful for training chat models, but must be converted into a standard format before being used with TRL trainers. This is typically done using chat templates specific to the model being used. For more information, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.

### Language modeling
### Types

#### Language modeling

A language modeling dataset consists of a column `"text"` (or `"messages"` for conversational datasets) containing a full sequence of text.

```python
language_modeling_example = {"text": "The sky is blue."}
```

### Prompt-only
#### Prompt-only

In a prompt-only dataset, only the initial prompt (the question or partial sentence) is provided under the key `"prompt"`. The training typically involves generating the completion based on this prompt, where the model learns to continue or complete the given input.

Expand All @@ -137,20 +141,20 @@ prompt_only_example = {"prompt": "The sky is"}

<Tip>

While both the prompt-only and language modeling formats are similar, they differ in how the input is handled. In the prompt-only format, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling format, the input is treated as a complete sentence or sequence. These two formats are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each format:
While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type:

```python
from transformers import AutoTokenizer
from trl import apply_chat_template

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")

# Example for prompt-only format
# Example for prompt-only type
prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
apply_chat_template(prompt_only_example, tokenizer)
# Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'}

# Example for language modeling format
# Example for language modeling type
lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]}
apply_chat_template(lm_example, tokenizer)
# Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'}
Expand All @@ -161,15 +165,15 @@ apply_chat_template(lm_example, tokenizer)

</Tip>

### Prompt-completion
#### Prompt-completion

A prompt-completion dataset includes a `"prompt"` and a `"completion"`.

```python
prompt_completion_example = {"prompt": "The sky is", "completion": " blue."}
```

### Preference
#### Preference

A preference dataset is used for tasks where the model is trained to choose between two or more possible completions to the same prompt. This dataset includes a `"prompt"`, a `"chosen"` completion, and a `"rejected"` completion. The model is trained to select the `"chosen"` response over the `"rejected"` response.
Some dataset may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible.
Expand All @@ -183,19 +187,19 @@ preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is gree

Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets.

### Unpaired preference
#### Unpaired preference

An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not.

```python
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
```

## Which dataset format to use?
## Which dataset type to use?

Choosing the right dataset format depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset formats supported by each TRL trainer.
Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer.

| Trainer | Expected dataset format |
| Trainer | Expected dataset type |
| ----------------------- | ------------------------------------------------------- |
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) |
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
Expand All @@ -220,7 +224,7 @@ For more information on how to work with conversational datasets, refer to the [

## Working with conversational datasets in TRL

Conversational datasets are increasingly common, especially for training chat models. However, TRL trainers (except [`SFTTrainer`]) don't support conversational datasets in their raw format. These datasets must first be converted into a standard format.
Conversational datasets are increasingly common, especially for training chat models. However, TRL trainers (except [`SFTTrainer`]) don't support conversational datasets in their raw format. These datasets must first be converted into a standard format.
Fortunately, TRL offers tools to easily handle this conversion, which are detailed below.

### Converting a conversational dataset into a standard dataset
Expand Down Expand Up @@ -270,7 +274,8 @@ dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})

<Tip warning={true}>

We recommend using the [`apply_chat_template`] function rather than directly calling `tokenizer.apply_chat_template`. Handling chat templates nonlanguage modeling datasets can be tricky and may lead to issues, such as inserting a system prompt in the middle of a conversation. For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle conversation.
For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.

</Tip>

Expand Down Expand Up @@ -308,7 +313,7 @@ Let’s take the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb

As shown above, the dataset format does not match the expected structure. It’s not in a conversational format, the column names differ, and the results pertain to different models (e.g., Bard, GPT-4) and aspects (e.g., "helpfulness", "honesty").

By using the provided conversion script [`examples/datasets/ultrafeedback.py`](https://github.com/huggingface/trl/tree/main/examples/datasets/ultrafeedback.py), you can transform this dataset into an unpaired preference format, and push it to the Hub:
By using the provided conversion script [`examples/datasets/ultrafeedback.py`](https://github.com/huggingface/trl/tree/main/examples/datasets/ultrafeedback.py), you can transform this dataset into an unpaired preference type, and push it to the Hub:

```sh
python examples/datasets/ultrafeedback.py --push_to_hub --repo_id trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness
Expand Down Expand Up @@ -727,10 +732,10 @@ A conversational vision dataset differs from a standard conversational dataset i
Example:

```python
# Textual dataset format:
# Textual dataset:
"content": "What color is the sky?"

# Vision dataset format:
# Vision dataset:
"content": [
{"type": "image"},
{"type": "text", "text": "What color is the sky in the image?"}
Expand Down
4 changes: 2 additions & 2 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ The best programming language for specific applications can vary depending on th
The best programming language based on these factors is subjective and depends on what the programmer intends to accomplish.
</code></pre>

## Expected dataset format
## Expected dataset type

DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/gkd_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ trainer = GKDTrainer(
trainer.train()
```

### Expected dataset format
### Expected dataset type

The dataset should be formatted as a list of "messages" where each message is a list of dictionaries with the following keys:
* `role`: either `system`, `assistant` or `user`
Expand Down
2 changes: 1 addition & 1 deletion docs/source/kto_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Depending on how good your base model is, you may or may not need to do SFT befo
This is different from standard RLHF and DPO, which always require SFT.
You can also train with imbalanced data (more chosen than rejected examples, or vice-versa), but you will need to adjust hyperparameters accordingly (see below).

## Expected dataset format
## Expected dataset type

The KTO trainer expects a very specific format for the dataset as it does not require pairwise preferences. Since the model will be trained to directly optimize examples that consist of a prompt, model completion, and a label to indicate whether the completion is "good" or "bad", we expect a dataset with the following columns:

Expand Down
4 changes: 2 additions & 2 deletions docs/source/nash_md_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ Execute the script using the following command:
accelerate launch train_nash_md.py
```

## Expected dataset format
## Expected dataset type

Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

## Usage tips

Expand Down
4 changes: 2 additions & 2 deletions docs/source/online_dpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ For example, if the server hosting the domain name does not have the correct IP
It's worth noting that the exact cause of DNS failure can vary depending on the specific situation, so it's important to carefully check all relevant factors before attempting to resolve the issue. If you suspect that your DNS problem may be caused by a bug in the system, you should report it to the DNS provider directly for further investigation.
```

## Expected dataset format
## Expected dataset type

Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

## Usage tips

Expand Down
4 changes: 2 additions & 2 deletions docs/source/orpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ Here are some other factors to consider when choosing a programming language for

</code></pre>

## Expected dataset format
## Expected dataset type

ORPO requires a [preference dataset](dataset_formats#preference). The [`ORPOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
ORPO requires a [preference dataset](dataset_formats#preference). The [`ORPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

Although the [`ORPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.

Expand Down
6 changes: 3 additions & 3 deletions docs/source/reward_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ TRL supports custom reward modeling for anyone to perform reward modeling on the

Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py).

## Expected dataset format
## Expected dataset type

The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `chosen` and `rejected` (and not `prompt`).
The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `"chosen"` and `"rejected"` (and not `"prompt"`).
The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`.

Expand Down
Loading

0 comments on commit 5e24101

Please sign in to comment.