Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast tokenizer breaks added tokens #27132

Closed
1 of 4 tasks
geronimi73 opened this issue Oct 29, 2023 · 7 comments · Fixed by #27313
Closed
1 of 4 tasks

Fast tokenizer breaks added tokens #27132

geronimi73 opened this issue Oct 29, 2023 · 7 comments · Fixed by #27313

Comments

@geronimi73
Copy link

System Info

  • transformers version: 4.35.0.dev0
  • Platform: Linux-6.2.0-35-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.1
  • Accelerate version: 0.25.0.dev0
  • PyTorch version (GPU?): 2.1.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("models/llama2-7b", use_fast=False)

# add tokens for chatml
tokenizer.add_tokens(["<|im_start|>"])
tokenizer.add_special_tokens({"eos_token": "<|im_end|>"})

messages = [ {"role": "user", "content": "question"},
  {"role": "assistant", "content": "answer"} ]

# https://huggingface.co/docs/transformers/main/chat_templating
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

chat = tokenizer.apply_chat_template(messages, tokenize=False)
chat_tokenized = tokenizer(chat, add_special_tokens=False)["input_ids"]

for token in chat_tokenized:
	print(f"{token} - \"{tokenizer.decode(token)}\"")

output: first occurence of <|im_start|> is correctly tokenized, second one is split

32000 - "<|im_start|>"
1792 - "user"
13 - "
"
12470 - "question"
32001 - "<|im_end|>"
29871 - ""
13 - "
"
**29966 - "<"
29989 - "|"
326 - "im"
29918 - "_"
2962 - "start"
29989 - "|"
29958 - ">"**
465 - "ass"
22137 - "istant"
13 - "
"
12011 - "answer"
32001 - "<|im_end|>"
29871 - ""
13 - "
"

Expected behavior

32000 - "<|im_start|>"
1404 - "user"
13 - "<0x0A>"
12470 - "question"
32001 - "<|im_end|>"
29871 - ""
13 - "<0x0A>"
32000 - "<|im_start|>"
20255 - "assistant"
13 - "<0x0A>"
12011 - "answer"
32001 - "<|im_end|>"
29871 - ""
13 - "<0x0A>"

this is the correct output of the slow tokenizer AutoTokenizer.from_pretrained("models/llama2-7b", use_fast=False)

  1. why does this happen with fast but not slow?
  2. any other solution than not using the fast tokenizer?

i guess this is known, sorry if I missed it in the existing issues

@amyeroberts
Copy link
Collaborator

Hi @geronimi73, thanks for raising an issue!

@ArthurZucker is off for this week and is the main person who knows and works with the tokenizers, so you might have to wait until then to have an answer.

@Rocketknight1 any chance you know what's happening?

@Rocketknight1
Copy link
Member

Hi @geronimi73, I'll wait for @ArthurZucker to return to give a full answer here, but in the meantime I think the issue is that when you add a normal token, the tokenizer may split it. If you want to preserve an important control token like <|im_start|> you should make it a special token. Try doing this instead:

tokenizer.add_special_tokens({"additional_special_tokens": ["<|im_start|>"]})
tokenizer.add_special_tokens({"eos_token": "<|im_end|>"})

@ArthurZucker
Copy link
Collaborator

Well, it's partly true partly wrong 😅
When you add a token, if it is not special, it will be normalized by default. I'll add the add_tokens function to the doc it seems that it was removed. But anyway, the Llama normalizer adds a SPIECE_UNDERLINE at the beginning of the special tokens, which will thus be a different token. AddedTokens (special or not) should never be splitted, but the content of the added tokens is affected by the normalizer

@geronimi73
Copy link
Author

ok, thanks!

@geronimi73
Copy link
Author

it's somehow working now.

just to sum this up for others who are struggling with this too:

  • I raised this issue because the fast tokenizer breaks the ChatML tag <|im_start|> into several tokens even though it was added with tokenizer.add_tokens(["<|im_start|>"]), slow tokenizer works fine
  • @ArthurZucker explains above, Llama normalizer adds a SPIECE_UNDERLINE; indeed, fast tokenizer encodes <|im_start|> correctly when token is added with ..
tokenizer.add_tokens(
	AddedToken("<|im_start|>",normalized=False))
)
  • but, new problem. decoding now adds a space after added tokens, example
tokenizer = AutoTokenizer.from_pretrained("../models/llama2-7b", use_fast=True, legacy=False)

tokenizer.add_tokens(
	AddedToken("<|im_start|>",normalized=False, rstrip=True, lstrip=False)
)
tokenizer.add_special_tokens({"eos_token": "<|im_end|>"})

# https://huggingface.co/docs/transformers/main/chat_templating
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

messages=[
    {"role": "user", "content": "Hi there!"},
    {"role": "assistant", "content": "Nice to meet you!"}
]

chat = tokenizer.apply_chat_template(messages, tokenize=False)
chat_tokenized = tokenizer(chat, add_special_tokens=False)["input_ids"]

print("INPUT")
print(chat)
print("-"*30)
print("DECODE(ENCODE(INPUT))")
print(tokenizer.decode(chat_tokenized))

# INPUT
# <|im_start|>user
# Hi there!<|im_end|>
# <|im_start|>assistant
# Nice to meet you!<|im_end|>

# ------------------------------
# DECODE(ENCODE(INPUT))
# <|im_start|> user
# Hi there!<|im_end|> 
# <|im_start|> assistant
# Nice to meet you!<|im_end|> 
  • fix all of the above: use slow tokenizer use_fast=False, legacy=False, add tokens with tokenizer.add_tokens(["<|im_start|>"]), decode with spaces_between_special_tokens=False like this
tokenizer = AutoTokenizer.from_pretrained("../models/llama2-7b", use_fast=False, legacy=False)
tokenizer.add_tokens(["<|im_start|>"])
...
chat_tokenized = tokenizer(chat, add_special_tokens=False)["input_ids"]
print(tokenizer.decode(chat_tokenized, spaces_between_special_tokens=False))
  • using transformers 4.35.0 btw

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Nov 13, 2023

Thanks for the great explanation!
Regarding the space added after added tokens, this PR will fix it: huggingface/tokenizers#1357 😉 I'll have to change the Llama paradigm a little bit to make sure it's compatible

@ArthurZucker
Copy link
Collaborator

feel free to play with #26678 as well 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants