Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply_chat_template method not working correctly for llama 3 tokenizer #33091

Closed
2 of 4 tasks
sirluk opened this issue Aug 23, 2024 · 10 comments
Closed
2 of 4 tasks

apply_chat_template method not working correctly for llama 3 tokenizer #33091

sirluk opened this issue Aug 23, 2024 · 10 comments

Comments

@sirluk
Copy link

sirluk commented Aug 23, 2024

System Info

  • transformers version: 4.44.1
  • Platform: Linux-4.18.0-553.8.1.el8_10.x86_64-x86_64-with-glibc2.28
  • Python version: 3.10.14
  • Huggingface_hub version: 0.24.5
  • Safetensors version: 0.4.4
  • Accelerate version: 0.33.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

@ArthurZucker
I noticed that the apply_chat_template for the PreTrainedTokenizerBase class does not work correctly when return_assistant_tokens_mask=True. We would expect to get back a list of indices for each example where 1 indicates the token is part of an assistant message and 0 otherwise. This is the case for the Llama 2 tokenizer for example. I am sharing a minimal example to reproduce this issue.

Looking deeper into the apply_chat_template method it seems the issue is related to the char_to_token method of the tokenizers.Embedding class and could be related to the fact that the Llama 3 tokenizer was trained with tiktoken as opposed to sentencepiece.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer
from datasets import load_dataset

dataset_name = "m-a-p/Code-Feedback"

model_name = "meta-llama/Meta-Llama-3.1-8B" # apply_chat_template does not work correctly
#model_name = "meta-llama/Llama-2-7b-hf" # apply_chat_template works correctly

chat_template = """{% if messages[0]['role'] == 'system' %}
    {% set offset = 1 %}
{% else %}
    {% set offset = 0 %}
{% endif %}

{% for message in messages %}
    {% if (message['role'] == 'user') != (loop.index0 % 2 == offset) %}
        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
    {% endif %}

    {{ '### ' + message['role'] + ':\n'}}
    {% if (message['role'] == 'assistant') %}
        {% generation %} {{ message['content'] | trim + eos_token }} {% endgeneration %}
    {% else %}
        {{ message['content'] | trim + eos_token }}
    {% endif %}

{% endfor %}

{% if add_generation_prompt %}
    {{ '### ' + 'assistant' + ':\n' }}
{% endif %}"""

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.chat_template = chat_template
datasets = load_dataset(dataset_name, trust_remote_code=True)

# assistant_mask is all zeros for llama3 tokenizer
chat = tokenizer.apply_chat_template(
    datasets["train"][0]["messages"],
    add_generation_prompt=False,
    return_dict=True,
    tokenize=True,
    return_assistant_tokens_mask=True
)
print("assistant_masks", chat["assistant_masks"])

Executing the steps to get the assistant mask in the apply chat template method shows that the char_to_token method of the tokenizers. Embedding class seems to be not working correctly.

compiled_template = tokenizer._compile_jinja_template(chat_template)
template_kwargs = {**tokenizer.special_tokens_map}
rendered_chat, generation_indices = tokenizer._render_with_assistant_indices(
    compiled_template=compiled_template,
    messages=datasets["train"][0]["messages"],
    tools=[],
    documents=None,
    add_generation_prompt=False,
    **tokenizer.special_tokens_map
)
out = tokenizer(
    rendered_chat,
    padding=False,
    truncation=False,
    max_length=None,
    add_special_tokens=False,
    return_tensors=None
)
first_assistant_start_char, first_assistant_end_char = generation_indices[0]
# returns None for llama3
print("char_to_token", out[0].char_to_token(0, first_assistant_start_char))

Expected behavior

If we assume that the entire chat is 10 characters and the assistant tokens occur at indices 4-6 and 8-9 we would have an expected output that looks like this
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0]
The actual output for the llama 3 tokenizer is always all 0s
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

@sirluk sirluk added the bug label Aug 23, 2024
@ArthurZucker
Copy link
Collaborator

cc @Rocketknight1 our chat template expert

@Rocketknight1
Copy link
Member

cc @yonigottesman who wrote the original PR at #30650, do you have any idea what the issue here could be? If you don't have time to investigate this right now, let me know and I'll take over.

@yonigottesman
Copy link
Contributor

related to this
#30650 (comment)

and #1620
If i get time ill try to find a work around for this tokenizer issue

@Rocketknight1
Copy link
Member

Got it - I wouldn't make a workaround in the template itself, because you'll need to remove the workaround again once the underlying tokenizers issue is fixed.

@sirluk
Copy link
Author

sirluk commented Sep 8, 2024

for anyone struggling with the same issue atm, I created a temporary workaround for my usecase

class TokenizerCodeFeedbackHacky:

    PROMPT = (
        "Instruction:\nGiven a multi-turn dialogue related to a coding task, your role is to generate the assistant's next response."
        "\n\nDialogue:\n"
    )
    CHAT_TEMPLATE_PATH = "chat_template.jinja"

    def __init__(self, tokenizer_path):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.tokenized_prompt = self.tokenizer(self.PROMPT, add_special_tokens=False, return_attention_mask=False)["input_ids"]
        self.tokenized_prompt_len = len(self.tokenized_prompt)

        chat_template = open(self.CHAT_TEMPLATE_PATH).read()
        self.tokenizer.chat_template = chat_template
        self.chat_template_header = "### {role}:\n" #.format(role="assistant")

    def __call__(self, examples):
        chats = self.tokenizer.apply_chat_template(
            examples["messages"],
            add_generation_prompt=False,
            return_dict=False,
            tokenize=False
        )
        chats_tokenized = self.tokenizer(chats, add_special_tokens=False, return_attention_mask=False, return_length=True, return_offsets_mapping=True)
        assistant_mask = []
        for i in range(len(chats)):
            s, _ = zip(*chats_tokenized[i].offsets)
            s = torch.tensor(s)
            assistant_starts = [x.end()+1 for x in re.finditer(self.chat_template_header.format(role="assistant"), chats[i])]
            assistant_ends = [x.start()-1 for x in re.finditer(self.chat_template_header.format(role="user"), chats[i])]
            assistant_ends = assistant_ends[1:] + [len(chats[i])]
            assistant_start_ids, assistant_end_ids = [], []
            for start, end in zip(assistant_starts, assistant_ends):
                assistant_start_ids.append((s > start).long().argmax().item() - 1)
                assistant_end_ids.append((s > end).long().argmax().item() - 1)
            assistant_end_ids = assistant_end_ids[:-1] + [chats_tokenized["length"][i]-1]
            mask = [0] * chats_tokenized["length"][i]
            for start_id, end_id in zip(assistant_start_ids, assistant_end_ids):
                mask[start_id:end_id] = [1] * (end_id-start_id)
            assistant_mask.append(mask)
        input_ids = [self.tokenized_prompt + x for x in chats_tokenized["input_ids"]]
        assistant_mask = [[0] * self.tokenized_prompt_len + x for x in assistant_mask]
        input_length = [x + self.tokenized_prompt_len for x in chats_tokenized["length"]]
        return {"input_ids": input_ids, "assistant_mask": assistant_mask, 'input_length': input_length}

Copy link

github-actions bot commented Oct 3, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Oct 3, 2024

Closing as the issue was fixed by #1640 !

@0seba
Copy link

0seba commented Oct 3, 2024

Hi @ArthurZucker , could you point to the commit/PR in which it was fixed?

@ArthurZucker
Copy link
Collaborator

oups sorry it's a fix in tokenizers

@ArthurZucker
Copy link
Collaborator

Updated huggingface/tokenizers#1640

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants