From 88be2c07e5024cc47e7777d23815602ae22c5f11 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 22 Oct 2024 13:29:32 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A9=20`setup=5Fchat=5Fformat`:=20throw?= =?UTF-8?q?=20error=20if=20there=20is=20already=20a=20template=20in=20base?= =?UTF-8?q?=20model=20(#2252)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * setup_chat_format: throw error if there was already a template * fix lint * clarify in docs * fix test? --------- Co-authored-by: Kashif Rasul --- tests/test_dataset_formatting.py | 2 ++ trl/models/utils.py | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py index 517da43f55..f1e9bcb4d8 100644 --- a/tests/test_dataset_formatting.py +++ b/tests/test_dataset_formatting.py @@ -119,6 +119,8 @@ class SetupChatFormatTestCase(unittest.TestCase): def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + # remove built-in chat_template to simulate a model having no chat_template + self.tokenizer.chat_template = None def test_setup_chat_format(self): original_tokenizer_len = len(self.tokenizer) diff --git a/trl/models/utils.py b/trl/models/utils.py index afdc944154..562b8617ed 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -84,6 +84,8 @@ def setup_chat_format( """ Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. + If the model already has a chat template, this will throw an error. If you want to overwrite it, please set `tokenizer.chat_template` to `None`. + Args: model (`~transformers.PreTrainedModel`): The model to be modified. tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified. @@ -94,6 +96,12 @@ def setup_chat_format( model (`~transformers.PreTrainedModel`): The modified model. tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer. """ + # check if model already had a chat template + if tokenizer.chat_template is not None: + raise ValueError( + "Chat template is already added to the tokenizer. If you want to overwrite it, please set it to None" + ) + # check if format available and retrieve if format not in FORMAT_MAPPING: raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}")