Skip to content

Commit

Permalink
add more details to the compilation docs
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed May 29, 2024
1 parent 3593971 commit 3bea9b6
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto

There are two ways you can configure the model to use a static kv-cache. For a 7B model on an A100, both methods get a ~4x speed up in the forward pass. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware. If you're using the [`~GenerationMixin.generate`] method, the speed up will be smaller -- the forward pass is only a part of the whole [`~GenerationMixin.generate`] code.

> [!TIP]
> Regardless of the strategy used with `torch.compile`, you can avoid shape-related recompilations if you left-pad your LLM inputs to a limited set of values. The [`pad_to_multiple_of` tokenizer flag](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.pad_to_multiple_of) is your friend!
<hfoptions id="static-kv">
<hfoption id="generation_config (basic usage)">

Expand All @@ -54,28 +57,25 @@ model.generation_config.cache_implementation = "static"
Call `torch.compile` on the model to compile the forward pass with the static kv-cache.

```py
compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = compiled_model.generate(**input_ids)
outputs = model.generate(**input_ids)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
['The theory of special relativity states 1. The speed of light is constant in all inertial reference']
```

Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increase between calls, the cache will have to be reinitialized, triggering a new compilation.
Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increases between calls, the cache will have to be reinitialized, triggering a new compilation.

> [!WARNING]
> For a more advanced usage of the static cache, such as compiling the entire [`~GenerationMixin.generate`] function or manually prefilling the cache, we recommend instantiating and manipulating the cache object outside [`~GenerationMixin.generate`].
> For a more advanced usage of the static cache, such as multi-turn conversations or compiling the entire [`~GenerationMixin.generate`] function, we recommend instantiating and manipulating the cache object outside [`~GenerationMixin.generate`].
</hfoption>
<hfoption id="Static Cache (advanced usage)">

A [`StaticCache`] object can be passed to the model's [`~GenerationMixin.generate`] under the `past_key_values` argument. The object will retain the cache contents, so you can pass it to a new [`~GenerationMixin.generate`] call to continue generation, like you would do with a dynamic cache.

> [!TIP]
> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method between calls
```py
from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
import torch
Expand All @@ -85,7 +85,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")

compiled_generate = torch.compile(model.generate, mode="reduce-overhead", fullgraph=True)
model.generate = torch.compile(model.generate, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
prompt_length = input_ids.input_ids.shape[1]
Expand All @@ -94,17 +94,20 @@ model.generation_config.max_new_tokens = 16
past_key_values = StaticCache(
config=model.config,
max_batch_size=1,
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
max_cache_len=prompt_length+model.generation_config.max_new_tokens,
device=model.device,
dtype=model.dtype
)
outputs = compiled_generate(**input_ids, past_key_values=past_key_values)
outputs = model.generate(**input_ids, past_key_values=past_key_values)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2']
```

Be mindful that full [`~GenerationMixin.generate`] compilation has severe feature limitations, and is still under development. For instance, all parameterization has to be done through `generation_config`. It can, however, be compiled without graph breaks.

> [!TIP]
> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method between calls
If you want to go further down a level, the [`StaticCache`] object can also be passed to the model's forward pass under the same `past_key_values` argument. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens.

Expand Down Expand Up @@ -181,9 +184,6 @@ text
</hfoption>
</hfoptions>

> [!TIP]
> Regardless of the strategy used with `torch.compile`, through `forward` or [`~GenerationMixin.generate`], you can avoid shape-related recompilations if you left-pad your LLM inputs to a limited set of values. The [`pad_to_multiple_of` tokenizer flag](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.pad_to_multiple_of) is your friend!
## Speculative decoding

> [!TIP]
Expand Down

0 comments on commit 3bea9b6

Please sign in to comment.