-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core generation
] Adds support for static KV cache
#27931
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Core genration
] Adds support for static KV cacheCore generation
] Adds support for static KV cache
If I understand correctly, this PR should close the existing gap between inference with transformers + AutoGPTQ and inference with ExLlama, as the VRAM usage would become much more controlled. I'm rooting for it :) |
Thanks! 🤗 |
Exciting PR! |
Hi @ArthurZucker It seems that the increase in VRAM could potentially lead to out of memory (OOM) comment1 comment2, as pointed out in this PR by @danielhanchen
Could you please take a look into it? |
Yes! And @paulcx I'm sorry this broke for you |
Does this work for llava? from my testing it doesnt |
No, you have to update the |
We can also just remove it, but then you need to allocate the causal_mask at each forward pass |
@ArthurZucker If I understand correctly, my thought is to lower the max_position_embedding, such as 4096, because it only affects the initial position embedding and caching for 200K. But during inference, lengths approaching 200K will still be calculated, just slower. This workaround, it can ensure normal training and inference for non-200K cases. Is my understanding correct? |
Thank you and great work on new release @ArthurZucker. Would you mind clarifying the use case of "static / compile cache" in release note? I'm not sure if I understand correctly. |
It is mostly this: https://gist.github.com/ArthurZucker/af34221def212259b43d55a2811d2dbb, you can get x4 generation speed in transformers with torch compile and static cache! |
Is this expected to work with llava-next? |
I believe so yes, if not we can add support for it |
Feel free to open an issue if it doesnt work |
I have tried and it dont work because the vision tower changes the shape of inputs after encoding to patches. Also, it doesnt work for bnb 4 bits |
bnb is a different issue, torch.compile might not support this (int8 yes). |
we are using NF4 for bnb |
@ArthurZucker Just checking have you added this to |
Co-authored-by: fxmarty <[email protected]> Co-authored-by: Younes Belkada <[email protected]> Co-authored-by: Joao Gante <[email protected]>
~4x speedups with cuda graphs! 🥳
Currently getting ~4x speedups compare to dynamic cache with torch.compile for a single forward pass (agnostic to batch but faster for smaller batch)
Forward is very very very fast, but materializing the input costs a bit!
~10ms / forward is what we get to!
past_key_values
mask_attn_utils
😄max_length
from the generation config (taking max_new_tokens) into accountBenchmark using af097af
Use it in generate:
Use this: EDIT: TO COME
Failing test left for @gante
Related to the fact that I don't return
past_key_values
/ is None so thetest_new_cache_format
fails. I don't want to dive in this.fixes #28075 , fixes #28610, fixes #28190