generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for ChatML dataset format in (#1208)
* Add support for ChatML dataset format in SFTTrainer * fix formatting * fix tests * more comment * fix intent * fix doc string * Update dataset_formatting.py * Update dataset_formatting.py * add documentation * Update sft_trainer.mdx * add leonardos comment and more tests * added more tests and fixed batching * style * comment in
- Loading branch information
1 parent
163ca9f
commit 776939d
Showing
5 changed files
with
331 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import unittest | ||
from typing import Callable | ||
|
||
from datasets import Dataset, load_dataset | ||
from transformers import AutoTokenizer | ||
|
||
from trl.extras.dataset_formatting import get_formatting_func_from_dataset | ||
|
||
|
||
class DatasetFormattingTestCase(unittest.TestCase): | ||
def setUp(self): | ||
self.llama_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") | ||
self.chatml_tokenizer = AutoTokenizer.from_pretrained("philschmid/gpt2-chatml-tokenizer") | ||
|
||
def test_get_formatting_func_from_dataset_with_chatml_messages(self): | ||
dataset = Dataset.from_dict( | ||
{ | ||
"messages": [ | ||
[ | ||
{"role": "system", "content": "You are helpful"}, | ||
{"role": "user", "content": "Hello"}, | ||
{"role": "assistant", "content": "Hi, how can I help you?"}, | ||
] | ||
] | ||
} | ||
) | ||
|
||
# Llama tokenizer | ||
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) | ||
self.assertTrue(isinstance(formatting_func, Callable)) | ||
formatted_text = formatting_func(dataset[0]) | ||
self.assertEqual( | ||
formatted_text, | ||
"<s>[INST] <<SYS>>\nYou are helpful\n<</SYS>>\n\nHello [/INST] Hi, how can I help you? </s>", | ||
) | ||
formatted_text = formatting_func(dataset[0:1]) | ||
self.assertEqual( | ||
formatted_text, | ||
["<s>[INST] <<SYS>>\nYou are helpful\n<</SYS>>\n\nHello [/INST] Hi, how can I help you? </s>"], | ||
) | ||
|
||
# ChatML tokenizer | ||
formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer) | ||
formatted_text = formatting_func(dataset[0]) | ||
self.assertEqual( | ||
formatted_text, | ||
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n", | ||
) | ||
formatted_text = formatting_func(dataset[0:1]) | ||
self.assertEqual( | ||
formatted_text, | ||
[ | ||
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n" | ||
], | ||
) | ||
|
||
def test_get_formatting_func_from_dataset_with_chatml_conversations(self): | ||
dataset = Dataset.from_dict( | ||
{ | ||
"conversations": [ | ||
[ | ||
{"role": "system", "content": "You are helpful"}, | ||
{"role": "user", "content": "Hello"}, | ||
{"role": "assistant", "content": "Hi, how can I help you?"}, | ||
] | ||
] | ||
} | ||
) | ||
# Llama tokenizer | ||
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) | ||
self.assertTrue(isinstance(formatting_func, Callable)) | ||
formatted_text = formatting_func(dataset[0]) | ||
self.assertEqual( | ||
formatted_text, | ||
"<s>[INST] <<SYS>>\nYou are helpful\n<</SYS>>\n\nHello [/INST] Hi, how can I help you? </s>", | ||
) | ||
formatted_text = formatting_func(dataset[0:1]) | ||
self.assertEqual( | ||
formatted_text, | ||
["<s>[INST] <<SYS>>\nYou are helpful\n<</SYS>>\n\nHello [/INST] Hi, how can I help you? </s>"], | ||
) | ||
|
||
# ChatML tokenizer | ||
formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer) | ||
formatted_text = formatting_func(dataset[0]) | ||
self.assertEqual( | ||
formatted_text, | ||
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n", | ||
) | ||
formatted_text = formatting_func(dataset[0:1]) | ||
self.assertEqual( | ||
formatted_text, | ||
[ | ||
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n" | ||
], | ||
) | ||
|
||
def test_get_formatting_func_from_dataset_with_instruction(self): | ||
dataset = Dataset.from_list( | ||
[{"prompt": "What is 2+2?", "completion": "4"}, {"prompt": "What is 3+3?", "completion": "6"}] | ||
) | ||
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) | ||
self.assertIsNotNone(formatting_func) | ||
self.assertTrue(isinstance(formatting_func, Callable)) | ||
formatted_text = formatting_func(dataset[0]) | ||
self.assertEqual(formatted_text, "<s>[INST] What is 2+2? [/INST] 4 </s>") | ||
formatted_text = formatting_func(dataset[0:1]) | ||
self.assertEqual(formatted_text, ["<s>[INST] What is 2+2? [/INST] 4 </s>"]) | ||
|
||
def test_get_formatting_func_from_dataset_from_hub(self): | ||
ds_1 = load_dataset("philschmid/trl-test-instruction", split="train") | ||
ds_2 = load_dataset("philschmid/dolly-15k-oai-style", split="train") | ||
for ds in [ds_1, ds_2]: | ||
formatting_func = get_formatting_func_from_dataset(ds, self.llama_tokenizer) | ||
self.assertIsNotNone(formatting_func) | ||
self.assertTrue(isinstance(formatting_func, Callable)) | ||
ds_3 = load_dataset("philschmid/guanaco-sharegpt-style", split="train") | ||
formatting_func = get_formatting_func_from_dataset(ds_3, self.llama_tokenizer) | ||
self.assertIsNone(formatting_func) | ||
|
||
def test_get_formatting_func_from_dataset_with_unknown_format(self): | ||
dataset = Dataset.from_dict({"text": "test"}) | ||
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) | ||
self.assertIsNone(formatting_func) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import logging | ||
from typing import Callable, Literal, Optional, Union | ||
|
||
from datasets import Dataset, Value | ||
from transformers import AutoTokenizer | ||
|
||
from ..trainer.utils import ConstantLengthDataset | ||
|
||
|
||
FORMAT_MAPPING = { | ||
"chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}], | ||
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, | ||
} | ||
|
||
|
||
def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]): | ||
r""" | ||
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer | ||
apply chat template to the dataset | ||
""" | ||
|
||
def format_dataset(examples): | ||
if isinstance(examples[messages_field][0], list): | ||
output_texts = [] | ||
for i in range(len(examples[messages_field])): | ||
output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False)) | ||
return output_texts | ||
else: | ||
return tokenizer.apply_chat_template(examples[messages_field], tokenize=False) | ||
|
||
return format_dataset | ||
|
||
|
||
def instructions_formatting_function(tokenizer: AutoTokenizer): | ||
r""" | ||
return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer | ||
apply chat template to the dataset | ||
""" | ||
|
||
def format_dataset(examples): | ||
if isinstance(examples["prompt"], list): | ||
output_texts = [] | ||
for i in range(len(examples["prompt"])): | ||
converted_sample = [ | ||
{"role": "user", "content": examples["prompt"][i]}, | ||
{"role": "assistant", "content": examples["completion"][i]}, | ||
] | ||
output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False)) | ||
return output_texts | ||
else: | ||
converted_sample = [ | ||
{"role": "user", "content": examples["prompt"]}, | ||
{"role": "assistant", "content": examples["completion"]}, | ||
] | ||
return tokenizer.apply_chat_template(converted_sample, tokenize=False) | ||
|
||
return format_dataset | ||
|
||
|
||
def get_formatting_func_from_dataset( | ||
dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer | ||
) -> Optional[Callable]: | ||
r""" | ||
Finds the correct formatting function based on the dataset structure. Currently supported datasets are: | ||
- `ChatML` with [{"role": str, "content": str}] | ||
- `instruction` with [{"prompt": str, "completion": str}] | ||
Args: | ||
dataset (Dataset): User dataset | ||
tokenizer (AutoTokenizer): Tokenizer used for formatting | ||
Returns: | ||
Callable: Formatting function if the dataset format is supported else None | ||
""" | ||
if isinstance(dataset, Dataset): | ||
if "messages" in dataset.features: | ||
if dataset.features["messages"] == FORMAT_MAPPING["chatml"]: | ||
logging.info("Formatting dataset with chatml format") | ||
return conversations_formatting_function(tokenizer, "messages") | ||
if "conversations" in dataset.features: | ||
if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]: | ||
logging.info("Formatting dataset with chatml format") | ||
return conversations_formatting_function(tokenizer, "conversations") | ||
elif dataset.features == FORMAT_MAPPING["instruction"]: | ||
logging.info("Formatting dataset with instruction format") | ||
return instructions_formatting_function(tokenizer) | ||
|
||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters