Skip to content

Commit

Permalink
Return assistant generated tokens mask in apply_chat_template (#30650)
Browse files Browse the repository at this point in the history
return assistant generated tokens mask in apply_chat_template
  • Loading branch information
yonigottesman authored and itazap committed Jul 25, 2024
1 parent 8eb922c commit ecf597c
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 8 deletions.
120 changes: 112 additions & 8 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,6 +1697,7 @@ def apply_chat_template(
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_dict: bool = False,
return_assistant_tokens_mask: bool = False,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
Expand Down Expand Up @@ -1747,6 +1748,10 @@ def apply_chat_template(
return_dict (`bool`, defaults to `False`):
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
return_assistant_tokens_mask (`bool`, defaults to `False`):
Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
the mask will contain 1. For user and system tokens, the mask will contain 0.
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
**kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
Returns:
Expand All @@ -1761,6 +1766,9 @@ def apply_chat_template(
"of tokenizer outputs to return."
)

if return_assistant_tokens_mask and not return_dict:
raise ValueError("`return_assistant_tokens_mask=True` is incompatible with `return_dict=False`")

if tokenizer_kwargs is None:
tokenizer_kwargs = {}

Expand Down Expand Up @@ -1813,6 +1821,11 @@ def apply_chat_template(
"then to ensure that this model continues working without issues."
)

if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template):
logger.warning_once(
"return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword."
)

# Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template)

Expand Down Expand Up @@ -1847,18 +1860,30 @@ def apply_chat_template(
raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")

rendered = []
all_generation_indices = []
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
for chat in conversations:
if hasattr(chat, "messages"):
# Indicates it's a Conversation object
chat = chat.messages
rendered_chat = compiled_template.render(
messages=chat,
tools=tool_schemas,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
)
if return_assistant_tokens_mask:
rendered_chat, generation_indices = self._render_with_assistant_indices(
compiled_template=compiled_template,
messages=chat,
tools=tool_schemas,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
)
all_generation_indices.append(generation_indices)
else:
rendered_chat = compiled_template.render(
messages=chat,
tools=tool_schemas,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
)
rendered.append(rendered_chat)

if not is_batched:
Expand All @@ -1875,17 +1900,54 @@ def apply_chat_template(
**tokenizer_kwargs,
)
if return_dict:
if return_assistant_tokens_mask:
assistant_masks = []
if is_batched or return_tensors:
input_ids = out["input_ids"]
else:
input_ids = [out["input_ids"]]
for i in range(len(input_ids)):
current_mask = [0] * len(input_ids[i])
for assistant_start_char, assistant_end_char in all_generation_indices[i]:
start_token = out.char_to_token(i, assistant_start_char)
end_token = out.char_to_token(i, assistant_end_char - 1)
if start_token is None:
# start_token is out of bounds maybe due to truncation.
break
for token_id in range(start_token, end_token + 1 if end_token else len(input_ids)):
current_mask[token_id] = 1
assistant_masks.append(current_mask)
out["assistant_masks"] = assistant_masks if is_batched else assistant_masks[0]
return out
else:
return out["input_ids"]
else:
return rendered

def _render_with_assistant_indices(
self, compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
):
rendered_blocks = []
generation_indices = []
with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
for block in compiled_template.generate(
messages=messages,
tools=tools,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
):
rendered_blocks.append(block)
rendered_chat = "".join(rendered_blocks)
return rendered_chat, generation_indices

@lru_cache
def _compile_jinja_template(self, chat_template):
try:
import jinja2
from jinja2 import nodes
from jinja2.exceptions import TemplateError
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
except ImportError:
raise ImportError("apply_chat_template requires jinja2 to be installed.")
Expand All @@ -1903,7 +1965,49 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False)
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)

jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
class AssistantTracker(Extension):
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
tags = {"generation"}

def __init__(self, environment: ImmutableSandboxedEnvironment):
# The class is only initiated by jinja.
super().__init__(environment)
environment.extend(activate_tracker=self.activate_tracker)
self._rendered_blocks = None
self._generation_indices = None

def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)

@jinja2.pass_eval_context
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
rv = caller()
if self.is_active():
# Only track generation indices if the tracker is active
start_index = len("".join(self._rendered_blocks))
end_index = start_index + len(rv)
self._generation_indices.append((start_index, end_index))
return rv

def is_active(self) -> bool:
return self._rendered_blocks or self._generation_indices

@contextmanager
def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
try:
if self.is_active():
raise ValueError("AssistantTracker should not be reused before closed")
self._rendered_blocks = rendered_blocks
self._generation_indices = generation_indices

yield
finally:
self._rendered_blocks = None
self._generation_indices = None

jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker])
jinja_env.filters["tojson"] = tojson
jinja_env.globals["raise_exception"] = raise_exception
return jinja_env.from_string(chat_template)
Expand Down
4 changes: 4 additions & 0 deletions tests/models/layoutlmv2/test_tokenization_layoutlmv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2483,3 +2483,7 @@ def test_np_encode_plus_sent_to_model(self):
@unittest.skip(reason="Chat is not supported")
def test_chat_template(self):
pass

@unittest.skip("Chat is not supported")
def test_chat_template_return_assistant_tokens_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/layoutlmv3/test_tokenization_layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2436,3 +2436,7 @@ def test_tf_encode_plus_sent_to_model(self):
@unittest.skip(reason="Chat is not supported")
def test_chat_template(self):
pass

@unittest.skip("Chat is not supported")
def test_chat_template_return_assistant_tokens_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/layoutxlm/test_tokenization_layoutxlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1977,3 +1977,7 @@ def test_sentencepiece_tokenize_and_decode(self):
@unittest.skip(reason="Chat is not supported")
def test_chat_template(self):
pass

@unittest.skip("Chat is not supported")
def test_chat_template_return_assistant_tokens_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/markuplm/test_tokenization_markuplm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2316,3 +2316,7 @@ def test_chat_template(self):
@unittest.skip(reason="The model tested fails `Hub -> Fast == Hub -> Slow`, nothing much we can do")
def test_added_tokens_serialization(self):
pass

@unittest.skip("Chat is not supported")
def test_chat_template_return_assistant_tokens_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/tapas/test_tokenization_tapas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,3 +1277,7 @@ def test_np_encode_plus_sent_to_model(self):
@unittest.skip(reason="Chat is not supported")
def test_chat_template(self):
pass

@unittest.skip("Chat is not supported")
def test_chat_template_return_assistant_tokens_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/udop/test_tokenization_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,10 @@ def test_offsets_mapping(self):
def test_chat_template(self):
pass

@unittest.skip(reason="Chat template tests don't play well with table/layout models.")
def test_chat_template_return_assistant_tokens_mask(self):
pass

@unittest.skip(reason="Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
Expand Down
129 changes: 129 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,135 @@ def test_chat_template_batched(self):
dummy_conversations, chat_template=dummy_template, tokenize=True
) # Check that no error raised

@require_jinja
def test_chat_template_return_assistant_tokens_mask(self):
dummy_template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'] + '<|im_end|>'}}"
"{% endgeneration %}"
"{{'\n'}}"
"{% endif %}"
"{% endfor %}"
)
conversations = [
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "start turn 1 assistant message. end turn 1"},
{"role": "user", "content": "user message 2"},
{"role": "assistant", "content": "start turn 2 assistant message. end turn 2"},
],
[
{"role": "system", "content": "system message 3"},
{"role": "user", "content": "user message 3"},
{"role": "assistant", "content": "start turn 3 assistant message. end turn 3"},
{"role": "user", "content": "user message 4"},
{"role": "assistant", "content": "start turn 4 assistant message. end turn 4"},
],
]

# These are the prefix and suffix strings of all the assistant messages. Used to find the assistant substring
# in the entire chat string, and then find the corresponding tokens in the tokenized output.
assistant_prefix_suffix = [
[("start turn 1", "end turn 1<|im_end|>"), ("start turn 2", "end turn 2<|im_end|>")],
[("start turn 3", "end turn 3<|im_end|>"), ("start turn 4", "end turn 4<|im_end|>")],
]
for tokenizer, pretrained_name, _ in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
if not self.test_rust_tokenizer:
self.skipTest(reason="No fast tokenizer defined")

tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name)

# check batched
output = tokenizer_r.apply_chat_template(
conversations,
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
)
for i, conv in enumerate(conversations):
chat_string = tokenizer_r.apply_chat_template(
conversations[i], tokenize=False, chat_template=dummy_template
)
assistant_start = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][0][0]))
assistant_end = output.char_to_token(
i,
chat_string.index(assistant_prefix_suffix[i][0][1])
+ len(assistant_prefix_suffix[i][0][1])
- 1,
)

assistant_start2 = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][1][0]))
assistant_end2 = output.char_to_token(
i,
chat_string.index(assistant_prefix_suffix[i][1][1])
+ len(assistant_prefix_suffix[i][1][1])
- 1,
)

# assert 1 in first assistant message
self.assertEqual(
output["assistant_masks"][i][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
# assert 1 second assistant message
self.assertEqual(
output["assistant_masks"][i][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)

# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][i][:assistant_start], [0] * assistant_start)
self.assertEqual(
output["assistant_masks"][i][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)

# check not batched
output = tokenizer_r.apply_chat_template(
conversations[0],
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
)

chat_string = tokenizer_r.apply_chat_template(
conversations[0], tokenize=False, chat_template=dummy_template
)
assistant_start = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][0][0]))
assistant_end = output.char_to_token(
0, chat_string.index(assistant_prefix_suffix[0][0][1]) + len(assistant_prefix_suffix[0][0][1]) - 1
)
assistant_start2 = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][1][0]))
assistant_end2 = output.char_to_token(
0, chat_string.index(assistant_prefix_suffix[0][1][1]) + len(assistant_prefix_suffix[0][1][1]) - 1
)

# assert 1 in assistant indices
self.assertEqual(
output["assistant_masks"][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
self.assertEqual(
output["assistant_masks"][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)

# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][:assistant_start], [0] * assistant_start)
self.assertEqual(
output["assistant_masks"][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)

@require_jinja
def test_chat_template_dict(self):
dummy_template_1 = "{{'a'}}"
Expand Down

0 comments on commit ecf597c

Please sign in to comment.