Skip to content

Commit

Permalink
return assistant generated tokens mask in apply_chat_template
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigottesman committed May 5, 2024
1 parent 91d155e commit c14a7f6
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 6 deletions.
81 changes: 75 additions & 6 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,7 @@ def apply_chat_template(
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_dict: bool = False,
return_assistant_mask: bool = False,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
Expand Down Expand Up @@ -1743,6 +1744,9 @@ def apply_chat_template(
"of tokenizer outputs to return."
)

if return_assistant_mask and not return_dict:
raise ValueError("`return_assistant_mask=True` is incompatible with `return_dict=False`")

if tokenizer_kwargs is None:
tokenizer_kwargs = {}

Expand Down Expand Up @@ -1804,15 +1808,20 @@ def apply_chat_template(
is_batched = False

rendered = []
all_generation_indices = []
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
for chat in conversations:
if hasattr(chat, "messages"):
# Indicates it's a Conversation object
chat = chat.messages
rendered_chat = compiled_template.render(
messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs
rendered_chat, generation_indices = self._render_with_assistant_indices(
compiled_template=compiled_template,
messages=chat,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
)
rendered.append(rendered_chat)
all_generation_indices.append(generation_indices)

if not is_batched:
rendered = rendered[0]
Expand All @@ -1828,17 +1837,45 @@ def apply_chat_template(
**tokenizer_kwargs,
)
if return_dict:
if return_assistant_mask:
assistant_mask = []
if is_batched or return_tensors:
input_ids = out["input_ids"]
else:
input_ids = [out["input_ids"]]
for i in range(len(input_ids)):
current_assistent_mask = [0] * len(input_ids[i])
for assistant_start_char, assistant_end_char in all_generation_indices[i]:
for char_index in range(assistant_start_char, assistant_end_char):
token_index = out.char_to_token(i, char_index)
if token_index is None:
continue # will happen on spaces in bpe tokenizers
current_assistent_mask[token_index] = 1
assistant_mask.append(current_assistent_mask)
out["assistant_mask"] = assistant_mask if is_batched else assistant_mask[0]
return out
else:
return out["input_ids"]
else:
return rendered

def _render_with_assistant_indices(self, compiled_template, messages, add_generation_prompt, **template_kwargs):
rendered_blocks, generation_indices = compiled_template.environment.new_generation_trackers()
for i in compiled_template.generate(
messages=messages, add_generation_prompt=add_generation_prompt, **template_kwargs
):
rendered_blocks.append(i)
rendered_chat = "".join(rendered_blocks)
# copy for safety as it is mutable and still referenced in the environment
return rendered_chat, generation_indices.copy()

@lru_cache
def _compile_jinja_template(self, chat_template):
try:
import jinja2
from jinja2 import nodes
from jinja2.exceptions import TemplateError
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
except ImportError:
raise ImportError("apply_chat_template requires jinja2 to be installed.")
Expand All @@ -1851,7 +1888,34 @@ def _compile_jinja_template(self, chat_template):
def raise_exception(message):
raise TemplateError(message)

jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
class AssistantTracker(Extension):
tags = {"generation"}

def __init__(self, environment):
super().__init__(environment)
environment.extend(new_generation_trackers=self.new_generation_trackers)
self.rendered_blocks = []
self.generation_indices = []

def new_generation_trackers(self):
self.rendered_blocks = []
self.generation_indices = []
return self.rendered_blocks, self.generation_indices

def parse(self, parser):
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)

@jinja2.pass_eval_context
def _generation_support(self, context, caller):
rv = caller()
start_index = len("".join(self.rendered_blocks))
end_index = start_index + len(rv)
self.generation_indices.append((start_index, end_index))
return rv

jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker])
jinja_env.globals["raise_exception"] = raise_exception
return jinja_env.from_string(chat_template)

Expand All @@ -1863,11 +1927,16 @@ def default_chat_template(self):
"""
return (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\n' }}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'] + '<|im_end|>'}}"
"{% endgeneration %}"
"{{'\n'}}"
"{% endif %}"
"{% endfor %}"
)

@classmethod
Expand Down
4 changes: 4 additions & 0 deletions tests/models/layoutlmv2/test_tokenization_layoutlmv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2483,3 +2483,7 @@ def test_np_encode_plus_sent_to_model(self):
@unittest.skip("Chat is not supported")
def test_chat_template(self):
pass

@unittest.skip("Chat is not supported")
def test_chat_template_return_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/layoutlmv3/test_tokenization_layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2436,3 +2436,7 @@ def test_tf_encode_plus_sent_to_model(self):
@unittest.skip("Chat is not supported")
def test_chat_template(self):
pass

@unittest.skip("Chat is not supported")
def test_chat_template_return_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/layoutxlm/test_tokenization_layoutxlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1954,3 +1954,7 @@ def test_sentencepiece_tokenize_and_decode(self):
@unittest.skip("Chat is not supported")
def test_chat_template(self):
pass

@unittest.skip("Chat is not supported")
def test_chat_template_return_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/markuplm/test_tokenization_markuplm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2316,3 +2316,7 @@ def test_chat_template(self):
@unittest.skip("The model tested fails `Hub -> Fast == Hub -> Slow`, nothing much we can do")
def test_added_tokens_serialization(self):
pass

@unittest.skip("Chat is not supported")
def test_chat_template_return_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/tapas/test_tokenization_tapas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,3 +1275,7 @@ def test_np_encode_plus_sent_to_model(self):
@unittest.skip("Chat is not supported")
def test_chat_template(self):
pass

@unittest.skip("Chat is not supported")
def test_chat_template_return_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/udop/test_tokenization_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,10 @@ def test_offsets_mapping(self):
def test_chat_template(self):
pass

@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_return_mask(self):
pass

@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
Expand Down
80 changes: 80 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,86 @@ def test_chat_template_batched(self):
dummy_conversations, chat_template=dummy_template, tokenize=True
) # Check that no error raised

@require_jinja
def test_chat_template_return_mask(self):
template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'] + '<|im_end|>'}}"
"{% endgeneration %}"
"{{'\n'}}"
"{% endif %}"
"{% endfor %}"
)
conversations = [
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
{"role": "user", "content": "user message 2"},
{"role": "assistant", "content": "assistant message 2"},
],
[
{"role": "system", "content": "system message 3"},
{"role": "user", "content": "user message 3"},
{"role": "assistant", "content": "assistant message 3"},
{"role": "user", "content": "user message 4"},
{"role": "assistant", "content": "assistant message 4"},
],
]
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
if self.test_rust_tokenizer:
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
pretrained_name, padding_side="left", **kwargs
)

# check batched
output = tokenizer_r.apply_chat_template(
conversations,
chat_template=template,
tokenize=True,
return_assistant_mask=True,
return_dict=True,
)
for i, conv in enumerate(conversations):
labels = [
output["input_ids"][i][index] if mask == 1 else -100
for index, mask in enumerate(output["assistant_mask"][i])
]
expected = [
tokenizer_r(f"{m['content']+'<|im_end|>'}", add_special_tokens=False)["input_ids"]
for m in conv
if m["role"] == "assistant"
]
expected = expected[0] + expected[1]
self.assertEqual([t for t in labels if t != -100], expected)

# check not batched
output = tokenizer_r.apply_chat_template(
conversations[0],
chat_template=template,
tokenize=True,
return_assistant_mask=True,
return_dict=True,
)

labels = [
output["input_ids"][index] if mask == 1 else -100
for index, mask in enumerate(output["assistant_mask"])
]
expected = [
tokenizer_r(f"{m['content']+'<|im_end|>'}", add_special_tokens=False)["input_ids"]
for m in conversations[0]
if m["role"] == "assistant"
]
expected = expected[0] + expected[1]
self.assertEqual([t for t in labels if t != -100], expected)

@require_jinja
def test_chat_template_dict(self):
dummy_template_1 = "{{'a'}}"
Expand Down

0 comments on commit c14a7f6

Please sign in to comment.