Skip to content

Commit

Permalink
PR suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed May 29, 2024
1 parent ee9a498 commit bb27eb4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2']
```

Be mindful that full [`~GenerationMixin.generate`] compilation has severe feature limitations, and is still under development. It can, however, be compiled without graph breaks.
Be mindful that full [`~GenerationMixin.generate`] compilation has severe feature limitations, and is still under development. For instance, all parameterization has to be done through `generation_config`. It can, however, be compiled without graph breaks.


If you want to go further down a level, the [`StaticCache`] object can also be passed to the model's forward pass under the same `past_key_values` argument. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens.
Expand Down
11 changes: 5 additions & 6 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,13 +1195,12 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
"""Performs validation related to the resulting generated length"""

# Can't throw warnings/exceptions during compilation
if is_torchdynamo_compiling():
return

# 1. Max length warnings related to poor parameterization
if (
not is_torchdynamo_compiling()
and has_default_max_length
and generation_config.max_new_tokens is None
and generation_config.max_length == 20
):
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
# 20 is the default max_length of the generation config
warnings.warn(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
Expand Down

0 comments on commit bb27eb4

Please sign in to comment.