diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 664d51a6713c29..ddb16d44893af3 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -103,7 +103,7 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True) ['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2'] ``` -Be mindful that full [`~GenerationMixin.generate`] compilation has severe feature limitations, and is still under development. It can, however, be compiled without graph breaks. +Be mindful that full [`~GenerationMixin.generate`] compilation has severe feature limitations, and is still under development. For instance, all parameterization has to be done through `generation_config`. It can, however, be compiled without graph breaks. If you want to go further down a level, the [`StaticCache`] object can also be passed to the model's forward pass under the same `past_key_values` argument. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 98b825e78afea4..ec7c41fcb4b3cd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1195,13 +1195,12 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): """Performs validation related to the resulting generated length""" + # Can't throw warnings/exceptions during compilation + if is_torchdynamo_compiling(): + return + # 1. Max length warnings related to poor parameterization - if ( - not is_torchdynamo_compiling() - and has_default_max_length - and generation_config.max_new_tokens is None - and generation_config.max_length == 20 - ): + if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: # 20 is the default max_length of the generation config warnings.warn( f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "