Skip to content

Commit

Permalink
cache is now init in generate; update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Jul 9, 2024
1 parent c58f24e commit cd88500
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 36 deletions.
99 changes: 69 additions & 30 deletions docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,53 +26,57 @@ During decoding, a LLM computes the key-value (kv) values for each input token a

To optimize this, you can use a kv-cache to store the past keys and values instead of recomputing them each time. However, since the kv-cache grows with each generation step and is dynamic, it prevents you from taking advantage of [`torch.compile`](./perf_torch_compile), a powerful optimization tool that fuses PyTorch code into fast and optimized kernels.

The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with `torch.compile` for up to a 4x speed up.
The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with `torch.compile` for up to a 4x speed up. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware.

> [!WARNING]
> Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and `torch.compile`. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list.
For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model.

```py
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")
```

There are two ways you can configure the model to use a static kv-cache. For a 7B model on an A100, both methods get a ~4x speed up in the forward pass. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware. If you're using the [`~GenerationMixin.generate`] method, the speed up will be smaller -- the forward pass is only a part of the whole [`~GenerationMixin.generate`] code.
There are three flavors of static kv-cache usage, depending on the complexity of your task:
1. Basic usage: simply set a flag in `generation_config` (recommended);
2. Advanced usage: handle a cache object for multi-turn generation or a custom generation loop;
3. Advanced usage: compile the entire `generate` function into a single graph, if having a single graph is relevant for you.

> [!TIP]
> Regardless of the strategy used with `torch.compile`, you can avoid shape-related recompilations if you left-pad your LLM inputs to a limited set of values. The [`pad_to_multiple_of` tokenizer flag](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.pad_to_multiple_of) is your friend!
<hfoptions id="static-kv">
<hfoption id="generation_config (basic usage)">
<hfoption id="Basic usage -- generation_config">

Access the model's `generation_config` attribute and set the `cache_implementation` to "static".
For this example, let's use the [Gemma](https://hf.co/google/gemma-2b) model. All we need to do is to:
1. Access the model's `generation_config` attribute and set the `cache_implementation` to "static";
2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache.

And that's it!

```py
model.generation_config.cache_implementation = "static"
```
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)

Call `torch.compile` on the model to compile the forward pass with the static kv-cache.
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")

model.generation_config.cache_implementation = "static"

```py
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['The theory of special relativity states 1. The speed of light is constant in all inertial reference']
```

Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increases between calls, the cache will have to be reinitialized, triggering a new compilation.
Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. Avoiding re-compilation is critical to get the most out of `torch.compile`, and you should be aware of the following:
1. If the batch size changes or the maximum output length increases between calls, the cache will have to be reinitialized, triggering a new compilation;
2. The first couple of calls of the compiled function are slower, as the function is being compiled.

> [!WARNING]
> For a more advanced usage of the static cache, such as multi-turn conversations or compiling the entire [`~GenerationMixin.generate`] function, we recommend instantiating and manipulating the cache object outside [`~GenerationMixin.generate`].
> For a more advanced usage of the static cache, such as multi-turn conversations, we recommend instantiating and manipulating the cache object outside [`~GenerationMixin.generate`]. See the advanced usage tab.
</hfoption>
<hfoption id="Static Cache (advanced usage)">
<hfoption id="Advanced usage -- control Static Cache">

A [`StaticCache`] object can be passed to the model's [`~GenerationMixin.generate`] under the `past_key_values` argument. The object will retain the cache contents, so you can pass it to a new [`~GenerationMixin.generate`] call to continue generation, like you would do with a dynamic cache.

Expand All @@ -85,7 +89,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")

model.generate = torch.compile(model.generate, mode="reduce-overhead", fullgraph=True)
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
prompt_length = input_ids.input_ids.shape[1]
Expand All @@ -95,19 +99,24 @@ past_key_values = StaticCache(
config=model.config,
max_batch_size=1,
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
max_cache_len=prompt_length+model.generation_config.max_new_tokens,
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
device=model.device,
dtype=model.dtype
)
outputs = model.generate(**input_ids, past_key_values=past_key_values)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2']
```

Be mindful that full [`~GenerationMixin.generate`] compilation has severe feature limitations, and is still under development. For instance, all parameterization has to be done through `generation_config`. It can, however, be compiled without graph breaks.
# pass in the generated text and the same cache object to continue generation from where it left off. Optionally, in a
# multi-turn conversation, append the new user input to the generated text.
new_input_ids = outputs
outputs = model.generate(new_input_ids, past_key_values=past_key_values)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2. The speed of light is constant in all inertial reference frames. 3.']
```

> [!TIP]
> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method between calls
> If you want to reuse the same [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method between calls
If you want to go further down a level, the [`StaticCache`] object can also be passed to the model's forward pass under the same `past_key_values` argument. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens.

Expand Down Expand Up @@ -142,11 +151,8 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
```

There are a few important things you must do to enable static kv-cache and `torch.compile` with the `StaticCache` method:

1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length.

2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache.

3. Set `enable_math=True` in the [torch.backends.cuda.sdp_kernel](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more.

```py
Expand Down Expand Up @@ -181,6 +187,39 @@ text
'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p']
```

</hfoption>
<hfoption id="Advanced usage -- end-to-end generate compilation">

Compiling the entire `generate` function, in terms of code, is as simple as in the basic usage:
1. Access the model's `generation_config` attribute and set the `cache_implementation` to "static";
2. Call `torch.compile` on `generate` to compile the entire function with the static kv-cache.

```py
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")

model.generation_config.cache_implementation = "static"

model.generate = torch.compile(model.generate, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['The theory of special relativity states 1. The speed of light is constant in all inertial reference']
```

As a result, we compile not only the model forward pass, but also all input preparation, logit processor operations, and so on. The result should be a slightly `generate` call, compared to the basic usage example, and the compiled graph may be better suited to more exotic hardware devices or use cases. However, there are severe drawbacks in using this approach:
1. Compilation is much slower;
2. All parameterization of `generate` must be done through `generation_config`;
3. Many warnings and exceptions are suppressed -- we suggest testing with its uncompiled form first;
4. Although we are working on it, it is heavily feature restricted (for instance, at the time of writing, generation does not stop if an EOS token is selected).

</hfoption>
</hfoptions>

Expand Down
10 changes: 6 additions & 4 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from packaging import version

from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_quanto_available, logging
from .utils import is_hqq_available, is_quanto_available, is_torchdynamo_compiling, logging


if is_quanto_available():
Expand Down Expand Up @@ -820,11 +820,13 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for _ in range(config.num_hidden_layers):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

Expand Down
15 changes: 14 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,7 +1451,14 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
if not is_torchdynamo_compiling():
cache_dtype = self.dtype
else:
# NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`.
# Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative
# models. May cause trobles with non-text modalities.
cache_dtype = self.lm_head.weight.dtype

cache_kwargs = {
"config": self.config,
"max_batch_size": max_batch_size,
Expand Down Expand Up @@ -1756,6 +1763,12 @@ def generate(
)

use_dynamic_cache_by_default = False
if model_kwargs.get("past_key_values") is not None and is_torchdynamo_compiling():
raise ValueError(
"Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you "
"may get incorrect outputs. Please compile `model.forward` only or use the `cache_implementation` "
"input argument."
)
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
raise ValueError(
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
Expand Down
4 changes: 3 additions & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,7 +1760,9 @@ def test_generate_compile_fullgraph(self):
output_static = model.generate(model_inputs, **generation_kwargs)
self.assertListEqual(output_dynamic.tolist(), output_static.tolist())

# compiled static cache
# compiled static cache (removes the cache initialized in the previous check, to confirm we can
# initialize the cache in full compiled mode)
model._cache = None
torch.compiler.reset()
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
Expand Down

0 comments on commit cd88500

Please sign in to comment.