Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalizer "replace" is quadratic in sequence length (impacts Llama 2 tokenizer) #1449

Closed
dlwh opened this issue Feb 6, 2024 · 6 comments
Closed
Labels

Comments

@dlwh
Copy link

dlwh commented Feb 6, 2024

(meaning the llama 2 tokenizer is quadratic in sequence length)

len(str) vs time to tokenize:

GPT-2's tokenizer is roughly linear as you would expect:

gpt2:
10000	0.012285999953746796
100000	0.13678820803761482
200000	0.2882207091897726
300000	0.4590047914534807
400000	0.626229832880199
500000	0.8339609587565064
800000	1.3014380000531673
1000000	1.728166500106454
10000000	20.250024332664907

But llama2 is quadratic:

meta-llama/Llama-2-7b-hf
10000	0.019068040885031223
100000	1.3336323332041502
200000	5.238070583902299
300000	11.626898417249322
400000	21.934583541937172

(i killed the script but you get the idea)

the culprit is the normalizer (by doing tokenizer_gpt2.backend_tokenizer.normalizer = tokenizer.backend_tokenizer.normalizer)

gpt2 with llama normalizer
10000	0.03563724923878908
100000	1.5297077922150493
200000	5.390813915990293
300000	12.077063124626875

if we give gpt2 just the Replace normalizer (which is one of the normalizers in the llama tokenizer), we get the quadratic behavior

tokenizer_gpt2.backend_tokenizer.normalizer = normalizers.Sequence([normalizers.Replace(" ", "_")])

10000	0.022305666469037533
100000	1.3800452090799809
200000	5.384156417101622
300000	12.147614290937781
400000	22.582430874928832
500000	34.3297392912209
Script ```python import json from transformers import AutoTokenizer from tokenizers import pre_tokenizers, normalizers import timeit

lorem = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."""

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer_gpt2 = AutoTokenizer.from_pretrained("gpt2")
tokenizer_gpt2.backend_tokenizer.normalizer = normalizers.Sequence([normalizers.Prepend("_")])

tokenizer_gpt2.backend_tokenizer.normalizer = tokenizer.backend_tokenizer.normalizer

tokenizer.backend_tokenizer.pre_tokenizer = tokenizer_gpt2.backend_tokenizer.pre_tokenizer

tokenizer.backend_tokenizer.post_processor = tokenizer_gpt2.backend_tokenizer.post_processor

tokenizer.backend_tokenizer.normalizer = tokenizer_gpt2.backend_tokenizer.normalizer

delattr(tokenizer.backend_tokenizer, "normalizer")

assert tokenizer.is_fast

MAX_LEN = 10_000_000

data = [lorem * (MAX_LEN // len(lorem))]

total_len = 0
for d in data:
total_len += len(d)

print(total_len / len(data))
print(total_len)

for chars_to_tokenize in [10_000, 100_000, 200_000, 300_000, 400_000, 500_000, 800_000, 1_000_000, 10_000_000]:
data_to_tokenize = [d[:chars_to_tokenize] for d in data]
out = timeit.timeit(lambda: tokenizer(data_to_tokenize, return_tensors="np", max_length=None, truncation=False).input_ids.shape, number=5)
print(f"{chars_to_tokenize}\t{out}")

</details> 
@dlwh dlwh changed the title Normalizer "replace" is quadratic in sequence length (impacts Llama tokenizer) Normalizer "replace" is quadratic in sequence length (impacts Llama 2 tokenizer) Feb 6, 2024
@ArthurZucker
Copy link
Collaborator

Thanks, #1413 should fix it, but #1357 as well. Normalizer should not be used for LLama

@Ivan-Zhou
Copy link

@ArthurZucker Thanks for following up. Is there any estimation on when #1413 is going to be merged?

@ArthurZucker
Copy link
Collaborator

I'd recommend using huggingface/transformers#28881 but it's not polished yet

@ArthurZucker
Copy link
Collaborator

Either this or I'll fix the normalizer for special tokens cases

Copy link

github-actions bot commented Mar 8, 2024

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Mar 8, 2024
@dlwh
Copy link
Author

dlwh commented Mar 8, 2024

Looks like it's fixed! With tokenizers 0.15.2 and transformers 4.38.2:

10000	0.010763333993963897
100000	0.11263033305294812
200000	0.2581357080489397
300000	0.4427691249875352
400000	0.5360061249812134
500000	0.7575962499831803
800000	1.2085898750228807
1000000	1.7549052499816753
10000000	23.24868266703561

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants