-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Quantized KV Cache #30483
Quantized KV Cache #30483
Conversation
As we discussed quantized cache can be started to be integrated to the library, given the results we got so far. All the possible speed optimizations/pre-fill stage optimizations can be done further, as we will be getting feedback from the community. So, I would like to get a review on the PR :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
API wise looks really great ! I did not spotted anything critical here that needs to be addressed (and I will let joao give a deeper review on the cache file changes) - except for guarding quanto imports (also I would say safer to make local imports whenever possible - e.g. at QuantCache
init)
You raised a concern about switching between cache implementations - I made an attempt while ago: #29030 that got stale (😅 ) maybe that PR might solve your concern?
Maybe we could also track models that support quant cache with a private attribute _supports_quant_cache
in xxxPreTrainedModel
- what do you think?
Thanks for the comments!
Okey noted!
I love the generalized cache implementation idea. Not sure how this will work on overall API level, given that Joao and Arthur are working on changing cache thing. I'll let Joao to decide about that
Hmm, Actually quant cache should be supported abywhere dynamicCache is, that means everything except for old models like bart/t5. Yeah I think we can add it for explicitness, until the cache API is refactored to be same everywhere |
Thanks !
Ok that's great if that's the case then, i would say no need for that ! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will need a rebase due to #30476, but I love this POC -- in fact, I've reviewed it as if it was not a POC 😉
After removing the extra .py
files and adding some docs, I believe it is ready to be launched! And I also think it deserves a blog post :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments for you to work on, but let's gather the benchmarks first :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @zucchini-nlp ! 🚀 I only left nits and one open question with respect to tests, otherwise it looks really great !
@gante added benchmark results on the PR description. Right now int4 has almost same performance as fp16, sometimes a bit better. Also added some comparison with the KIVI paper. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 🙌 Thank you for iterating on this very cool project!
(CI needs fixing -- possibly a simple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very interesting work! Having both cache and quantizing on the fly when needed is very interesting!
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work ! Left one nit about tests !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Last few nits and should be good to go!
Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: Arthur <[email protected]>
@ArthurZucker @ydshieh: "torch.compile with quanto is only supported for 8 bits quantization for now" (from @SunMarc, on a related conversation on slack) |
I made the KV cache work with HQQ as a backend. It can be simply plugged in if a user writes their own "CacheClass". I am not planning to add it now as it needs more evaluation and experiments, but wanted to show how anyone can add more backends. Do you think I should continue experimenting with HQQ or we can simply put the below code as example for users? BTW, if we were to actually support more cache quant classes in the library, maybe we'll need to change the current QuantCache API a bit to be more versatile. from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from hqq.core.quantize import Quantizer as HQQQuantizer
class HQQQuantizedCache(DynamicCache):
def __init__(
self,
nbits: int = 4,
axis: int = 0,
q_group_size: int = 64,
residual_length: int = 128,
compute_dtype: torch.dtype = torch.float16,
device: str = "cpu",
) -> None:
if nbits not in [2, 4, 8]:
raise ValueError(f"`nbits` has to be one of [`2`, `4`, `8`] but got {nbits}")
if axis not in [0, 1]:
raise ValueError(f"`axis` has to be one of [`1`, `2`] but got {axis}")
self._quantized_key_cache: List[Tuple[torch.Tensor, Dict]] = []
self._quantized_value_cache: List[Tuple[torch.Tensor, Dict]] = []
self.nbits = nbits
self.axis = axis
self.residual_length = residual_length
self.q_group_size = q_group_size
self.compute_dtype = compute_dtype
self.quantizer = HQQQuantizer
self.device = device
super().__init__()
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
if len(self.key_cache) <= layer_idx:
q_key, meta_key = self._quantize(key_states.contiguous())
self._quantized_key_cache.append((q_key, meta_key))
q_value, meta_value = self._quantize(value_states.contiguous())
self._quantized_value_cache.append((q_value, meta_value))
self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
keys_to_return, values_to_return = key_states, value_states
else:
quant_key, meta_key = self._quantized_key_cache[layer_idx]
dequant_key = self.quantizer.dequantize(quant_key, meta_key)
quant_value, meta_value = self._quantized_value_cache[layer_idx]
dequant_value = self.quantizer.dequantize(quant_value, meta_value)
keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]
keys_to_return = torch.cat(keys_to_return, dim=-2)
values_to_return = torch.cat(values_to_return, dim=-2)
if (
self.key_cache[layer_idx].dim() == 4
and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
):
q_key, meta_key = self._quantize(keys_to_return.contiguous())
self._quantized_key_cache[layer_idx] = (q_key, meta_key)
q_value, meta_value = self._quantize(values_to_return.contiguous())
self._quantized_key_cache[layer_idx] = (q_value, meta_value)
self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return keys_to_return, values_to_return
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
# since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
# updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
# this part of code otherwise fails when used to verify attn_weight shape in some models
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
def _quantize(self, tensor):
qtensor, meta = self.quantizer.quantize(
tensor,
axis=self.axis,
device=self.device,
compute_dtype=self.compute_dtype,
nbits=self.nbits,
group_size=self.q_group_size,
)
meta["compute_dtype"] = self.compute_dtype
return qtensor, meta
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, attn_implementation="eager", device_map = "auto")
inputs = tokenizer("I like rock music because" return_tensors="pt").to(model.device)
out = model.generate(
**inputs,
do_sample=False,
max_new_tokens=50,
past_key_values=HQQQuantizedCache(
nbits=2,
axis=1, # 2bit with axis=0 generates garbage
compute_dtype=torch.float16,
device=model.device
),
)
print(f"text with HQQ backend: {tokenizer.batch_decode(out)}") |
I think that making the cache class versatile is great to have people build on top of it, without necessarily including anythinig in |
@ArthurZucker yes, making a versatile cache class will go on another PR. In that case we can leave |
sounds good |
@ArthurZucker @gante I made a few changes from the last review:
I added a new usage ex in the description and will rework a bit the blogpost, given that now support HQQ. This PR is ready for the second review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💛 💛 💛
Cool, merging 🤞🏻 Ran slow tests in quantization and generation locally, everything is passing. |
I am wondering if we can have this works together #30862. If so, we can probably get further more speedup! @zucchini-nlp Could you share the simplest code snippet that you use for this PR to measure the runtime (latency)? I can try to incorporate this with #30862 🙏 |
OK. Thanks for sharing, so this PR is more about memory instead of speed. |
* clean-up * Update src/transformers/cache_utils.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/cache_utils.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/cache_utils.py Co-authored-by: Arthur <[email protected]> * fixup * Update tests/quantization/quanto_integration/test_quanto.py Co-authored-by: Younes Belkada <[email protected]> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Arthur <[email protected]> * more suggestions * mapping if torch available * run tests & add 'support_quantized' flag * fix jamba test * revert, will be fixed by another PR * codestyle * HQQ and versatile cache classes * final update * typo * make tests happy --------- Co-authored-by: Arthur <[email protected]> Co-authored-by: Younes Belkada <[email protected]>
* clean-up * Update src/transformers/cache_utils.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/cache_utils.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/cache_utils.py Co-authored-by: Arthur <[email protected]> * fixup * Update tests/quantization/quanto_integration/test_quanto.py Co-authored-by: Younes Belkada <[email protected]> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Arthur <[email protected]> * more suggestions * mapping if torch available * run tests & add 'support_quantized' flag * fix jamba test * revert, will be fixed by another PR * codestyle * HQQ and versatile cache classes * final update * typo * make tests happy --------- Co-authored-by: Arthur <[email protected]> Co-authored-by: Younes Belkada <[email protected]>
What does this PR do?
An implementation of quantized cache with
quanto
library. Introduces a newCacheConfig
to store cache related arguments and a new cache classQuantoQuantizedCache
. The implementation is based partially on the KIVI paper, but in this case we do a per-token quantization for both: keys and values.PR for HF blogpost here
Example usage:
Perplexity plots
Here the results are different from what we got earlier because I was calculating perplexity in one forward pass, by quantizing and then dequantizing all keys and values. The new script uses cache object and calculates pplx per new token.Eval on LongBench (scripts taken from LongBench repo)
This is to compare with the KIVI method, since they did the same evals on all datasets from LongBench.I cannot find KIVI results on all of the LongBench, so here will be only
transformers
version.Memory vs Latency plots
Same old plots showing memory consumption and latency for differeny cache types:UPDATE:
Latest commit has added possibility to choose HQQ or quanto as backend. Usage: