Skip to content

Commit

Permalink
Add support for ChatML dataset format in (#1208)
Browse files Browse the repository at this point in the history
* Add support for ChatML dataset format in
SFTTrainer

* fix formatting

* fix tests

* more comment

* fix intent

* fix doc string

* Update dataset_formatting.py

* Update dataset_formatting.py

* add documentation

* Update sft_trainer.mdx

* add leonardos comment and more tests

* added more tests and fixed batching

* style

* comment in
  • Loading branch information
philschmid authored Jan 12, 2024
1 parent 163ca9f commit 776939d
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 6 deletions.
53 changes: 48 additions & 5 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,49 @@ response_template_ids = tokenizer.encode(response_template_with_context, add_spe
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
```

### Dataset format support

The [`SFTTrainer`] supports popular dataset formats. This allows you to pass the dataset to the trainer without any pre-processing directly. The following formats are supported:
* conversational format
```json
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "..."}]}
```
* instruction format
```json
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
```

If your dataset uses one of the above formats, you can directly pass it to the trainer without pre-processing. The [`SFTTrainer`] will then format the dataset for you using the defined format from the model's tokenizer with the [apply_chat_template](https://huggingface.co/docs/transformers/main/en/chat_templating#templates-for-chat-models) method.


```python
from datasets import load_dataset
from trl import SFTTrainer

...

# load jsonl dataset
dataset = load_dataset("json", data_files="path/to/dataset.jsonl", split="train")
# load dataset from the HuggingFace Hub
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")

...

trainer = SFTTrainer(
"facebook/opt-350m",
args=training_args,
train_dataset=dataset,
packing=True,
)
```

If the dataset is not in one those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.


### Format your input prompts

For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response.
Expand Down Expand Up @@ -346,11 +389,11 @@ Note that you cannot train your model using Flash Attention 1 on an arbitrary da
Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB.

| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step |
|----------------|-------------------|-------------|------------|------------------------|
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
| | facebook/opt-350m | 2048 | 8 | **OOM** |
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
| | facebook/opt-350m | 2048 | 4 | ~148.9s |
| ---------------- | ----------------- | ----------- | ---------- | ---------------------- |
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
| | facebook/opt-350m | 2048 | 8 | **OOM** |
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
| | facebook/opt-350m | 2048 | 4 | ~148.9s |

### Using Flash Attention-2

Expand Down
124 changes: 124 additions & 0 deletions tests/test_dataset_formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import unittest
from typing import Callable

from datasets import Dataset, load_dataset
from transformers import AutoTokenizer

from trl.extras.dataset_formatting import get_formatting_func_from_dataset


class DatasetFormattingTestCase(unittest.TestCase):
def setUp(self):
self.llama_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
self.chatml_tokenizer = AutoTokenizer.from_pretrained("philschmid/gpt2-chatml-tokenizer")

def test_get_formatting_func_from_dataset_with_chatml_messages(self):
dataset = Dataset.from_dict(
{
"messages": [
[
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help you?"},
]
]
}
)

# Llama tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
self.assertTrue(isinstance(formatting_func, Callable))
formatted_text = formatting_func(dataset[0])
self.assertEqual(
formatted_text,
"<s>[INST] <<SYS>>\nYou are helpful\n<</SYS>>\n\nHello [/INST] Hi, how can I help you? </s>",
)
formatted_text = formatting_func(dataset[0:1])
self.assertEqual(
formatted_text,
["<s>[INST] <<SYS>>\nYou are helpful\n<</SYS>>\n\nHello [/INST] Hi, how can I help you? </s>"],
)

# ChatML tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer)
formatted_text = formatting_func(dataset[0])
self.assertEqual(
formatted_text,
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n",
)
formatted_text = formatting_func(dataset[0:1])
self.assertEqual(
formatted_text,
[
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
],
)

def test_get_formatting_func_from_dataset_with_chatml_conversations(self):
dataset = Dataset.from_dict(
{
"conversations": [
[
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help you?"},
]
]
}
)
# Llama tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
self.assertTrue(isinstance(formatting_func, Callable))
formatted_text = formatting_func(dataset[0])
self.assertEqual(
formatted_text,
"<s>[INST] <<SYS>>\nYou are helpful\n<</SYS>>\n\nHello [/INST] Hi, how can I help you? </s>",
)
formatted_text = formatting_func(dataset[0:1])
self.assertEqual(
formatted_text,
["<s>[INST] <<SYS>>\nYou are helpful\n<</SYS>>\n\nHello [/INST] Hi, how can I help you? </s>"],
)

# ChatML tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer)
formatted_text = formatting_func(dataset[0])
self.assertEqual(
formatted_text,
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n",
)
formatted_text = formatting_func(dataset[0:1])
self.assertEqual(
formatted_text,
[
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
],
)

def test_get_formatting_func_from_dataset_with_instruction(self):
dataset = Dataset.from_list(
[{"prompt": "What is 2+2?", "completion": "4"}, {"prompt": "What is 3+3?", "completion": "6"}]
)
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
self.assertIsNotNone(formatting_func)
self.assertTrue(isinstance(formatting_func, Callable))
formatted_text = formatting_func(dataset[0])
self.assertEqual(formatted_text, "<s>[INST] What is 2+2? [/INST] 4 </s>")
formatted_text = formatting_func(dataset[0:1])
self.assertEqual(formatted_text, ["<s>[INST] What is 2+2? [/INST] 4 </s>"])

def test_get_formatting_func_from_dataset_from_hub(self):
ds_1 = load_dataset("philschmid/trl-test-instruction", split="train")
ds_2 = load_dataset("philschmid/dolly-15k-oai-style", split="train")
for ds in [ds_1, ds_2]:
formatting_func = get_formatting_func_from_dataset(ds, self.llama_tokenizer)
self.assertIsNotNone(formatting_func)
self.assertTrue(isinstance(formatting_func, Callable))
ds_3 = load_dataset("philschmid/guanaco-sharegpt-style", split="train")
formatting_func = get_formatting_func_from_dataset(ds_3, self.llama_tokenizer)
self.assertIsNone(formatting_func)

def test_get_formatting_func_from_dataset_with_unknown_format(self):
dataset = Dataset.from_dict({"text": "test"})
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
self.assertIsNone(formatting_func)
66 changes: 65 additions & 1 deletion tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,42 @@ def setUpClass(cls):
],
}
)
cls.dummy_chatml_dataset = Dataset.from_dict(
{
"messages": [
[
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help you?"},
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
{"role": "user", "content": "What is 3+3?"},
{"role": "assistant", "content": "6"},
],
[
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help you?"},
],
]
}
)
cls.dummy_instruction_dataset = Dataset.from_list(
[
{"prompt": "What is 2+2?", "completion": "4"},
{"prompt": "What is 3+3?", "completion": "6"},
{"prompt": "What is 4+4?", "completion": "8"},
{"prompt": "What is 2+2?", "completion": "4"},
{"prompt": "What is 3+3?", "completion": "6"},
{"prompt": "What is 4+4?", "completion": "8"},
{"prompt": "What is 2+2?", "completion": "4"},
{"prompt": "What is 3+3?", "completion": "6"},
{"prompt": "What is 4+4?", "completion": "8"},
{"prompt": "What is 2+2?", "completion": "4"},
{"prompt": "What is 3+3?", "completion": "6"},
{"prompt": "What is 4+4?", "completion": "8"},
]
)

cls.train_dataset = ConstantLengthDataset(
cls.tokenizer,
Expand Down Expand Up @@ -171,7 +207,35 @@ def test_sft_trainer_uncorrect_data(self):
train_dataset=self.dummy_dataset,
packing=True,
)

# this should work since the dummy chatml include the correct format
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
max_seq_length=32, # make sure there is at least 1 packed sequence
num_of_sequences=32,
packing=True,
)
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
packing=False,
)
# this should work since the dummy instruction dataset is the correct format
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_instruction_dataset,
max_seq_length=16, # make sure there is at least 1 packed sequence
packing=True,
)
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_instruction_dataset,
packing=False,
)
# This should work
_ = SFTTrainer(
model=self.model,
Expand Down
88 changes: 88 additions & 0 deletions trl/extras/dataset_formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import logging
from typing import Callable, Literal, Optional, Union

from datasets import Dataset, Value
from transformers import AutoTokenizer

from ..trainer.utils import ConstantLengthDataset


FORMAT_MAPPING = {
"chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
}


def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]):
r"""
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer
apply chat template to the dataset
"""

def format_dataset(examples):
if isinstance(examples[messages_field][0], list):
output_texts = []
for i in range(len(examples[messages_field])):
output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False))
return output_texts
else:
return tokenizer.apply_chat_template(examples[messages_field], tokenize=False)

return format_dataset


def instructions_formatting_function(tokenizer: AutoTokenizer):
r"""
return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer
apply chat template to the dataset
"""

def format_dataset(examples):
if isinstance(examples["prompt"], list):
output_texts = []
for i in range(len(examples["prompt"])):
converted_sample = [
{"role": "user", "content": examples["prompt"][i]},
{"role": "assistant", "content": examples["completion"][i]},
]
output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False))
return output_texts
else:
converted_sample = [
{"role": "user", "content": examples["prompt"]},
{"role": "assistant", "content": examples["completion"]},
]
return tokenizer.apply_chat_template(converted_sample, tokenize=False)

return format_dataset


def get_formatting_func_from_dataset(
dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer
) -> Optional[Callable]:
r"""
Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
- `ChatML` with [{"role": str, "content": str}]
- `instruction` with [{"prompt": str, "completion": str}]
Args:
dataset (Dataset): User dataset
tokenizer (AutoTokenizer): Tokenizer used for formatting
Returns:
Callable: Formatting function if the dataset format is supported else None
"""
if isinstance(dataset, Dataset):
if "messages" in dataset.features:
if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
logging.info("Formatting dataset with chatml format")
return conversations_formatting_function(tokenizer, "messages")
if "conversations" in dataset.features:
if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]:
logging.info("Formatting dataset with chatml format")
return conversations_formatting_function(tokenizer, "conversations")
elif dataset.features == FORMAT_MAPPING["instruction"]:
logging.info("Formatting dataset with instruction format")
return instructions_formatting_function(tokenizer)

return None
6 changes: 6 additions & 0 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction

from ..extras.dataset_formatting import get_formatting_func_from_dataset
from ..import_utils import is_peft_available
from .utils import (
ConstantLengthDataset,
Expand Down Expand Up @@ -237,6 +238,11 @@ def make_inputs_require_grad(module, input, output):
elif not self._trainer_supports_neftune:
self.neftune_noise_alpha = neftune_noise_alpha

if formatting_func is None and dataset_text_field is None:
# check if dataset has ChatML format or instruction format and is supported
# if not stays #None
formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer)

if not packing:
if dataset_text_field is None and formatting_func is None:
raise ValueError(
Expand Down

0 comments on commit 776939d

Please sign in to comment.