Skip to content

Commit

Permalink
Quantized KV Cache (#30483)
Browse files Browse the repository at this point in the history
* clean-up

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <[email protected]>

* fixup

* Update tests/quantization/quanto_integration/test_quanto.py

Co-authored-by: Younes Belkada <[email protected]>

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Arthur <[email protected]>

* more suggestions

* mapping if torch available

* run tests & add 'support_quantized' flag

* fix jamba test

* revert, will be fixed by another PR

* codestyle

* HQQ and versatile cache classes

* final update

* typo

* make tests happy

---------

Co-authored-by: Arthur <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
  • Loading branch information
3 people authored and Ita Zaporozhets committed May 24, 2024
1 parent bd263c1 commit f715225
Show file tree
Hide file tree
Showing 19 changed files with 652 additions and 28 deletions.
3 changes: 3 additions & 0 deletions docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ RUN python3 -m pip install --no-cache-dir gguf
# Some slow tests require bnb
RUN python3 -m pip install --no-cache-dir bitsandbytes

# Some tests require quanto
RUN python3 -m pip install --no-cache-dir quanto

# For `dinat` model
# The `XXX` part in `torchXXX` needs to match `PYTORCH` (to some extent)
RUN python3 -m pip install --no-cache-dir natten==0.15.1+torch220$CUDA -f https://shi-labs.com/natten/wheels
Expand Down
37 changes: 37 additions & 0 deletions docs/source/en/generation_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,43 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te
```


## KV Cache Quantization

The `generate()` method supports caching keys and values to enhance efficiency and avoid re-computations. However the key and value
cache can occupy a large portion of memory, becoming a bottleneck for long-context generation, especially for Large Language Models.
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.

KV Cache quantization in `transformers` is largely inspired by the paper [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache]
(https://arxiv.org/abs/2402.02750) and currently supports `quanto` and `HQQ` as backends. For more information on the inner workings see the paper.

To enable quantization of the key-value cache, one needs to indicate `cache_implementation="quantized"` in the `generation_config`.
Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`QuantizedCacheConfig`] class.
One has to indicate which quantization backend to use in the [`QuantizedCacheConfig`], the default is `quanto`.

<Tip warning={true}>

Cache quantization can be detrimental if the context length is short and there is enough GPU VRAM available to run without cache quantization.

</Tip>


```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM

>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
>>> inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"nbits": 4, "backend": "quanto"})
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
I like rock music because it's loud and energetic. It's a great way to express myself and rel

>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20)
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
I like rock music because it's loud and energetic. I like to listen to it when I'm feeling
```

## Watermarking

The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
Expand Down
16 changes: 15 additions & 1 deletion docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,27 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] Cache
- update

[[autodoc]] CacheConfig
- update

[[autodoc]] QuantizedCacheConfig
- validate

[[autodoc]] DynamicCache
- update
- get_seq_length
- reorder_cache
- to_legacy_cache
- from_legacy_cache

[[autodoc]] QuantizedCache
- update
- get_seq_length

[[autodoc]] QuantoQuantizedCache

[[autodoc]] HQQQuantizedCache

[[autodoc]] SinkCache
- update
- get_seq_length
Expand All @@ -375,7 +389,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] StaticCache
- update
- get_seq_length
- reorder_cache
- reset


## Watermark Utils
Expand Down
24 changes: 22 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,17 @@
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache", "StaticCache"]
_import_structure["cache_utils"] = [
"Cache",
"CacheConfig",
"DynamicCache",
"HQQQuantizedCache",
"QuantizedCache",
"QuantizedCacheConfig",
"QuantoQuantizedCache",
"SinkCache",
"StaticCache",
]
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
Expand Down Expand Up @@ -5792,7 +5802,17 @@
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .cache_utils import Cache, DynamicCache, SinkCache, StaticCache
from .cache_utils import (
Cache,
CacheConfig,
DynamicCache,
HQQQuantizedCache,
QuantizedCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SinkCache,
StaticCache,
)
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
Expand Down
Loading

0 comments on commit f715225

Please sign in to comment.