Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficient Replace normalizer #1413

Merged
merged 7 commits into from
Feb 6, 2024
Merged

Conversation

rlrs
Copy link
Contributor

@rlrs rlrs commented Dec 12, 2023

The existing Replace normalizer, used for example in the Llama and Mistral tokenizers, is implemented very inefficiently.
This results in normalization taking orders of magnitude longer than it should, making it very time consuming to tokenize long sequences. I've seen a few issues that probably refer to this, for example huggingface/transformers#25873.

This PR replaces the existing implementation -- which seems to scale quadratically with sequence length and number of matches -- with an implementation that scales linearly, while (hopefully) retaining the exact same semantics.
In my benchmarks with real long-sequence data, tokenizing with the Llama tokenizer is more than two orders of magnitude faster.

@ArthurZucker
Copy link
Collaborator

I'll have a look, sounds really interesting !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker
Copy link
Collaborator

(I am planning on remove the normalizer of Llama and mistral families in favor of pre_tokenizer, but will still check this out!)

@rlrs
Copy link
Contributor Author

rlrs commented Dec 20, 2023

(I am planning on remove the normalizer of Llama and mistral families in favor of pre_tokenizer, but will still check this out!)

That's fair, until then this fix will let people tokenize as long sequences as they want.
The code should work with any non-overlapping matches (I'm not sure overlapping matches were ever supported), so it might help with other tokenizers as well, but I don't know if any others use the feature.

@rlrs
Copy link
Contributor Author

rlrs commented Jan 25, 2024

I see this got closed from inactivity. Do you need anything from me in order to merge this? Profiling, documentation?

@ArthurZucker
Copy link
Collaborator

Sorry I'll take some time to review!

@ArthurZucker
Copy link
Collaborator

Will be my priority review!

Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Sorry I've just seen this , seems quite important change indeed.

@Narsil
Copy link
Collaborator

Narsil commented Feb 6, 2024

I just got pinged internally on this.

Before
image
After
Screenshot from 2024-02-06 11-17-21

from matplotlib import pyplot as plt
import time
from tqdm import tqdm
from tokenizers import Tokenizer

import numpy as np
lens = np.arange(0, 100000, 100)

with open("data/big.txt") as f:
    TEXT = f.read()

times_fast = []
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
for ll in tqdm(lens):
    text = TEXT[:ll]
    start = time.perf_counter()
    tokenizer.encode(text)
    times_fast.append(time.perf_counter() - start)



timesgpt2 = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
for ll in tqdm(lens):
    text = TEXT[:ll]
    start = time.perf_counter()
    tokenizer.encode(text)
    timesgpt2.append(time.perf_counter() - start)

times = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
for ll in tqdm(lens):
    text = TEXT[:ll]
    start = time.perf_counter()
    tokenizer.encode(text)
    times.append(time.perf_counter() - start)

plt.plot(lens, times_fast)
plt.plot(lens, timesgpt2)
plt.plot(lens, times)
plt.legend(["mistral (tokenizers)", "gpt2 (tokenizers)", "mistral(spm)"])
plt.xlabel("Length in chars")
plt.ylabel("tokenization time (seconds)")
plt.show()

@Narsil Narsil merged commit c893204 into huggingface:main Feb 6, 2024
12 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks a lot @rlrs ! Really sorry about the delay and shoutout to you for this clean piece of work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants