-
Notifications
You must be signed in to change notification settings - Fork 816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Normalizer "replace" is quadratic in sequence length (impacts Llama 2 tokenizer) #1449
Comments
@ArthurZucker Thanks for following up. Is there any estimation on when #1413 is going to be merged? |
I'd recommend using huggingface/transformers#28881 but it's not polished yet |
Either this or I'll fix the normalizer for special tokens cases |
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days. |
Looks like it's fixed! With tokenizers 0.15.2 and transformers 4.38.2:
Thanks! |
(meaning the llama 2 tokenizer is quadratic in sequence length)
len(str) vs time to tokenize:
GPT-2's tokenizer is roughly linear as you would expect:
But llama2 is quadratic:
(i killed the script but you get the idea)
the culprit is the normalizer (by doing
tokenizer_gpt2.backend_tokenizer.normalizer = tokenizer.backend_tokenizer.normalizer
)if we give gpt2 just the Replace normalizer (which is one of the normalizers in the llama tokenizer), we get the quadratic behavior
tokenizer_gpt2.backend_tokenizer.normalizer = normalizers.Sequence([normalizers.Replace(" ", "_")])
Script
```python import json from transformers import AutoTokenizer from tokenizers import pre_tokenizers, normalizers import timeitlorem = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."""
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer_gpt2 = AutoTokenizer.from_pretrained("gpt2")
tokenizer_gpt2.backend_tokenizer.normalizer = normalizers.Sequence([normalizers.Prepend("_")])
tokenizer_gpt2.backend_tokenizer.normalizer = tokenizer.backend_tokenizer.normalizer
tokenizer.backend_tokenizer.pre_tokenizer = tokenizer_gpt2.backend_tokenizer.pre_tokenizer
tokenizer.backend_tokenizer.post_processor = tokenizer_gpt2.backend_tokenizer.post_processor
tokenizer.backend_tokenizer.normalizer = tokenizer_gpt2.backend_tokenizer.normalizer
delattr(tokenizer.backend_tokenizer, "normalizer")
assert tokenizer.is_fast
MAX_LEN = 10_000_000
data = [lorem * (MAX_LEN // len(lorem))]
total_len = 0
for d in data:
total_len += len(d)
print(total_len / len(data))
print(total_len)
for chars_to_tokenize in [10_000, 100_000, 200_000, 300_000, 400_000, 500_000, 800_000, 1_000_000, 10_000_000]:
data_to_tokenize = [d[:chars_to_tokenize] for d in data]
out = timeit.timeit(lambda: tokenizer(data_to_tokenize, return_tensors="np", max_length=None, truncation=False).input_ids.shape, number=5)
print(f"{chars_to_tokenize}\t{out}")
The text was updated successfully, but these errors were encountered: