diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 5e49f0e1ebd3ab..8e7e9c54d42a42 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -18,59 +18,109 @@ Basic inference is slow because LLMs have to be called repeatedly to generate th This guide will show you how to use the optimization techniques available in Transformers to accelerate LLM inference. > [!TIP] -> Hugging Face also provides [Text Generation Inference (TGI)](https://hf.co/docs/text-generation-inference), a library dedicated to deploying and serving highly optimized LLMs for inference. It includes more optimization features not included in Transformers, such as continuous batching for increasing throughput and tensor parallelism for multi-GPU inference. +> Hugging Face also provides [Text Generation Inference (TGI)](https://hf.co/docs/text-generation-inference), a library dedicated to deploying and serving highly optimized LLMs for inference. It includes deployment-oriented optimization features not included in Transformers, such as continuous batching for increasing throughput and tensor parallelism for multi-GPU inference. -## Static kv-cache and torch.compile +## Static kv-cache and `torch.compile` During decoding, a LLM computes the key-value (kv) values for each input token and since it is autoregressive, it computes the same kv values each time because the generated output becomes part of the input now. This is not very efficient because you're recomputing the same kv values each time. -To optimize this, you can use a kv-cache to store the past keys and values instead of recomputing them each time. However, since the kv-cache grows with each generation step and is dynamic, it prevents you from taking advantage of [torch.compile](./perf_torch_compile), a powerful optimization tool that fuses PyTorch code into fast and optimized kernels. +To optimize this, you can use a kv-cache to store the past keys and values instead of recomputing them each time. However, since the kv-cache grows with each generation step and is dynamic, it prevents you from taking advantage of [`torch.compile`](./perf_torch_compile), a powerful optimization tool that fuses PyTorch code into fast and optimized kernels. -The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up. +The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with `torch.compile` for up to a 4x speed up. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware. > [!WARNING] -> Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and torch.compile. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list. +> Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and `torch.compile`. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list. -For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model. +There are three flavors of static kv-cache usage, depending on the complexity of your task: +1. Basic usage: simply set a flag in `generation_config` (recommended); +2. Advanced usage: handle a cache object for multi-turn generation or a custom generation loop; +3. Advanced usage: compile the entire `generate` function into a single graph, if having a single graph is relevant for you. + +Select the correct tab below for further instructions on each of these flavors. + +> [!TIP] +> Regardless of the strategy used with `torch.compile`, you can avoid shape-related recompilations if you left-pad your LLM inputs to a limited set of values. The [`pad_to_multiple_of` tokenizer flag](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.pad_to_multiple_of) is your friend! + + + + +For this example, let's use the [Gemma](https://hf.co/google/gemma-2b) model. All we need to do is to: +1. Access the model's `generation_config` attribute and set the `cache_implementation` to "static"; +2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache. + +And that's it! ```py from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :) tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") -model = AutoModelForCausalLM.from_pretrained( - "google/gemma-2b", device_map="auto" -) +model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto") + +model.generation_config.cache_implementation = "static" + +model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) +input_text = "The theory of special relativity states " +input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") + +outputs = model.generate(**input_ids) +print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) +['The theory of special relativity states 1. The speed of light is constant in all inertial reference'] ``` -There are two ways you can configure the model to use a static kv-cache. For a 7B model on an A100, both methods get a 4x speed up in the forward pass. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware. If you're using the [`~GenerationMixin.generate`] method, the speed up is ~3x. The forward pass (which still gets 4x speed up) is only a part of the whole [`~GenerationMixin.generate`] code. +Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. Avoiding re-compilation is critical to get the most out of `torch.compile`, and you should be aware of the following: +1. If the batch size changes or the maximum output length increases between calls, the cache will have to be reinitialized, triggering a new compilation; +2. The first couple of calls of the compiled function are slower, as the function is being compiled. - - +> [!WARNING] +> For a more advanced usage of the static cache, such as multi-turn conversations, we recommend instantiating and manipulating the cache object outside [`~GenerationMixin.generate`]. See the advanced usage tab. + + + -Access the model's `generation_config` attribute and set the `cache_implementation` to "static". +A [`StaticCache`] object can be passed to the model's [`~GenerationMixin.generate`] under the `past_key_values` argument. The object will retain the cache contents, so you can pass it to a new [`~GenerationMixin.generate`] call to continue generation, like you would do with a dynamic cache. ```py -model.generation_config.cache_implementation = "static" -``` +from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache +import torch +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :) -Call torch.compile on the model to compile the forward pass with the static kv-cache. +tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") +model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto") -```py -compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) +model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) input_text = "The theory of special relativity states " input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") +prompt_length = input_ids.input_ids.shape[1] +model.generation_config.max_new_tokens = 16 + +past_key_values = StaticCache( + config=model.config, + max_batch_size=1, + # If you plan to reuse the cache, make sure the cache length is large enough for all cases + max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2), + device=model.device, + dtype=model.dtype +) +outputs = model.generate(**input_ids, past_key_values=past_key_values) +print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) +['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2'] -outputs = compiled_model.generate(**input_ids) -tokenizer.batch_decode(outputs, skip_special_tokens=True) -['The theory of special relativity states 1. The speed of light is constant in all inertial reference'] +# pass in the generated text and the same cache object to continue generation from where it left off. Optionally, in a +# multi-turn conversation, append the new user input to the generated text. +new_input_ids = outputs +outputs = model.generate(new_input_ids, past_key_values=past_key_values) +print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) +['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2. The speed of light is constant in all inertial reference frames. 3.'] ``` -Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increase between calls, the cache will have to be reinitialized, triggering a new compilation. - - - +> [!TIP] +> If you want to reuse the same [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method between calls -A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. You can also pass the [`StaticCache`] object to [`~GenerationMixin.generate`] and use it across calls, like you would do with a dynamic cache. +If you want to go further down a level, the [`StaticCache`] object can also be passed to the model's forward pass under the same `past_key_values` argument. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. ```py from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging @@ -102,12 +152,9 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu return new_token ``` -There are a few important things you must do to enable static kv-cache and torch.compile with the `StaticCache` method: - +There are a few important things you must do to enable static kv-cache and `torch.compile` with the `StaticCache` method: 1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length. - -2. Call torch.compile on the model to compile the forward pass with the static kv-cache. - +2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache. 3. Set `enable_math=True` in the [torch.backends.cuda.sdp_kernel](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more. ```py @@ -142,8 +189,34 @@ text 'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p'] ``` -> [!TIP] -> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method + + + +Compiling the entire `generate` function, in terms of code, is even simpler than in the basic usage: call `torch.compile` on `generate` to compile the entire function. No need to specify the use of the static cache: although it is compatible, dynamic cache (default) was faster in our benchmarks. + +```py +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :) + +tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") +model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto") + +model.generate = torch.compile(model.generate, mode="reduce-overhead", fullgraph=True) +input_text = "The theory of special relativity states " +input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") + +outputs = model.generate(**input_ids) +print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) +['The theory of special relativity states 1. The speed of light is constant in all inertial reference'] +``` + +As a result, we compile not only the model forward pass, but also all input preparation, logit processor operations, and so on. The result should be a slightly `generate` call, compared to the basic usage example, and the compiled graph may be better suited to more exotic hardware devices or use cases. However, there are severe drawbacks in using this approach: +1. Compilation is much slower; +2. All parameterization of `generate` must be done through `generation_config`; +3. Many warnings and exceptions are suppressed -- we suggest testing with its uncompiled form first; +4. Although we are working on it, it is heavily feature restricted (for instance, at the time of writing, generation does not stop if an EOS token is selected). diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 0c03ea2735db3f..9664ea49cb8f74 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -9,7 +9,7 @@ from packaging import version from .configuration_utils import PretrainedConfig -from .utils import is_hqq_available, is_quanto_available, logging +from .utils import is_hqq_available, is_quanto_available, is_torchdynamo_compiling, logging if is_quanto_available(): @@ -398,7 +398,6 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens def crop(self, max_length: int): """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" - # In case it is negative if max_length < 0: max_length = self.get_seq_length() - abs(max_length) @@ -821,11 +820,13 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) for _ in range(config.num_hidden_layers): # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. + # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case + # it is not needed anyway) new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) + if not is_torchdynamo_compiling(): + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b7bfeaf40d8c89..33d2bae9e4537c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1151,7 +1151,7 @@ def _validate_model_class(self): Confirms that the model class is compatible with generation. If not, raises an exception that points to the right class to use. """ - if not self.can_generate(): + if not is_torchdynamo_compiling() and not self.can_generate(): generate_compatible_mappings = [ MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, @@ -1254,6 +1254,10 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): """Performs validation related to the resulting generated length""" + # Can't throw warnings/exceptions during compilation + if is_torchdynamo_compiling(): + return + # 1. Max length warnings related to poor parameterization if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: # 20 is the default max_length of the generation config @@ -1383,20 +1387,12 @@ def _prepare_generation_config( self.generation_config = new_generation_config using_model_generation_config = True generation_config = self.generation_config + using_model_generation_config = True # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` - # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled. - if is_torchdynamo_compiling(): - model_kwargs = kwargs - generate_attributes_in_kwargs = [ - key for key, value in kwargs.items() if getattr(generation_config, key, None) != value - ] - if len(generate_attributes_in_kwargs) > 0: - raise ValueError( - "`torch.compile` exception: all generation configuration attributes must be passed within a " - f"`generation_config` instance passed to `generate` (found: {generate_attributes_in_kwargs})." - ) - else: + # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an + # exception will be raised in `_validate_model_kwargs` + if not is_torchdynamo_compiling(): generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model @@ -1409,30 +1405,40 @@ def _prepare_generation_config( generation_config.pad_token_id = self.generation_config.pad_token_id if generation_config.decoder_start_token_id is None: generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id + else: + model_kwargs = kwargs return generation_config, model_kwargs def _get_initial_cache_position(self, input_ids, model_kwargs): """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" + # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` + if "inputs_embeds" in model_kwargs: + cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + else: + cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 + past_length = 0 if model_kwargs.get("past_key_values") is not None: cache = model_kwargs["past_key_values"] + past_length = 0 if not isinstance(cache, Cache): past_length = cache[0][0].shape[2] elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: past_length = cache.get_seq_length() - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] - else: - cur_len = input_ids.shape[-1] - model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) + # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty, + # end-to-end compilation will yield bad results because `cache_position` will be incorrect. + if not is_torchdynamo_compiling(): + cache_position = cache_position[past_length:] + + model_kwargs["cache_position"] = cache_position return model_kwargs def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache: """ Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a - new `generate` call requires a larger cache. + new `generate` call requires a larger cache or uses a different batch size. Returns the resulting cache object. """ @@ -1465,7 +1471,14 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l if hasattr(self.config, "_pre_quantization_dtype"): cache_dtype = self.config._pre_quantization_dtype else: - cache_dtype = self.dtype + if not is_torchdynamo_compiling(): + cache_dtype = self.dtype + else: + # NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`. + # Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative + # models. May cause trobles with non-text modalities. + cache_dtype = self.lm_head.weight.dtype + cache_kwargs = { "config": self.config, "max_batch_size": max_batch_size, @@ -1542,27 +1555,29 @@ def _tensor_or_none(token, device=None): pad_token_tensor = eos_token_tensor[0] logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") - # we can't infer attn mask if pad token is set to be eos token in model's generation config - if eos_token_tensor is not None and pad_token_tensor in eos_token_tensor: - if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: - logger.warning_once( - "The attention mask is not set and cannot be inferred from input because pad token is same as eos token." - "As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` " - "to obtain reliable results." - ) - # Sanity checks/warnings if self.config.is_encoder_decoder and decoder_start_token_tensor is None: raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) - if eos_token_tensor is not None and ( - torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any() - ): - logger.warning( - f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation will not " - "stop until the maximum length is reached. Depending on other flags, it may even crash." - ) + if not is_torchdynamo_compiling(): # Checks that depend on tensor-dependent control flow + if ( + eos_token_tensor is not None + and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any() + ): + if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: + logger.warning_once( + "The attention mask is not set and cannot be inferred from input because pad token is same as " + "eos token. As a consequence, you may observe unexpected behavior. Please pass your input's " + "`attention_mask` to obtain reliable results." + ) + if eos_token_tensor is not None and ( + torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any() + ): + logger.warning( + f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation " + "will not stop until the maximum length is reached. Depending on other flags, it may even crash." + ) # Update generation config with the updated special tokens tensors # NOTE: this must be written into a different attribute name than the one holding the original special tokens @@ -1771,6 +1786,12 @@ def generate( cache_name = "cache_params" else: cache_name = "past_key_values" + if (model_kwargs.get(cache_name) is not None) and is_torchdynamo_compiling(): + raise ValueError( + "Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you " + "may get incorrect outputs. Please compile `model.forward` only or use the `cache_implementation` " + "input argument." + ) if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None): raise ValueError( f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " @@ -1847,7 +1868,7 @@ def generate( "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." ) - if self.device.type != input_ids.device.type: + if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type: warnings.warn( "You are calling .generate() with the `input_ids` being on a device type different" f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" @@ -2144,23 +2165,36 @@ def typeerror(): result.past_key_values = result.past_key_values.to_legacy_cache() return result - def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool: + def _has_unfinished_sequences( + self, + this_peer_finished: bool, + synced_gpus: bool, + device: torch.device, + cur_len: Optional[int] = None, + max_length: Optional[int] = None, + ) -> bool: """ Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is fed through `this_peer_finished`. ZeRO stage 3-friendly. """ - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: + # torch.compile does not support data-dependent control flow. This is a workaround to allow torch.compile, + # although we lose the ability to stop when all sequences return an EOS token (and other stopping criteria) + # TODO (joao): remove this when torch's support for control flow is not experimental (https://pytorch.org/docs/stable/generated/torch.cond.html) + if is_torchdynamo_compiling(): + return cur_len < max_length + else: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + return False + elif this_peer_finished: return False - elif this_peer_finished: - return False - return True + return True def heal_tokens( self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None @@ -2893,6 +2927,7 @@ def _sample( output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): @@ -2916,12 +2951,14 @@ def _sample( ) # keep track of which sequences are already finished - batch_size = input_ids.shape[0] + batch_size, cur_len = input_ids.shape this_peer_finished = False unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + while self._has_unfinished_sequences( + this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length + ): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -2967,6 +3004,7 @@ def _sample( # token selection if do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) @@ -2987,6 +3025,7 @@ def _sample( unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 + cur_len += 1 # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration diff --git a/src/transformers/models/dbrx/configuration_dbrx.py b/src/transformers/models/dbrx/configuration_dbrx.py index 91f4fc3a4b1c9f..a040a649d9d8e3 100644 --- a/src/transformers/models/dbrx/configuration_dbrx.py +++ b/src/transformers/models/dbrx/configuration_dbrx.py @@ -249,6 +249,7 @@ def __init__( self.use_cache = use_cache self.initializer_range = initializer_range self.output_router_logits = output_router_logits + self.num_key_value_heads = self.attn_config.kv_n_heads tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) if tie_word_embeddings: diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index edfc9519963bee..608e278ecfe808 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -513,7 +513,10 @@ def require_read_token(fn): @wraps(fn) def _inner(*args, **kwargs): - with patch("huggingface_hub.utils._headers.get_token", return_value=token): + if token is not None: + with patch("huggingface_hub.utils._headers.get_token", return_value=token): + return fn(*args, **kwargs) + else: # Allow running locally with the default token env variable return fn(*args, **kwargs) return _inner diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index c52da62c1de8e5..168d8b5d9c98af 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -670,18 +670,20 @@ def is_torch_compile_available(): def is_torchdynamo_compiling(): if not is_torch_available(): return False + + # Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622) + # hence rather relying on `torch.compiler.is_compiling()` when possible (torch>=2.3) try: - # Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622) hence rather relying on `torch.compiler.is_compiling()` when possible. - if version.parse(_torch_version) >= version.parse("2.3.0"): - import torch + import torch - return torch.compiler.is_compiling() - else: + return torch.compiler.is_compiling() + except AttributeError: + try: import torch._dynamo as dynamo # noqa: F401 return dynamo.is_compiling() - except Exception: - return False + except Exception: + return False def is_torch_tensorrt_fx_available(): diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 2c440bbd71ae66..4a3e9ce872de7f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1803,6 +1803,58 @@ def test_generate_with_quant_cache(self): with self.assertRaises(ValueError): model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + @require_torch_gpu + @slow + @is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky + def test_generate_compile_fullgraph(self): + """ + Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. + ⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️ + """ + for model_class in self.all_generative_model_classes: + if not model_class._supports_static_cache: + self.skipTest("This model doesn't support static cache") + # TODO (joao) -- fix and enable me :) + if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]): + self.skipTest("whisper model end-to-end generate compile not yet supported") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + # TODO (joao) -- fix and enable me :) + if config.is_encoder_decoder: + self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported") + + model = model_class(config).to(torch_device) + input_ids = inputs_dict["input_ids"].to(torch_device) + # creates two sets of *different* inputs with the same shape + half_batch_size = input_ids.shape[0] // 2 + input_ids_sets = [input_ids[:half_batch_size, :], input_ids[half_batch_size : half_batch_size * 2, :]] + self.assertTrue(input_ids_sets[0].shape == input_ids_sets[1].shape) + + generation_kwargs = { + "do_sample": False, + "max_new_tokens": 10, + } + + for model_inputs in input_ids_sets: + # dynamic cache + output_dynamic = model.generate(model_inputs, **generation_kwargs) + + # eager static cache + torch.compiler.reset() + model.generation_config.cache_implementation = "static" + output_static = model.generate(model_inputs, **generation_kwargs) + self.assertListEqual(output_dynamic.tolist(), output_static.tolist()) + + # compiled static cache (removes the cache initialized in the previous check, to confirm we can + # initialize the cache in full compiled mode) + model._cache = None + torch.compiler.reset() + generation_config = copy.deepcopy(model.generation_config) + generation_config.update(**generation_kwargs) + compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") + output_compiled = compiled_generate(model_inputs, generation_config=generation_config) + self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist()) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape num_sequences_in_output = batch_size * num_return_sequences diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py index 4e685411a04144..16e0a548e6dc47 100644 --- a/tests/models/chameleon/test_modeling_chameleon.py +++ b/tests/models/chameleon/test_modeling_chameleon.py @@ -370,6 +370,11 @@ def test_flash_attn_2_generate_padding_right(self): def test_batching_equivalence(self): pass + # TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow + @unittest.skip("Chameleon is not compatible with end-to-end generation compilation") + def test_generate_compile_fullgraph(self): + pass + @require_torch class ChameleonIntegrationTest(unittest.TestCase): diff --git a/tests/models/dbrx/test_modeling_dbrx.py b/tests/models/dbrx/test_modeling_dbrx.py index 06c82c949cb3d1..d38a479ab36e42 100644 --- a/tests/models/dbrx/test_modeling_dbrx.py +++ b/tests/models/dbrx/test_modeling_dbrx.py @@ -368,6 +368,10 @@ def test_disk_offload_safetensors(self): def test_disk_offload_bin(self): pass + @unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.") + def test_generate_compile_fullgraph(self): + pass + @require_torch class DbrxModelIntegrationTest(unittest.TestCase): diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index c2128b99e80a94..00d5189de47fdc 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -31,6 +31,7 @@ import transformers from transformers import WhisperConfig from transformers.testing_utils import ( + is_flaky, is_pt_flax_cross_test, require_flash_attn, require_torch, @@ -1785,6 +1786,7 @@ def test_constrained_beam_search_generate_dict_output(self): output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"] ) + @is_flaky() # TODO (joao, sanchit): fails ~9% of the times. Does the original test have the same issue? def test_custom_4d_attention_mask(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 98db42cfebb194..77250739bb296d 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -143,7 +143,7 @@ def _random_kvs(config): mha_config = LlamaConfig(num_attention_heads=32) mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mha_static_cache.update( - *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1)} + *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) self.assertTrue(cached_keys.shape == (1, 32, 10, 128)) self.assertTrue(cached_values.shape == (1, 32, 10, 128)) @@ -151,7 +151,7 @@ def _random_kvs(config): gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = gqa_static_cache.update( - *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)} + *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) self.assertTrue(cached_keys.shape == (1, 4, 10, 128)) self.assertTrue(cached_values.shape == (1, 4, 10, 128)) @@ -159,7 +159,7 @@ def _random_kvs(config): mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mqa_static_cache.update( - *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)} + *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) self.assertTrue(cached_keys.shape == (1, 1, 10, 128)) self.assertTrue(cached_values.shape == (1, 1, 10, 128))