diff --git a/docs/source/bco_trainer.mdx b/docs/source/bco_trainer.mdx index 6b54028b10..c23365cc00 100644 --- a/docs/source/bco_trainer.mdx +++ b/docs/source/bco_trainer.mdx @@ -6,10 +6,10 @@ TRL supports the Binary Classifier Optimization (BCO). The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. For a full example have a look at [`examples/scripts/bco.py`]. -## Expected dataset format +## Expected dataset type The [`BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference). -The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. ## Expected model format The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. diff --git a/docs/source/cpo_trainer.mdx b/docs/source/cpo_trainer.mdx index 38e211856d..587252adf9 100644 --- a/docs/source/cpo_trainer.mdx +++ b/docs/source/cpo_trainer.mdx @@ -42,9 +42,9 @@ Execute the script using the following command: accelerate launch train_cpo.py ``` -## Expected dataset format +## Expected dataset type -CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. ## Example script diff --git a/docs/source/dataset_formats.mdx b/docs/source/dataset_formats.mdx index 4944b1a9b1..cc92ec0ff1 100644 --- a/docs/source/dataset_formats.mdx +++ b/docs/source/dataset_formats.mdx @@ -1,10 +1,11 @@ -# Dataset formats +# Dataset formats and types -This guide provides an overview of the dataset formats supported by each trainer in TRL. Since conversational datasets are very common, we also provide a guide on how to use them, and how to convert them into a standard dataset format for TRL trainers. +This guide provides an overview of the dataset formats and types supported by each trainer in TRL. ## Overview of the dataset formats and types -The *format* of a dataset refers to how the data is structured, typically categorized as either *standard* or *conversational*. The *type* is associated with the specific task the dataset is designed for, such as *prompt-only* or *preference*. Each type is characterized by its columns, which vary according to the task, as shown in the table. +- The *format* of a dataset refers to how the data is structured, typically categorized as either *standard* or *conversational*. +- The *type* is associated with the specific task the dataset is designed for, such as *prompt-only* or *preference*. Each type is characterized by its columns, which vary according to the task, as shown in the table. @@ -78,8 +79,9 @@ The *format* of a dataset refers to how the data is structured, typically catego
+### Formats -### Standard dataset format +#### Standard The standard dataset format typically consists of plain text strings. The columns in the dataset vary depending on the task. This is the format expected by TRL trainers. Below are examples of standard dataset formats for different tasks: @@ -90,7 +92,7 @@ example = {"text": "The sky is blue."} example = {"chosen": "The sky is blue.", "rejected": "The sky is green."} ``` -### Conversational dataset format +#### Conversational Conversational datasets are used for tasks involving dialogues or chat interactions between users and assistants. Unlike standard dataset formats, these contain sequences of messages where each message has a `role` (e.g., `"user"` or `"assistant"`) and `content` (the message text). @@ -119,7 +121,9 @@ example = { Conversational datasets are useful for training chat models, but must be converted into a standard format before being used with TRL trainers. This is typically done using chat templates specific to the model being used. For more information, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section. -### Language modeling +### Types + +#### Language modeling A language modeling dataset consists of a column `"text"` (or `"messages"` for conversational datasets) containing a full sequence of text. @@ -127,7 +131,7 @@ A language modeling dataset consists of a column `"text"` (or `"messages"` for c language_modeling_example = {"text": "The sky is blue."} ``` -### Prompt-only +#### Prompt-only In a prompt-only dataset, only the initial prompt (the question or partial sentence) is provided under the key `"prompt"`. The training typically involves generating the completion based on this prompt, where the model learns to continue or complete the given input. @@ -137,7 +141,7 @@ prompt_only_example = {"prompt": "The sky is"} -While both the prompt-only and language modeling formats are similar, they differ in how the input is handled. In the prompt-only format, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling format, the input is treated as a complete sentence or sequence. These two formats are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each format: +While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type: ```python from transformers import AutoTokenizer @@ -145,12 +149,12 @@ from trl import apply_chat_template tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct") -# Example for prompt-only format +# Example for prompt-only type prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]} apply_chat_template(prompt_only_example, tokenizer) # Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'} -# Example for language modeling format +# Example for language modeling type lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]} apply_chat_template(lm_example, tokenizer) # Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'} @@ -161,7 +165,7 @@ apply_chat_template(lm_example, tokenizer) -### Prompt-completion +#### Prompt-completion A prompt-completion dataset includes a `"prompt"` and a `"completion"`. @@ -169,7 +173,7 @@ A prompt-completion dataset includes a `"prompt"` and a `"completion"`. prompt_completion_example = {"prompt": "The sky is", "completion": " blue."} ``` -### Preference +#### Preference A preference dataset is used for tasks where the model is trained to choose between two or more possible completions to the same prompt. This dataset includes a `"prompt"`, a `"chosen"` completion, and a `"rejected"` completion. The model is trained to select the `"chosen"` response over the `"rejected"` response. Some dataset may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible. @@ -183,7 +187,7 @@ preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is gree Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets. -### Unpaired preference +#### Unpaired preference An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not. @@ -191,11 +195,11 @@ An unpaired preference dataset is similar to a preference dataset but instead of unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True} ``` -## Which dataset format to use? +## Which dataset type to use? -Choosing the right dataset format depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset formats supported by each TRL trainer. +Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer. -| Trainer | Expected dataset format | +| Trainer | Expected dataset type | | ----------------------- | ------------------------------------------------------- | | [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) | | [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | @@ -220,7 +224,7 @@ For more information on how to work with conversational datasets, refer to the [ ## Working with conversational datasets in TRL -Conversational datasets are increasingly common, especially for training chat models. However, TRL trainers (except [`SFTTrainer`]) don't support conversational datasets in their raw format. These datasets must first be converted into a standard format. +Conversational datasets are increasingly common, especially for training chat models. However, TRL trainers (except [`SFTTrainer`]) don't support conversational datasets in their raw format. These datasets must first be converted into a standard format. Fortunately, TRL offers tools to easily handle this conversion, which are detailed below. ### Converting a conversational dataset into a standard dataset @@ -270,7 +274,8 @@ dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer}) -We recommend using the [`apply_chat_template`] function rather than directly calling `tokenizer.apply_chat_template`. Handling chat templates nonlanguage modeling datasets can be tricky and may lead to issues, such as inserting a system prompt in the middle of a conversation. For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks. +We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle conversation. +For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks. @@ -308,7 +313,7 @@ Let’s take the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb As shown above, the dataset format does not match the expected structure. It’s not in a conversational format, the column names differ, and the results pertain to different models (e.g., Bard, GPT-4) and aspects (e.g., "helpfulness", "honesty"). -By using the provided conversion script [`examples/datasets/ultrafeedback.py`](https://github.com/huggingface/trl/tree/main/examples/datasets/ultrafeedback.py), you can transform this dataset into an unpaired preference format, and push it to the Hub: +By using the provided conversion script [`examples/datasets/ultrafeedback.py`](https://github.com/huggingface/trl/tree/main/examples/datasets/ultrafeedback.py), you can transform this dataset into an unpaired preference type, and push it to the Hub: ```sh python examples/datasets/ultrafeedback.py --push_to_hub --repo_id trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness @@ -727,10 +732,10 @@ A conversational vision dataset differs from a standard conversational dataset i Example: ```python -# Textual dataset format: +# Textual dataset: "content": "What color is the sky?" -# Vision dataset format: +# Vision dataset: "content": [ {"type": "image"}, {"type": "text", "text": "What color is the sky in the image?"} diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index f9d01074ed..728e0e39e3 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -79,9 +79,9 @@ The best programming language for specific applications can vary depending on th The best programming language based on these factors is subjective and depends on what the programmer intends to accomplish. -## Expected dataset format +## Expected dataset type -DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section. diff --git a/docs/source/gkd_trainer.md b/docs/source/gkd_trainer.md index b4171cf87c..c4f82ff160 100644 --- a/docs/source/gkd_trainer.md +++ b/docs/source/gkd_trainer.md @@ -82,7 +82,7 @@ trainer = GKDTrainer( trainer.train() ``` -### Expected dataset format +### Expected dataset type The dataset should be formatted as a list of "messages" where each message is a list of dictionaries with the following keys: * `role`: either `system`, `assistant` or `user` diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.mdx index 24326f5d46..7c6433be43 100644 --- a/docs/source/kto_trainer.mdx +++ b/docs/source/kto_trainer.mdx @@ -9,7 +9,7 @@ Depending on how good your base model is, you may or may not need to do SFT befo This is different from standard RLHF and DPO, which always require SFT. You can also train with imbalanced data (more chosen than rejected examples, or vice-versa), but you will need to adjust hyperparameters accordingly (see below). -## Expected dataset format +## Expected dataset type The KTO trainer expects a very specific format for the dataset as it does not require pairwise preferences. Since the model will be trained to directly optimize examples that consist of a prompt, model completion, and a label to indicate whether the completion is "good" or "bad", we expect a dataset with the following columns: diff --git a/docs/source/nash_md_trainer.md b/docs/source/nash_md_trainer.md index 7081771736..38e955639c 100644 --- a/docs/source/nash_md_trainer.md +++ b/docs/source/nash_md_trainer.md @@ -53,9 +53,9 @@ Execute the script using the following command: accelerate launch train_nash_md.py ``` -## Expected dataset format +## Expected dataset type -Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. ## Usage tips diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index 6cd17deee2..7d479cb2a9 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -68,9 +68,9 @@ For example, if the server hosting the domain name does not have the correct IP It's worth noting that the exact cause of DNS failure can vary depending on the specific situation, so it's important to carefully check all relevant factors before attempting to resolve the issue. If you suspect that your DNS problem may be caused by a bug in the system, you should report it to the DNS provider directly for further investigation. ``` -## Expected dataset format +## Expected dataset type -Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. ## Usage tips diff --git a/docs/source/orpo_trainer.md b/docs/source/orpo_trainer.md index 13b7d7fa25..628383c5bd 100644 --- a/docs/source/orpo_trainer.md +++ b/docs/source/orpo_trainer.md @@ -77,9 +77,9 @@ Here are some other factors to consider when choosing a programming language for -## Expected dataset format +## Expected dataset type -ORPO requires a [preference dataset](dataset_formats#preference). The [`ORPOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +ORPO requires a [preference dataset](dataset_formats#preference). The [`ORPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. Although the [`ORPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section. diff --git a/docs/source/reward_trainer.mdx b/docs/source/reward_trainer.mdx index a17e4ad4da..09c2ac863c 100644 --- a/docs/source/reward_trainer.mdx +++ b/docs/source/reward_trainer.mdx @@ -6,10 +6,10 @@ TRL supports custom reward modeling for anyone to perform reward modeling on the Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py). -## Expected dataset format +## Expected dataset type -The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `chosen` and `rejected` (and not `prompt`). -The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `"chosen"` and `"rejected"` (and not `"prompt"`). +The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`. diff --git a/docs/source/xpo_trainer.mdx b/docs/source/xpo_trainer.mdx index c7d9116532..0a0cf90387 100644 --- a/docs/source/xpo_trainer.mdx +++ b/docs/source/xpo_trainer.mdx @@ -53,9 +53,9 @@ Execute the script using the following command: accelerate launch train_xpo.py ``` -## Expected dataset format +## Expected dataset type -XPO requires a [prompt-only dataset](dataset_formats#prompt-only). The [`XPOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +XPO requires a [prompt-only dataset](dataset_formats#prompt-only). The [`XPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. ## Usage tips diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 6f894729c0..cf55e8dd61 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -211,59 +211,59 @@ class UnpairPreferenceDatasetTester(unittest.TestCase): ) def test_unpair_preference_dataset(self): - # Test that a paired-formatted dataset is correctly converted to unpaired format + # Test that a paired dataset is correctly converted to unpaired unpaired_dataset = unpair_preference_dataset(self.paired_dataset) self.assertEqual( unpaired_dataset.to_dict(), self.unpaired_dataset.to_dict(), - "The paired-formatted dataset should be reformatted to unpaired format.", + "The paired dataset should be converted to unpaired.", ) def test_unpair_preference_dataset_dict(self): - # Test that a paired-formatted dataset dict is correctly converted to unpaired format + # Test that a paired dataset dict is correctly converted to unpaired paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) unpaired_dataset_dict = unpair_preference_dataset(paired_dataset_dict) self.assertEqual( unpaired_dataset_dict["abc"].to_dict(), self.unpaired_dataset.to_dict(), - "The paired-formatted dataset should be reformatted to unpaired format.", + "The paired dataset should be converted to unpaired.", ) def test_maybe_unpair_preference_dataset(self): - # Test that a paired-formatted dataset is correctly reformatted to unpaired format with maybe_unpair_preference_dataset + # Test that a paired dataset is correctly converted to unpaired with maybe_unpair_preference_dataset unpaired_dataset = maybe_unpair_preference_dataset(self.paired_dataset) self.assertEqual( unpaired_dataset.to_dict(), self.unpaired_dataset.to_dict(), - "The paired-formatted dataset should be reformatted to unpaired format.", + "The paired dataset should be converted to unpaired.", ) def test_maybe_unpair_preference_dataset_dict(self): - # Test that a paired-formatted dataset dict is correctly converted to unpaired format with maybe_unpair_preference_dataset + # Test that a paired dataset dict is correctly converted to unpaired with maybe_unpair_preference_dataset paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) unpaired_dataset_dict = maybe_unpair_preference_dataset(paired_dataset_dict) self.assertEqual( unpaired_dataset_dict["abc"].to_dict(), self.unpaired_dataset.to_dict(), - "The paired-formatted dataset should be reformatted to unpaired format.", + "The paired dataset should be converted to unpaired.", ) def test_maybe_unpair_preference_dataset_already_paired(self): - # Test that a paired-formatted dataset remains unchanged with maybe_unpair_preference_dataset + # Test that a paired dataset remains unchanged with maybe_unpair_preference_dataset unpaired_dataset = maybe_unpair_preference_dataset(self.unpaired_dataset) self.assertEqual( unpaired_dataset.to_dict(), self.unpaired_dataset.to_dict(), - "The unpaired-formatted dataset should remain unchanged.", + "The unpaired dataset should remain unchanged.", ) def test_maybe_unpair_preference_dataset_dict_already_paired(self): - # Test that a paired-formatted dataset dict remains unchanged with maybe_unpair_preference_dataset + # Test that a paired dataset dict remains unchanged with maybe_unpair_preference_dataset unpaired_dataset_dict = maybe_unpair_preference_dataset(DatasetDict({"abc": self.unpaired_dataset})) self.assertEqual( unpaired_dataset_dict["abc"].to_dict(), self.unpaired_dataset.to_dict(), - "The unpaired-formatted dataset should remain unchanged.", + "The unpaired dataset should remain unchanged.", ) diff --git a/trl/data_utils.py b/trl/data_utils.py index 266dceaad7..569d398b52 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -27,7 +27,7 @@ def is_conversational(example: Dict[str, Any]) -> bool: Args: example (`Dict[str, Any]`): A single data entry of a dataset. The example can have different keys depending on the - dataset format. + dataset type. Returns: `bool`: `True` if the data is in a conversational format, `False` otherwise. @@ -147,7 +147,7 @@ def maybe_apply_chat_template( Args: example (`Dict[str, List[Dict[str, str]]`): Dictionary representing a single data entry of a conversational dataset. Each data entry can have different - keys depending on the dataset format. The supported dataset formats are: + keys depending on the dataset type. The supported dataset types are: - Language modeling dataset: `"messages"`. - Prompt-only dataset: `"prompt"`.