diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 0a57345e395453..7eb89b42a3ab72 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1692,6 +1692,7 @@ def apply_chat_template( max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_dict: bool = False, + return_assistant_mask: bool = False, tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]: @@ -1743,6 +1744,9 @@ def apply_chat_template( "of tokenizer outputs to return." ) + if return_assistant_mask and not return_dict: + raise ValueError("`return_assistant_mask=True` is incompatible with `return_dict=False`") + if tokenizer_kwargs is None: tokenizer_kwargs = {} @@ -1804,15 +1808,20 @@ def apply_chat_template( is_batched = False rendered = [] + all_generation_indices = [] template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present for chat in conversations: if hasattr(chat, "messages"): # Indicates it's a Conversation object chat = chat.messages - rendered_chat = compiled_template.render( - messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs + rendered_chat, generation_indices = self._render_with_assistant_indices( + compiled_template=compiled_template, + messages=chat, + add_generation_prompt=add_generation_prompt, + **template_kwargs, ) rendered.append(rendered_chat) + all_generation_indices.append(generation_indices) if not is_batched: rendered = rendered[0] @@ -1828,17 +1837,45 @@ def apply_chat_template( **tokenizer_kwargs, ) if return_dict: + if return_assistant_mask: + assistant_mask = [] + if is_batched or return_tensors: + input_ids = out["input_ids"] + else: + input_ids = [out["input_ids"]] + for i in range(len(input_ids)): + current_assistent_mask = [0] * len(input_ids[i]) + for assistant_start_char, assistant_end_char in all_generation_indices[i]: + for char_index in range(assistant_start_char, assistant_end_char): + token_index = out.char_to_token(i, char_index) + if token_index is None: + continue # will happen on spaces in bpe tokenizers + current_assistent_mask[token_index] = 1 + assistant_mask.append(current_assistent_mask) + out["assistant_mask"] = assistant_mask if is_batched else assistant_mask[0] return out else: return out["input_ids"] else: return rendered + def _render_with_assistant_indices(self, compiled_template, messages, add_generation_prompt, **template_kwargs): + rendered_blocks, generation_indices = compiled_template.environment.new_generation_trackers() + for i in compiled_template.generate( + messages=messages, add_generation_prompt=add_generation_prompt, **template_kwargs + ): + rendered_blocks.append(i) + rendered_chat = "".join(rendered_blocks) + # copy for safety as it is mutable and still referenced in the environment + return rendered_chat, generation_indices.copy() + @lru_cache def _compile_jinja_template(self, chat_template): try: import jinja2 + from jinja2 import nodes from jinja2.exceptions import TemplateError + from jinja2.ext import Extension from jinja2.sandbox import ImmutableSandboxedEnvironment except ImportError: raise ImportError("apply_chat_template requires jinja2 to be installed.") @@ -1851,7 +1888,34 @@ def _compile_jinja_template(self, chat_template): def raise_exception(message): raise TemplateError(message) - jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) + class AssistantTracker(Extension): + tags = {"generation"} + + def __init__(self, environment): + super().__init__(environment) + environment.extend(new_generation_trackers=self.new_generation_trackers) + self.rendered_blocks = [] + self.generation_indices = [] + + def new_generation_trackers(self): + self.rendered_blocks = [] + self.generation_indices = [] + return self.rendered_blocks, self.generation_indices + + def parse(self, parser): + lineno = next(parser.stream).lineno + body = parser.parse_statements(["name:endgeneration"], drop_needle=True) + return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno) + + @jinja2.pass_eval_context + def _generation_support(self, context, caller): + rv = caller() + start_index = len("".join(self.rendered_blocks)) + end_index = start_index + len(rv) + self.generation_indices.append((start_index, end_index)) + return rv + + jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker]) jinja_env.globals["raise_exception"] = raise_exception return jinja_env.from_string(chat_template) @@ -1863,11 +1927,16 @@ def default_chat_template(self): """ return ( "{% for message in messages %}" + "{% if (message['role'] != 'assistant') %}" "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" - "{% endfor %}" - "{% if add_generation_prompt %}" - "{{ '<|im_start|>assistant\n' }}" + "{% elif (message['role'] == 'assistant')%}" + "{{'<|im_start|>' + message['role'] + '\n'}}" + "{% generation %}" + "{{message['content'] + '<|im_end|>'}}" + "{% endgeneration %}" + "{{'\n'}}" "{% endif %}" + "{% endfor %}" ) @classmethod diff --git a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py index ce6bbd0f01f24f..07a715eb078a51 100644 --- a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py @@ -2483,3 +2483,7 @@ def test_np_encode_plus_sent_to_model(self): @unittest.skip("Chat is not supported") def test_chat_template(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template_return_mask(self): + pass diff --git a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py index 80d29d3a46b176..da3afcbfe36e25 100644 --- a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py +++ b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py @@ -2436,3 +2436,7 @@ def test_tf_encode_plus_sent_to_model(self): @unittest.skip("Chat is not supported") def test_chat_template(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template_return_mask(self): + pass diff --git a/tests/models/layoutxlm/test_tokenization_layoutxlm.py b/tests/models/layoutxlm/test_tokenization_layoutxlm.py index 8f1d353efd57f9..677867036cc9e6 100644 --- a/tests/models/layoutxlm/test_tokenization_layoutxlm.py +++ b/tests/models/layoutxlm/test_tokenization_layoutxlm.py @@ -1954,3 +1954,7 @@ def test_sentencepiece_tokenize_and_decode(self): @unittest.skip("Chat is not supported") def test_chat_template(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template_return_mask(self): + pass diff --git a/tests/models/markuplm/test_tokenization_markuplm.py b/tests/models/markuplm/test_tokenization_markuplm.py index 370b1c569226d6..2ab94b8ab7124f 100644 --- a/tests/models/markuplm/test_tokenization_markuplm.py +++ b/tests/models/markuplm/test_tokenization_markuplm.py @@ -2316,3 +2316,7 @@ def test_chat_template(self): @unittest.skip("The model tested fails `Hub -> Fast == Hub -> Slow`, nothing much we can do") def test_added_tokens_serialization(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template_return_mask(self): + pass diff --git a/tests/models/tapas/test_tokenization_tapas.py b/tests/models/tapas/test_tokenization_tapas.py index 8f2bf9bb69d333..c396fd30b6d3d4 100644 --- a/tests/models/tapas/test_tokenization_tapas.py +++ b/tests/models/tapas/test_tokenization_tapas.py @@ -1275,3 +1275,7 @@ def test_np_encode_plus_sent_to_model(self): @unittest.skip("Chat is not supported") def test_chat_template(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template_return_mask(self): + pass diff --git a/tests/models/udop/test_tokenization_udop.py b/tests/models/udop/test_tokenization_udop.py index d022128ed14756..3f46ea1e5b447c 100644 --- a/tests/models/udop/test_tokenization_udop.py +++ b/tests/models/udop/test_tokenization_udop.py @@ -1157,6 +1157,10 @@ def test_offsets_mapping(self): def test_chat_template(self): pass + @unittest.skip("Chat template tests don't play well with table/layout models.") + def test_chat_template_return_mask(self): + pass + @unittest.skip("Chat template tests don't play well with table/layout models.") def test_chat_template_batched(self): pass diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 76402cd092b6d1..8dba6226d3dade 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1154,6 +1154,86 @@ def test_chat_template_batched(self): dummy_conversations, chat_template=dummy_template, tokenize=True ) # Check that no error raised + @require_jinja + def test_chat_template_return_mask(self): + template = ( + "{% for message in messages %}" + "{% if (message['role'] != 'assistant') %}" + "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" + "{% elif (message['role'] == 'assistant')%}" + "{{'<|im_start|>' + message['role'] + '\n'}}" + "{% generation %}" + "{{message['content'] + '<|im_end|>'}}" + "{% endgeneration %}" + "{{'\n'}}" + "{% endif %}" + "{% endfor %}" + ) + conversations = [ + [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + {"role": "assistant", "content": "assistant message"}, + {"role": "user", "content": "user message 2"}, + {"role": "assistant", "content": "assistant message 2"}, + ], + [ + {"role": "system", "content": "system message 3"}, + {"role": "user", "content": "user message 3"}, + {"role": "assistant", "content": "assistant message 3"}, + {"role": "user", "content": "user message 4"}, + {"role": "assistant", "content": "assistant message 4"}, + ], + ] + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + if self.test_rust_tokenizer: + tokenizer_r = self.rust_tokenizer_class.from_pretrained( + pretrained_name, padding_side="left", **kwargs + ) + + # check batched + output = tokenizer_r.apply_chat_template( + conversations, + chat_template=template, + tokenize=True, + return_assistant_mask=True, + return_dict=True, + ) + for i, conv in enumerate(conversations): + labels = [ + output["input_ids"][i][index] if mask == 1 else -100 + for index, mask in enumerate(output["assistant_mask"][i]) + ] + expected = [ + tokenizer_r(f"{m['content']+'<|im_end|>'}", add_special_tokens=False)["input_ids"] + for m in conv + if m["role"] == "assistant" + ] + expected = expected[0] + expected[1] + self.assertEqual([t for t in labels if t != -100], expected) + + # check not batched + output = tokenizer_r.apply_chat_template( + conversations[0], + chat_template=template, + tokenize=True, + return_assistant_mask=True, + return_dict=True, + ) + + labels = [ + output["input_ids"][index] if mask == 1 else -100 + for index, mask in enumerate(output["assistant_mask"]) + ] + expected = [ + tokenizer_r(f"{m['content']+'<|im_end|>'}", add_special_tokens=False)["input_ids"] + for m in conversations[0] + if m["role"] == "assistant" + ] + expected = expected[0] + expected[1] + self.assertEqual([t for t in labels if t != -100], expected) + @require_jinja def test_chat_template_dict(self): dummy_template_1 = "{{'a'}}"