-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inconsistency behaviours between LlamaTokenizer and LlamaTokenizerFast when lstrip=True #29626
Closed
4 tasks
Comments
There is also this inconsistency issue. Here is my case from transformers import AutoTokenizer
from transformers.tokenization_utils import AddedToken
model_name_or_path = "mistralai/Mistral-7B-v0.1"
fast_tk = AutoTokenizer.from_pretrained(
model_name_or_path ,
use_fast=True,
padding_side="left",
)
slow_tk = AutoTokenizer.from_pretrained(
model_name_or_path ,
use_fast=False,
padding_side="left",
)
new_toks = ["<|im_start|>", "<|im_end|>"]
added = tokenizer.add_tokens([AddedToken(tok, lstrip=True, rstrip=True) for tok in new_toks])
test_text = "<|im_start|>Who are you?<|im_end|>"
print(fast_tk(test_text))
print(slow_tk(test_text)) For fast token,
Otherwise
And result |
You need to add the tokens with |
This will also be fixed by #28881 |
Thank you so much! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
System Info
ubuntu
transformers==4.39.0dev
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Refer to tokenization test:
transformers/tests/test_tokenization_common.py
Lines 864 to 881 in 9acce7d
If
use_fast=False
, the output is:Otherwise
Expected behavior
Seems that
lstrip=True
did not take effect. Whenuse_fast=True
, the output should be the same asuse_fast=False
The text was updated successfully, but these errors were encountered: