Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return assistant generated tokens mask in apply_chat_template #30650

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 112 additions & 8 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,6 +1697,7 @@ def apply_chat_template(
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_dict: bool = False,
return_assistant_tokens_mask: bool = False,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
Expand Down Expand Up @@ -1747,6 +1748,10 @@ def apply_chat_template(
return_dict (`bool`, defaults to `False`):
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
return_assistant_tokens_mask (`bool`, defaults to `False`):
Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
the mask will contain 1. For user and system tokens, the mask will contain 0.
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
**kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.

Returns:
Expand All @@ -1761,6 +1766,9 @@ def apply_chat_template(
"of tokenizer outputs to return."
)

if return_assistant_tokens_mask and not return_dict:
raise ValueError("`return_assistant_tokens_mask=True` is incompatible with `return_dict=False`")

if tokenizer_kwargs is None:
tokenizer_kwargs = {}

Expand Down Expand Up @@ -1813,6 +1821,11 @@ def apply_chat_template(
"then to ensure that this model continues working without issues."
)

if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template):
logger.warning_once(
"return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword."
)

# Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template)

Expand Down Expand Up @@ -1847,18 +1860,30 @@ def apply_chat_template(
raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")

rendered = []
all_generation_indices = []
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
for chat in conversations:
if hasattr(chat, "messages"):
# Indicates it's a Conversation object
chat = chat.messages
rendered_chat = compiled_template.render(
messages=chat,
tools=tool_schemas,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
)
if return_assistant_tokens_mask:
rendered_chat, generation_indices = self._render_with_assistant_indices(
compiled_template=compiled_template,
messages=chat,
tools=tool_schemas,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
)
all_generation_indices.append(generation_indices)
else:
rendered_chat = compiled_template.render(
messages=chat,
tools=tool_schemas,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
)
rendered.append(rendered_chat)

if not is_batched:
Expand All @@ -1875,17 +1900,54 @@ def apply_chat_template(
**tokenizer_kwargs,
)
if return_dict:
if return_assistant_tokens_mask:
assistant_masks = []
if is_batched or return_tensors:
input_ids = out["input_ids"]
else:
input_ids = [out["input_ids"]]
for i in range(len(input_ids)):
current_mask = [0] * len(input_ids[i])
for assistant_start_char, assistant_end_char in all_generation_indices[i]:
start_token = out.char_to_token(i, assistant_start_char)
end_token = out.char_to_token(i, assistant_end_char - 1)
if start_token is None:
# start_token is out of bounds maybe due to truncation.
break
for token_id in range(start_token, end_token + 1 if end_token else len(input_ids)):
current_mask[token_id] = 1
assistant_masks.append(current_mask)
out["assistant_masks"] = assistant_masks if is_batched else assistant_masks[0]
return out
else:
return out["input_ids"]
else:
return rendered

def _render_with_assistant_indices(
self, compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
):
rendered_blocks = []
generation_indices = []
with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
for block in compiled_template.generate(
messages=messages,
tools=tools,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
):
rendered_blocks.append(block)
rendered_chat = "".join(rendered_blocks)
return rendered_chat, generation_indices

@lru_cache
def _compile_jinja_template(self, chat_template):
try:
import jinja2
from jinja2 import nodes
from jinja2.exceptions import TemplateError
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
except ImportError:
raise ImportError("apply_chat_template requires jinja2 to be installed.")
Expand All @@ -1903,7 +1965,49 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False)
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)

jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
class AssistantTracker(Extension):
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
tags = {"generation"}

def __init__(self, environment: ImmutableSandboxedEnvironment):
# The class is only initiated by jinja.
super().__init__(environment)
environment.extend(activate_tracker=self.activate_tracker)
self._rendered_blocks = None
self._generation_indices = None

def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)

@jinja2.pass_eval_context
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
rv = caller()
if self.is_active():
# Only track generation indices if the tracker is active
start_index = len("".join(self._rendered_blocks))
end_index = start_index + len(rv)
self._generation_indices.append((start_index, end_index))
return rv

def is_active(self) -> bool:
return self._rendered_blocks or self._generation_indices

@contextmanager
def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
Copy link
Contributor

@harupy harupy Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts @yonigottesman

In mlflow/mlflow#12757, we found this line throws in python 3.8.

https://github.com/mlflow/mlflow/actions/runs/10056412801/job/27795200814?pr=12757#step:12:1016

    class AssistantTracker(Extension):
        # This extension is used to track the indices of assistant-generated tokens in the rendered chat
        tags = {"generation"}
    
        def __init__(self, environment: ImmutableSandboxedEnvironment):
            # The class is only initiated by jinja.
            super().__init__(environment)
            environment.extend(activate_tracker=self.activate_tracker)
            self._rendered_blocks = None
            self._generation_indices = None
    
        def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
            lineno = next(parser.stream).lineno
            body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
            return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)
    
        @jinja2.pass_eval_context
        def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
            rv = caller()
            if self.is_active():
                # Only track generation indices if the tracker is active
                start_index = len("".join(self._rendered_blocks))
                end_index = start_index + len(rv)
                self._generation_indices.append((start_index, end_index))
            return rv
    
        def is_active(self) -> bool:
            return self._rendered_blocks or self._generation_indices
    
        @contextmanager
>       def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
E       TypeError: 'type' object is not subscriptable

__init__   = <function PreTrainedTokenizerBase._compile_jinja_template.<locals>.AssistantTracker.__init__ at 0x7f013dc78940>
__module__ = 'transformers.tokenization_utils_base'
__qualname__ = 'PreTrainedTokenizerBase._compile_jinja_template.<locals>.AssistantTracker'
_generation_support = <function PreTrainedTokenizerBase._compile_jinja_template.<locals>.AssistantTracker._generation_support at 0x7f013dc78790>
is_active  = <function PreTrainedTokenizerBase._compile_jinja_template.<locals>.AssistantTracker.is_active at 0x7f013dc785e0>
parse      = <function PreTrainedTokenizerBase._compile_jinja_template.<locals>.AssistantTracker.parse at 0x7f013dc78820>
tags       = {'generation'}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]):

or from __future__ import annotations needs to be added.

Copy link
Collaborator

@amyeroberts amyeroberts Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for flagging! Opening a PR to fix

https://github.com/huggingface/transformers/pull/32155/files

try:
if self.is_active():
raise ValueError("AssistantTracker should not be reused before closed")
self._rendered_blocks = rendered_blocks
self._generation_indices = generation_indices

yield
finally:
self._rendered_blocks = None
self._generation_indices = None

jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker])
jinja_env.filters["tojson"] = tojson
jinja_env.globals["raise_exception"] = raise_exception
return jinja_env.from_string(chat_template)
Expand Down
4 changes: 4 additions & 0 deletions tests/models/layoutlmv2/test_tokenization_layoutlmv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2483,3 +2483,7 @@ def test_np_encode_plus_sent_to_model(self):
@unittest.skip(reason="Chat is not supported")
def test_chat_template(self):
pass

@unittest.skip("Chat is not supported")
def test_chat_template_return_assistant_tokens_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/layoutlmv3/test_tokenization_layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2436,3 +2436,7 @@ def test_tf_encode_plus_sent_to_model(self):
@unittest.skip(reason="Chat is not supported")
def test_chat_template(self):
pass

@unittest.skip("Chat is not supported")
def test_chat_template_return_assistant_tokens_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/layoutxlm/test_tokenization_layoutxlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1977,3 +1977,7 @@ def test_sentencepiece_tokenize_and_decode(self):
@unittest.skip(reason="Chat is not supported")
def test_chat_template(self):
pass

@unittest.skip("Chat is not supported")
def test_chat_template_return_assistant_tokens_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/markuplm/test_tokenization_markuplm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2316,3 +2316,7 @@ def test_chat_template(self):
@unittest.skip(reason="The model tested fails `Hub -> Fast == Hub -> Slow`, nothing much we can do")
def test_added_tokens_serialization(self):
pass

@unittest.skip("Chat is not supported")
def test_chat_template_return_assistant_tokens_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/tapas/test_tokenization_tapas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,3 +1277,7 @@ def test_np_encode_plus_sent_to_model(self):
@unittest.skip(reason="Chat is not supported")
def test_chat_template(self):
pass

@unittest.skip("Chat is not supported")
def test_chat_template_return_assistant_tokens_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/udop/test_tokenization_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,10 @@ def test_offsets_mapping(self):
def test_chat_template(self):
pass

@unittest.skip(reason="Chat template tests don't play well with table/layout models.")
def test_chat_template_return_assistant_tokens_mask(self):
pass

@unittest.skip(reason="Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
Expand Down
129 changes: 129 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,135 @@ def test_chat_template_batched(self):
dummy_conversations, chat_template=dummy_template, tokenize=True
) # Check that no error raised

@require_jinja
def test_chat_template_return_assistant_tokens_mask(self):
dummy_template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'] + '<|im_end|>'}}"
"{% endgeneration %}"
"{{'\n'}}"
"{% endif %}"
"{% endfor %}"
)
conversations = [
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "start turn 1 assistant message. end turn 1"},
{"role": "user", "content": "user message 2"},
{"role": "assistant", "content": "start turn 2 assistant message. end turn 2"},
],
[
{"role": "system", "content": "system message 3"},
{"role": "user", "content": "user message 3"},
{"role": "assistant", "content": "start turn 3 assistant message. end turn 3"},
{"role": "user", "content": "user message 4"},
{"role": "assistant", "content": "start turn 4 assistant message. end turn 4"},
],
]

# These are the prefix and suffix strings of all the assistant messages. Used to find the assistant substring
# in the entire chat string, and then find the corresponding tokens in the tokenized output.
assistant_prefix_suffix = [
[("start turn 1", "end turn 1<|im_end|>"), ("start turn 2", "end turn 2<|im_end|>")],
[("start turn 3", "end turn 3<|im_end|>"), ("start turn 4", "end turn 4<|im_end|>")],
]
for tokenizer, pretrained_name, _ in self.tokenizers_list:
yonigottesman marked this conversation as resolved.
Show resolved Hide resolved
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
if not self.test_rust_tokenizer:
self.skipTest(reason="No fast tokenizer defined")

tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name)

# check batched
output = tokenizer_r.apply_chat_template(
conversations,
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
)
for i, conv in enumerate(conversations):
chat_string = tokenizer_r.apply_chat_template(
conversations[i], tokenize=False, chat_template=dummy_template
)
assistant_start = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][0][0]))
assistant_end = output.char_to_token(
i,
chat_string.index(assistant_prefix_suffix[i][0][1])
+ len(assistant_prefix_suffix[i][0][1])
- 1,
)

assistant_start2 = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][1][0]))
assistant_end2 = output.char_to_token(
i,
chat_string.index(assistant_prefix_suffix[i][1][1])
+ len(assistant_prefix_suffix[i][1][1])
- 1,
)

# assert 1 in first assistant message
self.assertEqual(
output["assistant_masks"][i][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
# assert 1 second assistant message
self.assertEqual(
output["assistant_masks"][i][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)

# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][i][:assistant_start], [0] * assistant_start)
self.assertEqual(
output["assistant_masks"][i][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)

# check not batched
output = tokenizer_r.apply_chat_template(
conversations[0],
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
)

chat_string = tokenizer_r.apply_chat_template(
conversations[0], tokenize=False, chat_template=dummy_template
)
assistant_start = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][0][0]))
assistant_end = output.char_to_token(
0, chat_string.index(assistant_prefix_suffix[0][0][1]) + len(assistant_prefix_suffix[0][0][1]) - 1
)
assistant_start2 = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][1][0]))
assistant_end2 = output.char_to_token(
0, chat_string.index(assistant_prefix_suffix[0][1][1]) + len(assistant_prefix_suffix[0][1][1]) - 1
)

# assert 1 in assistant indices
self.assertEqual(
output["assistant_masks"][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
self.assertEqual(
output["assistant_masks"][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)

# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][:assistant_start], [0] * assistant_start)
self.assertEqual(
output["assistant_masks"][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)

@require_jinja
def test_chat_template_dict(self):
dummy_template_1 = "{{'a'}}"
Expand Down