From 30d8ae999f40e0eb84b6afcf61a560ebeeb451e1 Mon Sep 17 00:00:00 2001 From: Yoni Gottesman Date: Mon, 22 Jul 2024 20:24:43 +0300 Subject: [PATCH] Return assistant generated tokens mask in apply_chat_template (#30650) return assistant generated tokens mask in apply_chat_template --- src/transformers/tokenization_utils_base.py | 120 ++++++++++++++-- .../test_tokenization_layoutlmv2.py | 4 + .../test_tokenization_layoutlmv3.py | 4 + .../layoutxlm/test_tokenization_layoutxlm.py | 4 + .../markuplm/test_tokenization_markuplm.py | 4 + tests/models/tapas/test_tokenization_tapas.py | 4 + tests/models/udop/test_tokenization_udop.py | 4 + tests/test_tokenization_common.py | 129 ++++++++++++++++++ 8 files changed, 265 insertions(+), 8 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 6d2e7f502e0089..434eaa2fac8ba2 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1697,6 +1697,7 @@ def apply_chat_template( max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_dict: bool = False, + return_assistant_tokens_mask: bool = False, tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]: @@ -1747,6 +1748,10 @@ def apply_chat_template( return_dict (`bool`, defaults to `False`): Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer. + return_assistant_tokens_mask (`bool`, defaults to `False`): + Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant, + the mask will contain 1. For user and system tokens, the mask will contain 0. + This functionality is only available for chat templates that support it via the `{% generation %}` keyword. **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template. Returns: @@ -1761,6 +1766,9 @@ def apply_chat_template( "of tokenizer outputs to return." ) + if return_assistant_tokens_mask and not return_dict: + raise ValueError("`return_assistant_tokens_mask=True` is incompatible with `return_dict=False`") + if tokenizer_kwargs is None: tokenizer_kwargs = {} @@ -1813,6 +1821,11 @@ def apply_chat_template( "then to ensure that this model continues working without issues." ) + if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template): + logger.warning_once( + "return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword." + ) + # Compilation function uses a cache to avoid recompiling the same template compiled_template = self._compile_jinja_template(chat_template) @@ -1847,18 +1860,30 @@ def apply_chat_template( raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!") rendered = [] + all_generation_indices = [] template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present for chat in conversations: if hasattr(chat, "messages"): # Indicates it's a Conversation object chat = chat.messages - rendered_chat = compiled_template.render( - messages=chat, - tools=tool_schemas, - documents=documents, - add_generation_prompt=add_generation_prompt, - **template_kwargs, - ) + if return_assistant_tokens_mask: + rendered_chat, generation_indices = self._render_with_assistant_indices( + compiled_template=compiled_template, + messages=chat, + tools=tool_schemas, + documents=documents, + add_generation_prompt=add_generation_prompt, + **template_kwargs, + ) + all_generation_indices.append(generation_indices) + else: + rendered_chat = compiled_template.render( + messages=chat, + tools=tool_schemas, + documents=documents, + add_generation_prompt=add_generation_prompt, + **template_kwargs, + ) rendered.append(rendered_chat) if not is_batched: @@ -1875,17 +1900,54 @@ def apply_chat_template( **tokenizer_kwargs, ) if return_dict: + if return_assistant_tokens_mask: + assistant_masks = [] + if is_batched or return_tensors: + input_ids = out["input_ids"] + else: + input_ids = [out["input_ids"]] + for i in range(len(input_ids)): + current_mask = [0] * len(input_ids[i]) + for assistant_start_char, assistant_end_char in all_generation_indices[i]: + start_token = out.char_to_token(i, assistant_start_char) + end_token = out.char_to_token(i, assistant_end_char - 1) + if start_token is None: + # start_token is out of bounds maybe due to truncation. + break + for token_id in range(start_token, end_token + 1 if end_token else len(input_ids)): + current_mask[token_id] = 1 + assistant_masks.append(current_mask) + out["assistant_masks"] = assistant_masks if is_batched else assistant_masks[0] return out else: return out["input_ids"] else: return rendered + def _render_with_assistant_indices( + self, compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs + ): + rendered_blocks = [] + generation_indices = [] + with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices): + for block in compiled_template.generate( + messages=messages, + tools=tools, + documents=documents, + add_generation_prompt=add_generation_prompt, + **template_kwargs, + ): + rendered_blocks.append(block) + rendered_chat = "".join(rendered_blocks) + return rendered_chat, generation_indices + @lru_cache def _compile_jinja_template(self, chat_template): try: import jinja2 + from jinja2 import nodes from jinja2.exceptions import TemplateError + from jinja2.ext import Extension from jinja2.sandbox import ImmutableSandboxedEnvironment except ImportError: raise ImportError("apply_chat_template requires jinja2 to be installed.") @@ -1903,7 +1965,49 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False) # We also expose some options like custom indents and separators return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) - jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) + class AssistantTracker(Extension): + # This extension is used to track the indices of assistant-generated tokens in the rendered chat + tags = {"generation"} + + def __init__(self, environment: ImmutableSandboxedEnvironment): + # The class is only initiated by jinja. + super().__init__(environment) + environment.extend(activate_tracker=self.activate_tracker) + self._rendered_blocks = None + self._generation_indices = None + + def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: + lineno = next(parser.stream).lineno + body = parser.parse_statements(["name:endgeneration"], drop_needle=True) + return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno) + + @jinja2.pass_eval_context + def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str: + rv = caller() + if self.is_active(): + # Only track generation indices if the tracker is active + start_index = len("".join(self._rendered_blocks)) + end_index = start_index + len(rv) + self._generation_indices.append((start_index, end_index)) + return rv + + def is_active(self) -> bool: + return self._rendered_blocks or self._generation_indices + + @contextmanager + def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]): + try: + if self.is_active(): + raise ValueError("AssistantTracker should not be reused before closed") + self._rendered_blocks = rendered_blocks + self._generation_indices = generation_indices + + yield + finally: + self._rendered_blocks = None + self._generation_indices = None + + jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker]) jinja_env.filters["tojson"] = tojson jinja_env.globals["raise_exception"] = raise_exception return jinja_env.from_string(chat_template) diff --git a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py index 0dbeef0c4176c7..bb526e140e5740 100644 --- a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py @@ -2483,3 +2483,7 @@ def test_np_encode_plus_sent_to_model(self): @unittest.skip(reason="Chat is not supported") def test_chat_template(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template_return_assistant_tokens_mask(self): + pass diff --git a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py index e478e0ac62cb5c..5ea384f0b26422 100644 --- a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py +++ b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py @@ -2436,3 +2436,7 @@ def test_tf_encode_plus_sent_to_model(self): @unittest.skip(reason="Chat is not supported") def test_chat_template(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template_return_assistant_tokens_mask(self): + pass diff --git a/tests/models/layoutxlm/test_tokenization_layoutxlm.py b/tests/models/layoutxlm/test_tokenization_layoutxlm.py index 2f8b19a662ab73..c0e44fcb30491f 100644 --- a/tests/models/layoutxlm/test_tokenization_layoutxlm.py +++ b/tests/models/layoutxlm/test_tokenization_layoutxlm.py @@ -1977,3 +1977,7 @@ def test_sentencepiece_tokenize_and_decode(self): @unittest.skip(reason="Chat is not supported") def test_chat_template(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template_return_assistant_tokens_mask(self): + pass diff --git a/tests/models/markuplm/test_tokenization_markuplm.py b/tests/models/markuplm/test_tokenization_markuplm.py index b2c0d20bdb2434..458df94ec2fbcc 100644 --- a/tests/models/markuplm/test_tokenization_markuplm.py +++ b/tests/models/markuplm/test_tokenization_markuplm.py @@ -2316,3 +2316,7 @@ def test_chat_template(self): @unittest.skip(reason="The model tested fails `Hub -> Fast == Hub -> Slow`, nothing much we can do") def test_added_tokens_serialization(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template_return_assistant_tokens_mask(self): + pass diff --git a/tests/models/tapas/test_tokenization_tapas.py b/tests/models/tapas/test_tokenization_tapas.py index 8fe65438d5cac7..a9b8e9a0c77fa6 100644 --- a/tests/models/tapas/test_tokenization_tapas.py +++ b/tests/models/tapas/test_tokenization_tapas.py @@ -1277,3 +1277,7 @@ def test_np_encode_plus_sent_to_model(self): @unittest.skip(reason="Chat is not supported") def test_chat_template(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template_return_assistant_tokens_mask(self): + pass diff --git a/tests/models/udop/test_tokenization_udop.py b/tests/models/udop/test_tokenization_udop.py index 151695c1c126fc..78153172f2c729 100644 --- a/tests/models/udop/test_tokenization_udop.py +++ b/tests/models/udop/test_tokenization_udop.py @@ -1157,6 +1157,10 @@ def test_offsets_mapping(self): def test_chat_template(self): pass + @unittest.skip(reason="Chat template tests don't play well with table/layout models.") + def test_chat_template_return_assistant_tokens_mask(self): + pass + @unittest.skip(reason="Chat template tests don't play well with table/layout models.") def test_chat_template_batched(self): pass diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 867ca859ebc109..a1fb5124a457f2 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1153,6 +1153,135 @@ def test_chat_template_batched(self): dummy_conversations, chat_template=dummy_template, tokenize=True ) # Check that no error raised + @require_jinja + def test_chat_template_return_assistant_tokens_mask(self): + dummy_template = ( + "{% for message in messages %}" + "{% if (message['role'] != 'assistant') %}" + "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" + "{% elif (message['role'] == 'assistant')%}" + "{{'<|im_start|>' + message['role'] + '\n'}}" + "{% generation %}" + "{{message['content'] + '<|im_end|>'}}" + "{% endgeneration %}" + "{{'\n'}}" + "{% endif %}" + "{% endfor %}" + ) + conversations = [ + [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + {"role": "assistant", "content": "start turn 1 assistant message. end turn 1"}, + {"role": "user", "content": "user message 2"}, + {"role": "assistant", "content": "start turn 2 assistant message. end turn 2"}, + ], + [ + {"role": "system", "content": "system message 3"}, + {"role": "user", "content": "user message 3"}, + {"role": "assistant", "content": "start turn 3 assistant message. end turn 3"}, + {"role": "user", "content": "user message 4"}, + {"role": "assistant", "content": "start turn 4 assistant message. end turn 4"}, + ], + ] + + # These are the prefix and suffix strings of all the assistant messages. Used to find the assistant substring + # in the entire chat string, and then find the corresponding tokens in the tokenized output. + assistant_prefix_suffix = [ + [("start turn 1", "end turn 1<|im_end|>"), ("start turn 2", "end turn 2<|im_end|>")], + [("start turn 3", "end turn 3<|im_end|>"), ("start turn 4", "end turn 4<|im_end|>")], + ] + for tokenizer, pretrained_name, _ in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + if not self.test_rust_tokenizer: + self.skipTest(reason="No fast tokenizer defined") + + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name) + + # check batched + output = tokenizer_r.apply_chat_template( + conversations, + chat_template=dummy_template, + tokenize=True, + return_assistant_tokens_mask=True, + return_dict=True, + ) + for i, conv in enumerate(conversations): + chat_string = tokenizer_r.apply_chat_template( + conversations[i], tokenize=False, chat_template=dummy_template + ) + assistant_start = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][0][0])) + assistant_end = output.char_to_token( + i, + chat_string.index(assistant_prefix_suffix[i][0][1]) + + len(assistant_prefix_suffix[i][0][1]) + - 1, + ) + + assistant_start2 = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][1][0])) + assistant_end2 = output.char_to_token( + i, + chat_string.index(assistant_prefix_suffix[i][1][1]) + + len(assistant_prefix_suffix[i][1][1]) + - 1, + ) + + # assert 1 in first assistant message + self.assertEqual( + output["assistant_masks"][i][assistant_start : assistant_end + 1], + [1] * (assistant_end - assistant_start + 1), + ) + # assert 1 second assistant message + self.assertEqual( + output["assistant_masks"][i][assistant_start2 : assistant_end2 + 1], + [1] * (assistant_end2 - assistant_start2 + 1), + ) + + # assert 0 in user/system indices + self.assertEqual(output["assistant_masks"][i][:assistant_start], [0] * assistant_start) + self.assertEqual( + output["assistant_masks"][i][assistant_end + 1 : assistant_start2], + [0] * (assistant_start2 - assistant_end - 1), + ) + + # check not batched + output = tokenizer_r.apply_chat_template( + conversations[0], + chat_template=dummy_template, + tokenize=True, + return_assistant_tokens_mask=True, + return_dict=True, + ) + + chat_string = tokenizer_r.apply_chat_template( + conversations[0], tokenize=False, chat_template=dummy_template + ) + assistant_start = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][0][0])) + assistant_end = output.char_to_token( + 0, chat_string.index(assistant_prefix_suffix[0][0][1]) + len(assistant_prefix_suffix[0][0][1]) - 1 + ) + assistant_start2 = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][1][0])) + assistant_end2 = output.char_to_token( + 0, chat_string.index(assistant_prefix_suffix[0][1][1]) + len(assistant_prefix_suffix[0][1][1]) - 1 + ) + + # assert 1 in assistant indices + self.assertEqual( + output["assistant_masks"][assistant_start : assistant_end + 1], + [1] * (assistant_end - assistant_start + 1), + ) + self.assertEqual( + output["assistant_masks"][assistant_start2 : assistant_end2 + 1], + [1] * (assistant_end2 - assistant_start2 + 1), + ) + + # assert 0 in user/system indices + self.assertEqual(output["assistant_masks"][:assistant_start], [0] * assistant_start) + self.assertEqual( + output["assistant_masks"][assistant_end + 1 : assistant_start2], + [0] * (assistant_start2 - assistant_end - 1), + ) + @require_jinja def test_chat_template_dict(self): dummy_template_1 = "{{'a'}}"