From 3bea9b6ec85c7770fd2b954cdd655d3ee6b6e0af Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 29 May 2024 11:02:28 +0000 Subject: [PATCH] add more details to the compilation docs --- docs/source/en/llm_optims.md | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index ddb16d44893af3..3b18955bf85c55 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -42,6 +42,9 @@ model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto There are two ways you can configure the model to use a static kv-cache. For a 7B model on an A100, both methods get a ~4x speed up in the forward pass. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware. If you're using the [`~GenerationMixin.generate`] method, the speed up will be smaller -- the forward pass is only a part of the whole [`~GenerationMixin.generate`] code. +> [!TIP] +> Regardless of the strategy used with `torch.compile`, you can avoid shape-related recompilations if you left-pad your LLM inputs to a limited set of values. The [`pad_to_multiple_of` tokenizer flag](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.pad_to_multiple_of) is your friend! + @@ -54,28 +57,25 @@ model.generation_config.cache_implementation = "static" Call `torch.compile` on the model to compile the forward pass with the static kv-cache. ```py -compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) +model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) input_text = "The theory of special relativity states " input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") -outputs = compiled_model.generate(**input_ids) +outputs = model.generate(**input_ids) tokenizer.batch_decode(outputs, skip_special_tokens=True) ['The theory of special relativity states 1. The speed of light is constant in all inertial reference'] ``` -Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increase between calls, the cache will have to be reinitialized, triggering a new compilation. +Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increases between calls, the cache will have to be reinitialized, triggering a new compilation. > [!WARNING] -> For a more advanced usage of the static cache, such as compiling the entire [`~GenerationMixin.generate`] function or manually prefilling the cache, we recommend instantiating and manipulating the cache object outside [`~GenerationMixin.generate`]. +> For a more advanced usage of the static cache, such as multi-turn conversations or compiling the entire [`~GenerationMixin.generate`] function, we recommend instantiating and manipulating the cache object outside [`~GenerationMixin.generate`]. A [`StaticCache`] object can be passed to the model's [`~GenerationMixin.generate`] under the `past_key_values` argument. The object will retain the cache contents, so you can pass it to a new [`~GenerationMixin.generate`] call to continue generation, like you would do with a dynamic cache. -> [!TIP] -> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method between calls - ```py from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache import torch @@ -85,7 +85,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :) tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto") -compiled_generate = torch.compile(model.generate, mode="reduce-overhead", fullgraph=True) +model.generate = torch.compile(model.generate, mode="reduce-overhead", fullgraph=True) input_text = "The theory of special relativity states " input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") prompt_length = input_ids.input_ids.shape[1] @@ -94,17 +94,20 @@ model.generation_config.max_new_tokens = 16 past_key_values = StaticCache( config=model.config, max_batch_size=1, + # If you plan to reuse the cache, make sure the cache length is large enough for all cases max_cache_len=prompt_length+model.generation_config.max_new_tokens, device=model.device, dtype=model.dtype ) -outputs = compiled_generate(**input_ids, past_key_values=past_key_values) +outputs = model.generate(**input_ids, past_key_values=past_key_values) tokenizer.batch_decode(outputs, skip_special_tokens=True) ['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2'] ``` Be mindful that full [`~GenerationMixin.generate`] compilation has severe feature limitations, and is still under development. For instance, all parameterization has to be done through `generation_config`. It can, however, be compiled without graph breaks. +> [!TIP] +> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method between calls If you want to go further down a level, the [`StaticCache`] object can also be passed to the model's forward pass under the same `past_key_values` argument. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. @@ -181,9 +184,6 @@ text -> [!TIP] -> Regardless of the strategy used with `torch.compile`, through `forward` or [`~GenerationMixin.generate`], you can avoid shape-related recompilations if you left-pad your LLM inputs to a limited set of values. The [`pad_to_multiple_of` tokenizer flag](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.pad_to_multiple_of) is your friend! - ## Speculative decoding > [!TIP]