Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistency behaviours between LlamaTokenizer and LlamaTokenizerFast when lstrip=True #29626

Closed
4 tasks
x54-729 opened this issue Mar 13, 2024 · 5 comments · Fixed by #28881
Closed
4 tasks

Comments

@x54-729
Copy link

x54-729 commented Mar 13, 2024

System Info

ubuntu
transformers==4.39.0dev

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Refer to tokenization test:

special_token = tokenizer.all_special_tokens[0]
text = special_token + " aaaaa bbbbbb low cccccccccdddddddd l " + special_token
text2 = special_token + " AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l " + special_token
toks_before_adding = tokenizer.tokenize(text) # toks before adding new_toks
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd", "AAAAA BBBBBB", "CCCCCCCCCDDDDDDDD"]
added = tokenizer.add_tokens([AddedToken(tok, lstrip=True, rstrip=True) for tok in new_toks])
self.assertIn(added, [2, 4])
toks_after_adding = tokenizer.tokenize(text)
toks_after_adding2 = tokenizer.tokenize(text2)
self.assertEqual(len(toks_after_adding), len(toks_after_adding2)) # Length should still be the same
self.assertNotEqual(
toks_after_adding[1], toks_after_adding2[1]
) # But at least the first non-special tokens should differ

from transformers import AutoTokenizer
from transformers.tokenization_utils import AddedToken

tokenizer = AutoTokenizer.from_pretrained("/path/to/llama2", use_fast=False)

special_token = tokenizer.all_special_tokens[0]

text = special_token + " aaaaa bbbbbb low cccccccccdddddddd l " + special_token
text2 = special_token + " AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l " + special_token

toks_before_adding = tokenizer.tokenize(text)  # toks before adding new_toks

new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd", "AAAAA BBBBBB", "CCCCCCCCCDDDDDDDD"]

added = tokenizer.add_tokens([AddedToken(tok, lstrip=True, rstrip=True) for tok in new_toks])

toks_after_adding = tokenizer.tokenize(text)
toks_after_adding2 = tokenizer.tokenize(text2)

print(toks_after_adding, toks_after_adding2)

If use_fast=False, the output is:

['<s>', 'aaaaa bbbbbb', '▁low', 'cccccccccdddddddd', '▁l', '', '<s>'] ['<s>', 'AAAAA BBBBBB', '▁low', 'CCCCCCCCCDDDDDDDD', '▁l', '', '<s>']

Otherwise

['<s>', '', '▁aaaaa▁bbbbbb', '▁low', '▁cccccccccdddddddd', '▁l', '', '<s>'] ['<s>', '', '▁AAAAA▁BBBBBB', '▁low', '▁CCCCCCCCCDDDDDDDD', '▁l', '', '<s>']

Expected behavior

Seems that lstrip=True did not take effect. When use_fast=True, the output should be the same as use_fast=False

@amyeroberts
Copy link
Collaborator

cc @ArthurZucker

@Trangle
Copy link

Trangle commented Mar 22, 2024

There is also this inconsistency issue.

Here is my case

from transformers import AutoTokenizer
from transformers.tokenization_utils import AddedToken

model_name_or_path = "mistralai/Mistral-7B-v0.1"
fast_tk = AutoTokenizer.from_pretrained(
  model_name_or_path ,
  use_fast=True,
  padding_side="left",
)
slow_tk = AutoTokenizer.from_pretrained(
  model_name_or_path ,
  use_fast=False,
  padding_side="left",
)
new_toks = ["<|im_start|>", "<|im_end|>"]
added = tokenizer.add_tokens([AddedToken(tok, lstrip=True, rstrip=True) for tok in new_toks])

test_text = "<|im_start|>Who are you?<|im_end|>"
print(fast_tk(test_text))
print(slow_tk(test_text))

For fast token, use_fast=False

{'input_ids': [32006, 11447, 460, 368, 28804, 28789, 28766, 321, 28730, 416, 28766, 28767], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

Otherwise

{'input_ids': [32006, 6526, 460, 368, 28804, 32007], 'attention_mask': [1, 1, 1, 1, 1, 1]}

And result {'input_ids': [32006, 6526, 460, 368, 28804, 32007], 'attention_mask': [1, 1, 1, 1, 1, 1]} is expected.

@ArthurZucker
Copy link
Collaborator

You need to add the tokens with normalized=False this is not a bug, has been filed as an issue many many times 😉

@ArthurZucker
Copy link
Collaborator

This will also be fixed by #28881

@x54-729
Copy link
Author

x54-729 commented Mar 25, 2024

This will also be fixed by #28881

Thank you so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants