Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

return mask of user messages when calling tokenizer.apply_chat_template(c,tokenize=True) #28950

Closed
yonigottesman opened this issue Feb 10, 2024 · 11 comments
Labels
Feature request Request for a new feature

Comments

@yonigottesman
Copy link
Contributor

yonigottesman commented Feb 10, 2024

Feature request

when training a chat model I want to ignore labels that are "user" generated and only compute the loss on the "assistant" messages. The tokenizer.apply_chat_template(c,tokenize=True) should return a list with 0,1 - 1 marking tokens from a "user" message I can then create the labels of this input by marking all tokens generated by user with -100.

This is similar to the behavior of DataCollatorForCompletionOnlyLM but with this class we search the instruction_template which is not easy to find in a multi message conversation.

Motivation

anyone training a conversational model should probably do this and its hard to do it together with apply_chat_template. in most cases people manually construct the chat string with -100 (see fastchat llama)

Your contribution

If the proposal is accepted I will work on this and submit a pr

@yonigottesman yonigottesman changed the title return mask of human user messages when calling tokenizer.apply_chat_template(c,tokenize=True) return mask of user messages when calling tokenizer.apply_chat_template(c,tokenize=True) Feb 10, 2024
@amyeroberts
Copy link
Collaborator

cc @Rocketknight1 @ArthurZucker

@amyeroberts amyeroberts added the Feature request Request for a new feature label Feb 12, 2024
@Rocketknight1
Copy link
Member

Hi @yonigottesman - this would be a useful feature, but how do you plan to implement it?

@geronimi73
Copy link

Hi @yonigottesman - this would be a useful feature, but how do you plan to implement it?

indeed, great feature!

possible approach: in apply_chat_template how about looping through the messages and calling compiled_template.render for each message, knowing what is user and non-user and thereby building 0/1 mask that is returned by apply_chat_template ?

@Rocketknight1
Copy link
Member

@geronimi73 something like that could work, but there are several edge cases! Firstly, some tokenizers introduce additional spaces, in which case the outputs might be slightly different if you loop through messages separately, and secondly some tokenizers like LLaMA insert the system message into the first user message, which means that we can't safely assume that the ordering of messages in the dict will match the ordering of tokens in the output.

@geronimi73
Copy link

tokenizers like LLaMA insert the system message into the first user message

from what i've seen the order of the messages is always preserved, also llama inserts <<SYS>> whenever role == "system". see https://github.com/facebookresearch/llama/blob/ef351e9cd9496c579bf9f2bb036ef11bdc5ca3d2/llama/generation.py#L324
https://huggingface.co/docs/transformers/main/en/chat_templating#how-do-i-create-a-chat-template

do you have an example of a jinja template where the order of the is changed? i haven't found one (in the few examples i looked at)

in which case the outputs might be slightly different if you loop through messages separately

true. but I think this is exactly the way people help themselves right now: loop through the messages, tokenize each message separately and set labels according to the role of the message just tokenized. so, while not ideal, it would not get worse than how it is done currently. the advantage would be that the tokenizer would take care of this. instead of every individual coding this error-prone part themself

@Rocketknight1
Copy link
Member

Hi @geronimi73 - the example in the docs you linked is not actually the full LLaMA template! We simplified it for that document. Here's the full LLaMA 2 template, with linebreaks/indentation added. Note that the system message is actually injected into the middle of the first user message!

{% if messages[0]['role'] == 'system' %}
    {% set loop_messages = messages[1:] %}
    {% set system_message = messages[0]['content'] %}
{% else %}
    {% set loop_messages = messages %}
    {% set system_message = false %}
{% endif %}
{% for message in loop_messages %}
    {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
    {% endif %}
    {% if loop.index0 == 0 and system_message != false %}
        {% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}
    {% else %}
        {% set content = message['content'] %}
    {% endif %}
    {% if message['role'] == 'user' %}
        {{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}
    {% elif message['role'] == 'assistant' %}
        {{ ' '  + content.strip() + ' ' + eos_token }}
    {% endif %}
{% endfor %}

@yonigottesman
Copy link
Contributor Author

yonigottesman commented Apr 29, 2024

I think a different approach is needed here. I know this is a bit of a hack, but bare with me a second as I think this is really important :)

We can introduce a new keyword {% generation %} into the chat templates and tokenizers can wrap the assistant generated part with this new keyword. We can start by changing the popular templates and others will follow.
Here is an example of how this should work, in this example I'm wrapping the assistant generated parts, but instead we can decide we should wrap the user/system parts with a {% train_ignore %}.

Here is how the new chat template of phi3 would look like:

chat_template = (
    "{{ bos_token }}"
    "{% for message in messages %}"
    "{% if (message['role'] in ['user', 'system']) %}"
    "{{'<|user|>' + '' + message['content'] + '<|end|>' + '' + '<|assistant|>' + ''}}"
    "{% elif message['role'] == 'assistant' %}"
    "{% generation %}"
    "{{message['content'] + '<|end|>' + ''}}"
    "{% endgeneration %}"
    "{% endif %}"
    "{% endfor %}"
)

Here is the jinja2 extension to add the generation keyword together with two variable it needs to have access to blocks and generation_indices (with closure probably)

blocks = []
generation_indices = []


class GenerationTracker(Extension):
    tags = {"generation"}

    def __init__(self, environment):
        super().__init__(environment)

    def parse(self, parser):
        lineno = next(parser.stream).lineno
        body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
        return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)

    def _generation_support(self, caller):
        rv = caller()
        start_index = len("".join(blocks))
        end_index = start_index + len(rv)
        generation_indices.append((start_index, end_index))
        return rv


jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[GenerationTracker])
compiled_template = jinja_env.from_string(chat_template)

The rendering will be done with compiled_template.generate and that way I can slowly add text nodes to the blocks list and every time there is a generation block it knows its char offset by computing the string up until now with len("".join(blocks):


for i in compiled_template.generate(messages=messages, **tokenizer.special_tokens_map):
    blocks.append(i)
chat_text = "".join(blocks)

When this is done generation_indices contains the start and end offsets of all assistant generated text in chat_text. Now I use the tokenizer char_to_token to map the chars in the text to token indices:

generation_mask = [0] * len(tokenized["input_ids"])
for start, end in generation_indices:
    for i in range(start, end):
        token_index = tokenized.char_to_token(i)
        generation_mask[token_index] = 1

Now generation_mask contains 1 for all tokens of the text that was wrapped in generation.

I know this is not such a trivial solution, but given we are not going swap jinja with something else, I think its not so bad.
I have seen tons of training code examples that just ignore this issue and either train on all tokens or rewrite the tokenization without apply_chat_template which might lead to inconsistency between train/inference.

@Rocketknight1 what do you think?

p.s this is the messages i used to check this code works:

messages = [
    {
        "role": "user",
        "content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.",
    },
    {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
    {
        "role": "assistant",
        "content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey.",
    },
    {"role": "user", "content": "What about solving an 2x + 3 = 7 equation?"},
    {"role": "assistant", "content": "Sure! the answer is bla bla bla"},
]

@Rocketknight1
Copy link
Member

@yonigottesman this is really cool! We'll definitely have to do some testing and iterate on it, but the way you've gotten the template to track the information we need is really nice.

@lewtun - This should let us keep track of which roles generated which blocks of text in the rendered output. Are there other things you wanted added to chat templates in a similar vein that we might include in this PR?

@xenova - how do you think this would interact with huggingface/jinja? Or minijinja?

@yonigottesman
Copy link
Contributor Author

like I said I am willing to work on this pr.
I would also change trl ConstantLengthDataset as it already returns labels to optionally add -100 on user labels

yield {
    "input_ids": torch.LongTensor(example),
    "labels": torch.LongTensor(example),
}

but that would be work on that repo...

@Rocketknight1
Copy link
Member

@yonigottesman we're very to have you open this PR! We'll definitely need to check with maintainers for other libraries to make sure that we can support it there (or at least ignore it without breaking the templates). However, I think opening a PR quickly is a good start - let us know whenever you're ready!

@yonigottesman
Copy link
Contributor Author

lets continue this conversation here #30650

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

4 participants