From 776939dcc43355507cd4e7adbdfbad44e7a5481b Mon Sep 17 00:00:00 2001 From: Philipp Schmid <32632186+philschmid@users.noreply.github.com> Date: Fri, 12 Jan 2024 08:05:32 +0100 Subject: [PATCH] Add support for ChatML dataset format in (#1208) * Add support for ChatML dataset format in SFTTrainer * fix formatting * fix tests * more comment * fix intent * fix doc string * Update dataset_formatting.py * Update dataset_formatting.py * add documentation * Update sft_trainer.mdx * add leonardos comment and more tests * added more tests and fixed batching * style * comment in --- docs/source/sft_trainer.mdx | 53 +++++++++++-- tests/test_dataset_formatting.py | 124 +++++++++++++++++++++++++++++++ tests/test_sft_trainer.py | 66 +++++++++++++++- trl/extras/dataset_formatting.py | 88 ++++++++++++++++++++++ trl/trainer/sft_trainer.py | 6 ++ 5 files changed, 331 insertions(+), 6 deletions(-) create mode 100644 tests/test_dataset_formatting.py create mode 100644 trl/extras/dataset_formatting.py diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx index 0f9e145019..9c4df57677 100644 --- a/docs/source/sft_trainer.mdx +++ b/docs/source/sft_trainer.mdx @@ -154,6 +154,49 @@ response_template_ids = tokenizer.encode(response_template_with_context, add_spe data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer) ``` +### Dataset format support + +The [`SFTTrainer`] supports popular dataset formats. This allows you to pass the dataset to the trainer without any pre-processing directly. The following formats are supported: +* conversational format +```json +{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}]} +{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "..."}]} +{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "..."}]} +``` +* instruction format +```json +{"prompt": "", "completion": ""} +{"prompt": "", "completion": ""} +{"prompt": "", "completion": ""} +``` + +If your dataset uses one of the above formats, you can directly pass it to the trainer without pre-processing. The [`SFTTrainer`] will then format the dataset for you using the defined format from the model's tokenizer with the [apply_chat_template](https://huggingface.co/docs/transformers/main/en/chat_templating#templates-for-chat-models) method. + + +```python +from datasets import load_dataset +from trl import SFTTrainer + +... + +# load jsonl dataset +dataset = load_dataset("json", data_files="path/to/dataset.jsonl", split="train") +# load dataset from the HuggingFace Hub +dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train") + +... + +trainer = SFTTrainer( + "facebook/opt-350m", + args=training_args, + train_dataset=dataset, + packing=True, +) +``` + +If the dataset is not in one those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look. + + ### Format your input prompts For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response. @@ -346,11 +389,11 @@ Note that you cannot train your model using Flash Attention 1 on an arbitrary da Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB. | use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step | -|----------------|-------------------|-------------|------------|------------------------| -| x | facebook/opt-350m | 2048 | 8 | ~59.1s | -| | facebook/opt-350m | 2048 | 8 | **OOM** | -| x | facebook/opt-350m | 2048 | 4 | ~30.3s | -| | facebook/opt-350m | 2048 | 4 | ~148.9s | +| ---------------- | ----------------- | ----------- | ---------- | ---------------------- | +| x | facebook/opt-350m | 2048 | 8 | ~59.1s | +| | facebook/opt-350m | 2048 | 8 | **OOM** | +| x | facebook/opt-350m | 2048 | 4 | ~30.3s | +| | facebook/opt-350m | 2048 | 4 | ~148.9s | ### Using Flash Attention-2 diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py new file mode 100644 index 0000000000..156afdef09 --- /dev/null +++ b/tests/test_dataset_formatting.py @@ -0,0 +1,124 @@ +import unittest +from typing import Callable + +from datasets import Dataset, load_dataset +from transformers import AutoTokenizer + +from trl.extras.dataset_formatting import get_formatting_func_from_dataset + + +class DatasetFormattingTestCase(unittest.TestCase): + def setUp(self): + self.llama_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + self.chatml_tokenizer = AutoTokenizer.from_pretrained("philschmid/gpt2-chatml-tokenizer") + + def test_get_formatting_func_from_dataset_with_chatml_messages(self): + dataset = Dataset.from_dict( + { + "messages": [ + [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how can I help you?"}, + ] + ] + } + ) + + # Llama tokenizer + formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) + self.assertTrue(isinstance(formatting_func, Callable)) + formatted_text = formatting_func(dataset[0]) + self.assertEqual( + formatted_text, + "[INST] <>\nYou are helpful\n<>\n\nHello [/INST] Hi, how can I help you? ", + ) + formatted_text = formatting_func(dataset[0:1]) + self.assertEqual( + formatted_text, + ["[INST] <>\nYou are helpful\n<>\n\nHello [/INST] Hi, how can I help you? "], + ) + + # ChatML tokenizer + formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer) + formatted_text = formatting_func(dataset[0]) + self.assertEqual( + formatted_text, + "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n", + ) + formatted_text = formatting_func(dataset[0:1]) + self.assertEqual( + formatted_text, + [ + "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n" + ], + ) + + def test_get_formatting_func_from_dataset_with_chatml_conversations(self): + dataset = Dataset.from_dict( + { + "conversations": [ + [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how can I help you?"}, + ] + ] + } + ) + # Llama tokenizer + formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) + self.assertTrue(isinstance(formatting_func, Callable)) + formatted_text = formatting_func(dataset[0]) + self.assertEqual( + formatted_text, + "[INST] <>\nYou are helpful\n<>\n\nHello [/INST] Hi, how can I help you? ", + ) + formatted_text = formatting_func(dataset[0:1]) + self.assertEqual( + formatted_text, + ["[INST] <>\nYou are helpful\n<>\n\nHello [/INST] Hi, how can I help you? "], + ) + + # ChatML tokenizer + formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer) + formatted_text = formatting_func(dataset[0]) + self.assertEqual( + formatted_text, + "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n", + ) + formatted_text = formatting_func(dataset[0:1]) + self.assertEqual( + formatted_text, + [ + "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n" + ], + ) + + def test_get_formatting_func_from_dataset_with_instruction(self): + dataset = Dataset.from_list( + [{"prompt": "What is 2+2?", "completion": "4"}, {"prompt": "What is 3+3?", "completion": "6"}] + ) + formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) + self.assertIsNotNone(formatting_func) + self.assertTrue(isinstance(formatting_func, Callable)) + formatted_text = formatting_func(dataset[0]) + self.assertEqual(formatted_text, "[INST] What is 2+2? [/INST] 4 ") + formatted_text = formatting_func(dataset[0:1]) + self.assertEqual(formatted_text, ["[INST] What is 2+2? [/INST] 4 "]) + + def test_get_formatting_func_from_dataset_from_hub(self): + ds_1 = load_dataset("philschmid/trl-test-instruction", split="train") + ds_2 = load_dataset("philschmid/dolly-15k-oai-style", split="train") + for ds in [ds_1, ds_2]: + formatting_func = get_formatting_func_from_dataset(ds, self.llama_tokenizer) + self.assertIsNotNone(formatting_func) + self.assertTrue(isinstance(formatting_func, Callable)) + ds_3 = load_dataset("philschmid/guanaco-sharegpt-style", split="train") + formatting_func = get_formatting_func_from_dataset(ds_3, self.llama_tokenizer) + self.assertIsNone(formatting_func) + + def test_get_formatting_func_from_dataset_with_unknown_format(self): + dataset = Dataset.from_dict({"text": "test"}) + formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) + self.assertIsNone(formatting_func) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index f430c7b486..3aa873caf5 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -85,6 +85,42 @@ def setUpClass(cls): ], } ) + cls.dummy_chatml_dataset = Dataset.from_dict( + { + "messages": [ + [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how can I help you?"}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + {"role": "user", "content": "What is 3+3?"}, + {"role": "assistant", "content": "6"}, + ], + [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how can I help you?"}, + ], + ] + } + ) + cls.dummy_instruction_dataset = Dataset.from_list( + [ + {"prompt": "What is 2+2?", "completion": "4"}, + {"prompt": "What is 3+3?", "completion": "6"}, + {"prompt": "What is 4+4?", "completion": "8"}, + {"prompt": "What is 2+2?", "completion": "4"}, + {"prompt": "What is 3+3?", "completion": "6"}, + {"prompt": "What is 4+4?", "completion": "8"}, + {"prompt": "What is 2+2?", "completion": "4"}, + {"prompt": "What is 3+3?", "completion": "6"}, + {"prompt": "What is 4+4?", "completion": "8"}, + {"prompt": "What is 2+2?", "completion": "4"}, + {"prompt": "What is 3+3?", "completion": "6"}, + {"prompt": "What is 4+4?", "completion": "8"}, + ] + ) cls.train_dataset = ConstantLengthDataset( cls.tokenizer, @@ -171,7 +207,35 @@ def test_sft_trainer_uncorrect_data(self): train_dataset=self.dummy_dataset, packing=True, ) - + # this should work since the dummy chatml include the correct format + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_chatml_dataset, + max_seq_length=32, # make sure there is at least 1 packed sequence + num_of_sequences=32, + packing=True, + ) + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_chatml_dataset, + packing=False, + ) + # this should work since the dummy instruction dataset is the correct format + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_instruction_dataset, + max_seq_length=16, # make sure there is at least 1 packed sequence + packing=True, + ) + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_instruction_dataset, + packing=False, + ) # This should work _ = SFTTrainer( model=self.model, diff --git a/trl/extras/dataset_formatting.py b/trl/extras/dataset_formatting.py new file mode 100644 index 0000000000..31cf567209 --- /dev/null +++ b/trl/extras/dataset_formatting.py @@ -0,0 +1,88 @@ +import logging +from typing import Callable, Literal, Optional, Union + +from datasets import Dataset, Value +from transformers import AutoTokenizer + +from ..trainer.utils import ConstantLengthDataset + + +FORMAT_MAPPING = { + "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}], + "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, +} + + +def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]): + r""" + return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer + apply chat template to the dataset + """ + + def format_dataset(examples): + if isinstance(examples[messages_field][0], list): + output_texts = [] + for i in range(len(examples[messages_field])): + output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False)) + return output_texts + else: + return tokenizer.apply_chat_template(examples[messages_field], tokenize=False) + + return format_dataset + + +def instructions_formatting_function(tokenizer: AutoTokenizer): + r""" + return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer + apply chat template to the dataset + """ + + def format_dataset(examples): + if isinstance(examples["prompt"], list): + output_texts = [] + for i in range(len(examples["prompt"])): + converted_sample = [ + {"role": "user", "content": examples["prompt"][i]}, + {"role": "assistant", "content": examples["completion"][i]}, + ] + output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False)) + return output_texts + else: + converted_sample = [ + {"role": "user", "content": examples["prompt"]}, + {"role": "assistant", "content": examples["completion"]}, + ] + return tokenizer.apply_chat_template(converted_sample, tokenize=False) + + return format_dataset + + +def get_formatting_func_from_dataset( + dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer +) -> Optional[Callable]: + r""" + Finds the correct formatting function based on the dataset structure. Currently supported datasets are: + - `ChatML` with [{"role": str, "content": str}] + - `instruction` with [{"prompt": str, "completion": str}] + + Args: + dataset (Dataset): User dataset + tokenizer (AutoTokenizer): Tokenizer used for formatting + + Returns: + Callable: Formatting function if the dataset format is supported else None + """ + if isinstance(dataset, Dataset): + if "messages" in dataset.features: + if dataset.features["messages"] == FORMAT_MAPPING["chatml"]: + logging.info("Formatting dataset with chatml format") + return conversations_formatting_function(tokenizer, "messages") + if "conversations" in dataset.features: + if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]: + logging.info("Formatting dataset with chatml format") + return conversations_formatting_function(tokenizer, "conversations") + elif dataset.features == FORMAT_MAPPING["instruction"]: + logging.info("Formatting dataset with instruction format") + return instructions_formatting_function(tokenizer) + + return None diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 9c06b102ff..8d3320969e 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -36,6 +36,7 @@ from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalPrediction +from ..extras.dataset_formatting import get_formatting_func_from_dataset from ..import_utils import is_peft_available from .utils import ( ConstantLengthDataset, @@ -237,6 +238,11 @@ def make_inputs_require_grad(module, input, output): elif not self._trainer_supports_neftune: self.neftune_noise_alpha = neftune_noise_alpha + if formatting_func is None and dataset_text_field is None: + # check if dataset has ChatML format or instruction format and is supported + # if not stays #None + formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer) + if not packing: if dataset_text_field is None and formatting_func is None: raise ValueError(