diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 3b18955bf85c55..a3c55f6c1f54cb 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -26,53 +26,57 @@ During decoding, a LLM computes the key-value (kv) values for each input token a To optimize this, you can use a kv-cache to store the past keys and values instead of recomputing them each time. However, since the kv-cache grows with each generation step and is dynamic, it prevents you from taking advantage of [`torch.compile`](./perf_torch_compile), a powerful optimization tool that fuses PyTorch code into fast and optimized kernels. -The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with `torch.compile` for up to a 4x speed up. +The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with `torch.compile` for up to a 4x speed up. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware. > [!WARNING] > Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and `torch.compile`. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list. -For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model. - -```py -from transformers import AutoTokenizer, AutoModelForCausalLM - -tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") -model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto") -``` - -There are two ways you can configure the model to use a static kv-cache. For a 7B model on an A100, both methods get a ~4x speed up in the forward pass. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware. If you're using the [`~GenerationMixin.generate`] method, the speed up will be smaller -- the forward pass is only a part of the whole [`~GenerationMixin.generate`] code. +There are three flavors of static kv-cache usage, depending on the complexity of your task: +1. Basic usage: simply set a flag in `generation_config` (recommended); +2. Advanced usage: handle a cache object for multi-turn generation or a custom generation loop; +3. Advanced usage: compile the entire `generate` function into a single graph, if having a single graph is relevant for you. > [!TIP] > Regardless of the strategy used with `torch.compile`, you can avoid shape-related recompilations if you left-pad your LLM inputs to a limited set of values. The [`pad_to_multiple_of` tokenizer flag](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.pad_to_multiple_of) is your friend! - + -Access the model's `generation_config` attribute and set the `cache_implementation` to "static". +For this example, let's use the [Gemma](https://hf.co/google/gemma-2b) model. All we need to do is to: +1. Access the model's `generation_config` attribute and set the `cache_implementation` to "static"; +2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache. + +And that's it! ```py -model.generation_config.cache_implementation = "static" -``` +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :) -Call `torch.compile` on the model to compile the forward pass with the static kv-cache. +tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") +model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto") + +model.generation_config.cache_implementation = "static" -```py model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) input_text = "The theory of special relativity states " input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") outputs = model.generate(**input_ids) -tokenizer.batch_decode(outputs, skip_special_tokens=True) +print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) ['The theory of special relativity states 1. The speed of light is constant in all inertial reference'] ``` -Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increases between calls, the cache will have to be reinitialized, triggering a new compilation. +Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. Avoiding re-compilation is critical to get the most out of `torch.compile`, and you should be aware of the following: +1. If the batch size changes or the maximum output length increases between calls, the cache will have to be reinitialized, triggering a new compilation; +2. The first couple of calls of the compiled function are slower, as the function is being compiled. > [!WARNING] -> For a more advanced usage of the static cache, such as multi-turn conversations or compiling the entire [`~GenerationMixin.generate`] function, we recommend instantiating and manipulating the cache object outside [`~GenerationMixin.generate`]. +> For a more advanced usage of the static cache, such as multi-turn conversations, we recommend instantiating and manipulating the cache object outside [`~GenerationMixin.generate`]. See the advanced usage tab. - + A [`StaticCache`] object can be passed to the model's [`~GenerationMixin.generate`] under the `past_key_values` argument. The object will retain the cache contents, so you can pass it to a new [`~GenerationMixin.generate`] call to continue generation, like you would do with a dynamic cache. @@ -85,7 +89,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :) tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto") -model.generate = torch.compile(model.generate, mode="reduce-overhead", fullgraph=True) +model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) input_text = "The theory of special relativity states " input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") prompt_length = input_ids.input_ids.shape[1] @@ -95,19 +99,24 @@ past_key_values = StaticCache( config=model.config, max_batch_size=1, # If you plan to reuse the cache, make sure the cache length is large enough for all cases - max_cache_len=prompt_length+model.generation_config.max_new_tokens, + max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2), device=model.device, dtype=model.dtype ) outputs = model.generate(**input_ids, past_key_values=past_key_values) -tokenizer.batch_decode(outputs, skip_special_tokens=True) +print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) ['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2'] -``` -Be mindful that full [`~GenerationMixin.generate`] compilation has severe feature limitations, and is still under development. For instance, all parameterization has to be done through `generation_config`. It can, however, be compiled without graph breaks. +# pass in the generated text and the same cache object to continue generation from where it left off. Optionally, in a +# multi-turn conversation, append the new user input to the generated text. +new_input_ids = outputs +outputs = model.generate(new_input_ids, past_key_values=past_key_values) +print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) +['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2. The speed of light is constant in all inertial reference frames. 3.'] +``` > [!TIP] -> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method between calls +> If you want to reuse the same [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method between calls If you want to go further down a level, the [`StaticCache`] object can also be passed to the model's forward pass under the same `past_key_values` argument. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. @@ -142,11 +151,8 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu ``` There are a few important things you must do to enable static kv-cache and `torch.compile` with the `StaticCache` method: - 1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length. - 2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache. - 3. Set `enable_math=True` in the [torch.backends.cuda.sdp_kernel](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more. ```py @@ -181,6 +187,39 @@ text 'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p'] ``` + + + +Compiling the entire `generate` function, in terms of code, is as simple as in the basic usage: +1. Access the model's `generation_config` attribute and set the `cache_implementation` to "static"; +2. Call `torch.compile` on `generate` to compile the entire function with the static kv-cache. + +```py +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :) + +tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") +model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto") + +model.generation_config.cache_implementation = "static" + +model.generate = torch.compile(model.generate, mode="reduce-overhead", fullgraph=True) +input_text = "The theory of special relativity states " +input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") + +outputs = model.generate(**input_ids) +print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) +['The theory of special relativity states 1. The speed of light is constant in all inertial reference'] +``` + +As a result, we compile not only the model forward pass, but also all input preparation, logit processor operations, and so on. The result should be a slightly `generate` call, compared to the basic usage example, and the compiled graph may be better suited to more exotic hardware devices or use cases. However, there are severe drawbacks in using this approach: +1. Compilation is much slower; +2. All parameterization of `generate` must be done through `generation_config`; +3. Many warnings and exceptions are suppressed -- we suggest testing with its uncompiled form first; +4. Although we are working on it, it is heavily feature restricted (for instance, at the time of writing, generation does not stop if an EOS token is selected). + diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 67d8e150d13c61..01d6f2532520b0 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -9,7 +9,7 @@ from packaging import version from .configuration_utils import PretrainedConfig -from .utils import is_hqq_available, is_quanto_available, logging +from .utils import is_hqq_available, is_quanto_available, is_torchdynamo_compiling, logging if is_quanto_available(): @@ -820,11 +820,13 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) for _ in range(config.num_hidden_layers): # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. + # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case + # it is not needed anyway) new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) + if not is_torchdynamo_compiling(): + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 702a116dd60e6e..dd0090e280d014 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1451,7 +1451,14 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l if hasattr(self.config, "_pre_quantization_dtype"): cache_dtype = self.config._pre_quantization_dtype else: - cache_dtype = self.dtype + if not is_torchdynamo_compiling(): + cache_dtype = self.dtype + else: + # NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`. + # Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative + # models. May cause trobles with non-text modalities. + cache_dtype = self.lm_head.weight.dtype + cache_kwargs = { "config": self.config, "max_batch_size": max_batch_size, @@ -1756,6 +1763,12 @@ def generate( ) use_dynamic_cache_by_default = False + if model_kwargs.get("past_key_values") is not None and is_torchdynamo_compiling(): + raise ValueError( + "Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you " + "may get incorrect outputs. Please compile `model.forward` only or use the `cache_implementation` " + "input argument." + ) if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: raise ValueError( "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 99057d0a257a44..db3f0f05185db6 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1760,7 +1760,9 @@ def test_generate_compile_fullgraph(self): output_static = model.generate(model_inputs, **generation_kwargs) self.assertListEqual(output_dynamic.tolist(), output_static.tolist()) - # compiled static cache + # compiled static cache (removes the cache initialized in the previous check, to confirm we can + # initialize the cache in full compiled mode) + model._cache = None torch.compiler.reset() generation_config = copy.deepcopy(model.generation_config) generation_config.update(**generation_kwargs)