-
Notifications
You must be signed in to change notification settings - Fork 520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Understanding why TorchInductor cannot speed-up huggingface transformer inference #59
Comments
I think HF llama does not have a static kv cache, since its cache is dynamically increased during generation. Here is the relavent code: https://github.com/huggingface/transformers/blob/38611086d293ea4a5809bcd7fadd8081d55cb74e/src/transformers/models/llama/modeling_llama.py#L1014C37-L1014C37 |
This should solve the problem😄 |
Yes! Static KV cache is not supported but coming soon! |
@learning-chip @ArthurZucker |
Closing since core issue in huggingface was a dynamic KV cache which was made static |
Problem
torch.compile()
shows an impressive ~2x speed-up for this code repo, but when applying to huggingface transformers there is barely no speed-up. I want to understand why, and then figure out how TorchInductor can also benefit HF models (related issue #9)Comparing HF's
model.generate()
vsgpt-fast
under the same setting (same prompt, output length, sampling, data type, ...), I found that (on RTX 4090):compile()
, HFgenerate()
(39.4 token/s) is faster than gpt-fast (28 token/s)generate()
has almost no speed-up (still 39.4 token/s); gpt-fast gets much faster (68.5 token/s)The blog mentions statically allocating KV cache, but isn't this also implemented in the HF llama model?
Benchmark code
GPT-fast
--max_new_tokens 134
is to match HF's output length, as thisgpt-fast
repo continues to generate text even when hitting the end token</s>
.HuggingFace
Run the script below by
The default sampling settings are the same as this repo's generate.py
Output results
gpt-fast:
For eager, output texts are the same as Huggingface, although random seed settings are different from HF script.
With Inductor, the output texts becomes different (not sure due to random seed or float-point issues), although still sensible.
Huggingface:
Environment
Torch installed by
which grabs https://download.pytorch.org/whl/nightly/cu121/torch-2.3.0.dev20231217%2Bcu121-cp310-cp310-linux_x86_64.whl
Similar results with torch
2.1.2+cu121
#46 (comment)The text was updated successfully, but these errors were encountered: