Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantized KV Cache #30483

Merged
merged 21 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ RUN python3 -m pip install --no-cache-dir gguf
# Some slow tests require bnb
RUN python3 -m pip install --no-cache-dir bitsandbytes

# Some tests require quanto
RUN python3 -m pip install --no-cache-dir quanto

# For `dinat` model
# The `XXX` part in `torchXXX` needs to match `PYTORCH` (to some extent)
RUN python3 -m pip install --no-cache-dir natten==0.15.1+torch220$CUDA -f https://shi-labs.com/natten/wheels
Expand Down
37 changes: 37 additions & 0 deletions docs/source/en/generation_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,43 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te
```


## KV Cache Quantization

The `generate()` method supports caching keys and values to enhance efficiency and avoid re-computations. However the key and value
cache can occupy a large portion of memory, becoming a bottleneck for long-context generation, especially for Large Language Models.
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved

ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
KV Cache quantization in `transformers` is largely inspired by the paper [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache]
(https://arxiv.org/abs/2402.02750) and currently supports `quanto` and `HQQ` as backends. For more information on the inner workings see the paper.

To enable quantization of the key-value cache, one needs to indicate `cache_implementation="quantized"` in the `generation_config`.
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`QuantizedCacheConfig`] class.
One has to indicate which quantization backend to use in the [`QuantizedCacheConfig`], the default is `quanto`.

<Tip warning={true}>

Cache quantization can be detrimental if the context length is short and there is enough GPU VRAM available to run without cache quantization.

</Tip>


```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM

>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
>>> inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"nbits": 4, "backend": "quanto"})
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
I like rock music because it's loud and energetic. It's a great way to express myself and rel

>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20)
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
I like rock music because it's loud and energetic. I like to listen to it when I'm feeling
```

## Watermarking

The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
Expand Down
16 changes: 15 additions & 1 deletion docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,27 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] Cache
- update

[[autodoc]] CacheConfig
- update

[[autodoc]] QuantizedCacheConfig
- validate

[[autodoc]] DynamicCache
- update
- get_seq_length
- reorder_cache
- to_legacy_cache
- from_legacy_cache

[[autodoc]] QuantizedCache
- update
- get_seq_length

[[autodoc]] QuantoQuantizedCache

[[autodoc]] HQQQuantizedCache

[[autodoc]] SinkCache
- update
- get_seq_length
Expand All @@ -375,7 +389,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] StaticCache
- update
- get_seq_length
- reorder_cache
- reset


## Watermark Utils
Expand Down
20 changes: 18 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,15 @@
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache", "StaticCache"]
_import_structure["cache_utils"] = [
"Cache",
"CacheConfig",
"DynamicCache",
"QuantizedCacheConfig",
"QuantoQuantizedCache",
"SinkCache",
"StaticCache",
]
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
Expand Down Expand Up @@ -5778,7 +5786,15 @@
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .cache_utils import Cache, DynamicCache, SinkCache, StaticCache
from .cache_utils import (
Cache,
CacheConfig,
DynamicCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SinkCache,
StaticCache,
)
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
Expand Down
Loading