diff --git a/docs/source/en/llm_tutorial.md b/docs/source/en/llm_tutorial.md index bb42fb4633bf69..8d6372e129cc47 100644 --- a/docs/source/en/llm_tutorial.md +++ b/docs/source/en/llm_tutorial.md @@ -250,7 +250,7 @@ While the autoregressive generation process is relatively straightforward, makin 1. [Guide](generation_strategies) on how to control different generation methods, how to set up the generation configuration file, and how to stream the output; 2. [Guide](chat_templating) on the prompt template for chat LLMs; 3. [Guide](tasks/prompting) on to get the most of prompt design; -4. API reference on [`~generation.GenerationConfig`], [`~generation.GenerationMixin.generate`], and [generate-related classes](internal/generation_utils). +4. API reference on [`~generation.GenerationConfig`], [`~generation.GenerationMixin.generate`], and [generate-related classes](internal/generation_utils). Most of the classes, including the logits processors, have usage examples! ### LLM leaderboards diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 7a494bf3733a0e..4818ca8d97b7f1 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -63,6 +63,14 @@ class GenerationConfig(PushToHubMixin): You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). + + + A large number of these flags control the logits or the stopping criteria of the generation. Make sure you check + the [generate-related classes](https://huggingface.co/docs/transformers/internal/generation_utils) for a full + description of the possible manipulations, as well as examples of their usage. + + + Arg: > Parameters that control the length of the output diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index d5f65161ecd763..7f4415dd0dbe84 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -100,13 +100,39 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa class MinLengthLogitsProcessor(LogitsProcessor): r""" - [`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0. + [`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0. Note that, for decoder-only models + like most LLMs, the length includes the prompt. Args: min_length (`int`): The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. eos_token_id (`Union[int, List[int]]`): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + + Examples: + + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + + >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m") + >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m") + + >>> inputs = tokenizer("A number:", return_tensors="pt") + >>> gen_out = model.generate(**inputs) + >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0]) + A number: one + + >>> # setting `min_length` to a value smaller than the uncontrolled output length has no impact + >>> gen_out = model.generate(**inputs, min_length=3) + >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0]) + A number: one + + >>> # setting a larger `min_length` will force the model to generate beyond its natural ending point, which is not + >>> # necessarily incorrect + >>> gen_out = model.generate(**inputs, min_length=10) + >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0]) + A number: one thousand, nine hundred and ninety-four + ``` """ def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]): @@ -133,9 +159,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class MinNewTokensLengthLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0. - Note that for decoder-only models, such as Llama2, `min_length` will compute the length of `prompt + newly - generated tokens` whereas for other models it will behave as `min_new_tokens`, that is, taking only into account - the newly generated ones. + Contrarily to [`MinLengthLogitsProcessor`], this processor ignores the prompt. Args: prompt_length_to_skip (`int`): @@ -149,29 +173,21 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): Examples: ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM + >>> from transformers import AutoModelForCausalLM, AutoTokenizer - >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") - >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") - >>> model.config.pad_token_id = model.config.eos_token_id - >>> inputs = tokenizer(["Hugging Face Company is"], return_tensors="pt") + >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m") + >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m") - >>> # If the maximum length (default = 20) is smaller than the minimum length constraint, the latter is ignored! - >>> outputs = model.generate(**inputs, min_new_tokens=30) - >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) - Hugging Face Company is a company that has been working on a new product for the past year. - - >>> # For testing purposes, let's set `eos_token` to `"company"`, the first generated token. This will make - >>> # generation end there. - >>> outputs = model.generate(**inputs, eos_token_id=1664) - >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) - Hugging Face Company is a company - - >>> # Increasing `min_new_tokens` will make generation ignore occurences `"company"` (eos token) before the - >>> # minimum length condition is honored. - >>> outputs = model.generate(**inputs, min_new_tokens=2, eos_token_id=1664) - >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) - Hugging Face Company is a new company + >>> inputs = tokenizer(["A number:"], return_tensors="pt") + >>> gen_out = model.generate(**inputs) + >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0]) + A number: one + + >>> # setting `min_new_tokens` will force the model to generate beyond its natural ending point, which is not + >>> # necessarily incorrect + >>> gen_out = model.generate(**inputs, min_new_tokens=2) + >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0]) + A number: one thousand ``` """ @@ -205,7 +221,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class TemperatureLogitsWarper(LogitsWarper): r""" [`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means - that it can control the randomness of the predicted tokens. + that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and + [`TopKLogitsWarper`]. @@ -269,22 +286,18 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class RepetitionPenaltyLogitsProcessor(LogitsProcessor): r""" - [`LogitsProcessor`] that prevents the repetition of previous tokens through an exponential penalty. This technique - shares some similarities with coverage mechanisms and other aimed at reducing repetition. During the text - generation process, the probability distribution for the next token is determined using a formula that incorporates - token scores based on their occurrence in the generated sequence. Tokens with higher scores are more likely to be - selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the - paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition. - - This technique can also be used to reward and thus encourage repetition in a similar manner. To penalize and reduce + [`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at + most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt. + + In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around + 1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly. Args: penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated - tokens. Between 0.0 and 1.0 rewards previously generated tokens. See [this - paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + tokens. Between 0.0 and 1.0 rewards previously generated tokens. Examples: @@ -327,20 +340,39 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor): r""" - [`LogitsProcessor`] that avoids hallucination by boosting the probabilities of tokens found within the original - input. + [`LogitsProcessor`] that works similarly to [`RepetitionPenaltyLogitsProcessor`], but with an *inverse* penalty + that is applied to the tokens present in the prompt. In other words, a penalty above 1.0 increases the odds of + selecting tokens that were present in the prompt. - This technique can also be used to reward and thus encourage hallucination (or creativity) in a similar manner. To - penalize and reduce hallucination, use `penalty` values above 1.0, where a higher value penalizes more strongly. To - reward and encourage hallucination, use `penalty` values between 0.0 and 1.0, where a lower value rewards more - strongly. + It was designed to avoid hallucination in input-grounded tasks, like summarization. Although originally intended + for encoder-decoder models, it can also be used with decoder-only models like LLMs. Args: penalty (`float`): - The parameter for hallucination penalty. 1.0 means no penalty. Above 1.0 penalizes hallucination. Between - 0.0 and 1.0 rewards hallucination. + The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 rewards prompt tokens. Between 0.0 + and 1.0 penalizes prompt tokens. encoder_input_ids (`torch.LongTensor`): The encoder_input_ids that should be repeated within the decoder ids. + + Examples: + + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + + >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m") + >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m") + + >>> inputs = tokenizer(["Alice and Bob. The third member's name was"], return_tensors="pt") + >>> gen_out = model.generate(**inputs) + >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0]) + Alice and Bob. The third member's name was not mentioned. + + >>> # With the `encoder_repetition_penalty` argument we can trigger this logits processor in `generate`, which can + >>> # promote the use of prompt tokens ("Bob" in this example) + >>> gen_out = model.generate(**inputs, encoder_repetition_penalty=1.2) + >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0]) + Alice and Bob. The third member's name was Bob. The third member's name was Bob. + ``` """ def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor): @@ -363,7 +395,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class TopPLogitsWarper(LogitsWarper): """ - [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. + [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. Often + used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`]. Args: top_p (`float`): @@ -375,6 +408,7 @@ class TopPLogitsWarper(LogitsWarper): Minimum number of tokens that cannot be filtered. Examples: + ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed @@ -426,7 +460,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class TopKLogitsWarper(LogitsWarper): r""" - [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. + [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. Often used together + with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`]. Args: top_k (`int`): @@ -435,6 +470,29 @@ class TopKLogitsWarper(LogitsWarper): All filtered values will be set to this float value. min_tokens_to_keep (`int`, *optional*, defaults to 1): Minimum number of tokens that cannot be filtered. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed + + >>> set_seed(0) + >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") + + >>> inputs = tokenizer("A sequence: A, B, C, D", return_tensors="pt") + + >>> # With sampling, the output is unexpected -- sometimes too unexpected. + >>> outputs = model.generate(**inputs, do_sample=True) + >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + A sequence: A, B, C, D, G, H, I. A, M + + >>> # With `top_k` sampling, the output gets restricted the k most likely tokens. + >>> # Pro tip: In practice, LLMs use `top_k` in the 5-50 range. + >>> outputs = model.generate(**inputs, do_sample=True, top_k=2) + >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + A sequence: A, B, C, D, E, F, G, H, I + ``` """ def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): @@ -455,8 +513,11 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class TypicalLogitsWarper(LogitsWarper): r""" - [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language - Generation](https://arxiv.org/abs/2202.00666) for more information. + [`LogitsWarper`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens whose + log probability is close to the entropy of the token probability distribution. This means that the most likely + tokens may be discarded in the process. + + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information. Args: mass (`float`, *optional*, defaults to 0.9): @@ -465,6 +526,42 @@ class TypicalLogitsWarper(LogitsWarper): All filtered values will be set to this float value. min_tokens_to_keep (`int`, *optional*, defaults to 1): Minimum number of tokens that cannot be filtered. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed + + >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m") + >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m") + + >>> inputs = tokenizer("1, 2, 3", return_tensors="pt") + + >>> # We can see that greedy decoding produces a sequence of numbers + >>> outputs = model.generate(**inputs) + >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + + >>> # For this particular seed, we can see that sampling produces nearly the same low-information (= low entropy) + >>> # sequence + >>> set_seed(18) + >>> outputs = model.generate(**inputs, do_sample=True) + >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + 1, 2, 3, 4, 5, 6, 7, 8, 9 and 10 + + >>> # With `typical_p` set, the most obvious sequence is no longer produced, which may be good for your problem + >>> set_seed(18) + >>> outputs = model.generate( + ... **inputs, do_sample=True, typical_p=0.1, return_dict_in_generate=True, output_scores=True + ... ) + >>> print(tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0]) + 1, 2, 3 and 5 + + >>> # We can see that the token corresponding to "4" (token 934) in the second position, the most likely token + >>> # as seen with greedy decoding, was entirely blocked out + >>> print(outputs.scores[1][0, 934]) + tensor(-inf) + ``` """ def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): @@ -721,7 +818,8 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor): sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). In text generation, avoiding repetitions of word sequences provides a more diverse output. This [`LogitsProcessor`] enforces no repetition of n-grams by setting the scores of banned tokens to negative infinity which eliminates those tokens - from consideration when further processing the scores. + from consideration when further processing the scores. Note that, for decoder-only models like most LLMs, the + prompt is also considered to obtain the n-grams. [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345). @@ -774,14 +872,40 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor): r""" - [`LogitsProcessor`] that enforces no repetition of encoder input ids n-grams for the decoder ids. See - [ParlAI](https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/torch_generator_agent.py#L1350). + [`LogitsProcessor`] that works similarly to [`NoRepeatNGramLogitsProcessor`], but applied exclusively to prevent + the repetition of n-grams present in the prompt. + + It was designed to promote chattiness in a language model, by preventing the generation of n-grams present in + previous conversation rounds. Args: encoder_ngram_size (`int`): All ngrams of size `ngram_size` can only occur within the encoder input ids. encoder_input_ids (`int`): The encoder_input_ids that should not be repeated within the decoder ids. + + Examples: + + ```py + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m") + >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m") + + >>> inputs = tokenizer("Alice: I love cats. What do you love?\nBob:", return_tensors="pt") + + >>> # With greedy decoding, we see Bob repeating Alice's opinion. If Bob was a chatbot, it would be a poor one. + >>> outputs = model.generate(**inputs) + >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + Alice: I love cats. What do you love? + Bob: I love cats. What do you + + >>> # With this logits processor, we can prevent Bob from repeating Alice's opinion. + >>> outputs = model.generate(**inputs, encoder_no_repeat_ngram_size=2) + >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + Alice: I love cats. What do you love? + Bob: My cats are very cute. + ``` """ def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor): @@ -1060,6 +1184,40 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor): arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID `batch_id`. + + Examples: + + ```py + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m") + >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m") + + >>> inputs = tokenizer("Alice and Bob", return_tensors="pt") + + >>> # By default, it continues generating according to the model's logits + >>> outputs = model.generate(**inputs, max_new_tokens=5) + >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + Alice and Bob are friends + + >>> # We can contrain it with `prefix_allowed_tokens_fn` to force a certain behavior based on a prefix. + >>> # For instance, we can force an entire entity to be generated when its beginning is detected. + >>> entity = tokenizer(" Bob Marley", return_tensors="pt").input_ids[0] # 3 tokens + >>> def prefix_allowed_tokens_fn(batch_id, input_ids): + ... ''' + ... Attempts to generate 'Bob Marley' when 'Bob' is detected. + ... In this case, `batch_id` is not used, but you can set rules for each batch member. + ... ''' + ... if input_ids[-1] == entity[0]: + ... return entity[1] + ... elif input_ids[-2] == entity[0] and input_ids[-1] == entity[1]: + ... return entity[2] + ... return list(range(tokenizer.vocab_size)) # If no match, allow all tokens + + >>> outputs = model.generate(**inputs, max_new_tokens=5, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn) + >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + Alice and Bob Marley + ``` """ def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int): @@ -1084,56 +1242,20 @@ class HammingDiversityLogitsProcessor(LogitsProcessor): Search: Decoding Diverse Solutions from Neural Sequence Models](https://arxiv.org/pdf/1610.02424.pdf) for more details. - - - Diverse beam search can be particularly useful in scenarios where a variety of different outputs is desired, rather - than multiple similar sequences. It allows the model to explore different generation paths and provides a broader - coverage of possible outputs. - - - - - - This logits processor can be resource-intensive, especially when using large models or long sequences. - - - Traditional beam search often generates very similar sequences across different beams. `HammingDiversityLogitsProcessor` addresses this by penalizing beams that generate tokens already chosen by other beams in the same time step. - How It Works: - - **Grouping Beams**: Beams are divided into groups. Each group selects tokens independently of the others. - - **Penalizing Repeated Tokens**: If a beam in a group selects a token already chosen by another group in the - same step, a penalty is applied to that token's score. - - **Promoting Diversity**: This penalty discourages beams within a group from selecting the same tokens as - beams in other groups. - - Benefits: - - **Diverse Outputs**: Produces a variety of different sequences. - - **Exploration**: Allows the model to explore different paths. - Args: diversity_penalty (`float`): This value is subtracted from a beam's score if it generates a token same as any beam from other group at a - particular time. Note that `diversity_penalty` is only effective if group beam search is enabled. The - penalty applied to a beam's score when it generates a token that has already been chosen by another beam - within the same group during the same time step. A higher `diversity_penalty` will enforce greater - diversity among the beams, making it less likely for multiple beams to choose the same token. Conversely, a - lower penalty will allow beams to more freely choose similar tokens. Adjusting this value can help strike a - balance between diversity and natural likelihood. + particular time. A higher `diversity_penalty` will enforce greater diversity among the beams. Adjusting + this value can help strike a balance between diversity and natural likelihood. num_beams (`int`): - Number of beams used for group beam search. Beam search is a method used that maintains beams (or "multiple - hypotheses") at each step, expanding each one and keeping the top-scoring sequences. A higher `num_beams` - will explore more potential sequences. This can increase chances of finding a high-quality output but also - increases computational cost. + Number of beams for beam search. 1 means no beam search. num_beam_groups (`int`): Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. - Each group of beams will operate independently, selecting tokens without considering the choices of other - groups. This division promotes diversity by ensuring that beams within different groups explore different - paths. For instance, if `num_beams` is 6 and `num_beam_groups` is 2, there will be 2 groups each containing - 3 beams. The choice of `num_beam_groups` should be made considering the desired level of output diversity - and the total number of beams. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. + [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. Examples: @@ -1146,7 +1268,13 @@ class HammingDiversityLogitsProcessor(LogitsProcessor): >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> # A long text about the solar system - >>> text = "The Solar System is a gravitationally bound system comprising the Sun and the objects that orbit it, either directly or indirectly. Of the objects that orbit the Sun directly, the largest are the eight planets, with the remainder being smaller objects, such as the five dwarf planets and small Solar System bodies. The Solar System formed 4.6 billion years ago from the gravitational collapse of a giant interstellar molecular cloud." + >>> text = ( + ... "The Solar System is a gravitationally bound system comprising the Sun and the objects that orbit it, " + ... "either directly or indirectly. Of the objects that orbit the Sun directly, the largest are the eight " + ... "planets, with the remainder being smaller objects, such as the five dwarf planets and small Solar System " + ... "bodies. The Solar System formed 4.6 billion years ago from the gravitational collapse of a giant " + ... "interstellar molecular cloud." + ... ) >>> inputs = tokenizer("summarize: " + text, return_tensors="pt") >>> # Generate diverse summary @@ -1241,11 +1369,34 @@ def __call__( class ForcedBOSTokenLogitsProcessor(LogitsProcessor): r""" - [`LogitsProcessor`] that enforces the specified token as the first generated token. + [`LogitsProcessor`] that enforces the specified token as the first generated token. Used with encoder-decoder + models. Args: bos_token_id (`int`): The id of the token to force as the first generated token. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + + >>> model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small") + >>> tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") + + >>> inputs = tokenizer("Translate from English to German: I love cats.", return_tensors="pt") + + >>> # By default, it continues generating according to the model's logits + >>> outputs = model.generate(**inputs, max_new_tokens=10) + >>> print(tokenizer.batch_decode(outputs)[0]) + Ich liebe Kitty. + + >>> # We can use `forced_bos_token_id` to force the start of generation with an encoder-decoder model + >>> # (including forcing it to end straight away with an EOS token) + >>> outputs = model.generate(**inputs, max_new_tokens=10, forced_bos_token_id=tokenizer.eos_token_id) + >>> print(tokenizer.batch_decode(outputs)[0]) + + ``` """ def __init__(self, bos_token_id: int): @@ -1271,6 +1422,27 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): eos_token_id (`Union[int, List[int]]`): The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a list to set multiple *end-of-sequence* tokens. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") + + >>> inputs = tokenizer("A sequence: 1, 2, 3", return_tensors="pt") + + >>> # By default, it continues generating according to the model's logits + >>> outputs = model.generate(**inputs, max_new_tokens=10) + >>> print(tokenizer.batch_decode(outputs)[0]) + A sequence: 1, 2, 3, 4, 5, 6, 7, 8 + + >>> # `forced_eos_token_id` ensures the generation ends with a EOS token + >>> outputs = model.generate(**inputs, max_new_tokens=10, forced_eos_token_id=tokenizer.eos_token_id) + >>> print(tokenizer.batch_decode(outputs)[0]) + A sequence: 1, 2, 3, 4, 5, 6, 7,<|endoftext|> + ``` """ def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]): @@ -1294,6 +1466,9 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] that removes all `nan` and `inf` values to avoid the generation method to fail. Note that using the logits processor should only be used if necessary since it can slow down the generation method. + + This logits processor has no `generate` example, as there shouldn't be a correct combination of flags that warrants + its use. """ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @@ -1405,6 +1580,29 @@ class LogitNormalization(LogitsProcessor, LogitsWarper): the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that the scores are normalized when comparing the hypotheses. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + >>> import torch + + >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") + + >>> inputs = tokenizer("A sequence: 1, 2, 3", return_tensors="pt") + + >>> # By default, the scores are not normalized -- the sum of their exponentials is NOT a normalized probability + >>> # distribution, summing to 1 + >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) + >>> print(torch.sum(torch.exp(outputs.scores[-1]))) + tensor(816.3250) + + >>> # Normalizing them may have a positive impact on beam methods, or when using the scores on your application + >>> outputs = model.generate(**inputs, renormalize_logits=True, return_dict_in_generate=True, output_scores=True) + >>> print(torch.sum(torch.exp(outputs.scores[-1]))) + tensor(1.0000) + ``` """ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @@ -1416,8 +1614,36 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): r""" [`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts - generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not - sampled at the begining of the generation. + generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are + not generated at the begining. Originally created for + [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper). + + Examples: + + ```python + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + + >>> # Whisper has `begin_suppress_tokens` set by default (= `[220, 50256]`). 50256 is the EOS token, so this means + >>> # it can't generate and EOS token in the first iteration, but it can in the others. + >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) + >>> print(outputs.scores[1][0, 50256]) # 1 (and not 0) is the first freely generated token + tensor(-inf) + >>> print(outputs.scores[-1][0, 50256]) # in other places we can see some probability mass for EOS + tensor(29.9010) + + >>> # If we disable `begin_suppress_tokens`, we can generate EOS in the first iteration. + >>> outputs = model.generate( + ... **inputs, return_dict_in_generate=True, output_scores=True, begin_suppress_tokens=None + ... ) + >>> print(outputs.scores[1][0, 50256]) + tensor(11.2027) + ``` """ def __init__(self, begin_suppress_tokens, begin_index): @@ -1433,8 +1659,33 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class SuppressTokensLogitsProcessor(LogitsProcessor): - r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they - are not sampled.""" + r""" + This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so + that they are not generated. Originally created for + [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper). + + Examples: + + ```python + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + + >>> # Whisper has a long list of suppressed tokens. For instance, in this case, the token 1 is suppressed by default. + >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) + >>> print(outputs.scores[1][0, 1]) # 1 (and not 0) is the first freely generated token + tensor(-inf) + + >>> # If we disable `suppress_tokens`, we can generate it. + >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, suppress_tokens=None) + >>> print(outputs.scores[1][0, 1]) + tensor(5.7738) + ``` + """ def __init__(self, suppress_tokens): self.suppress_tokens = list(suppress_tokens) @@ -1446,9 +1697,42 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class ForceTokensLogitsProcessor(LogitsProcessor): - r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token - indices that will be forced before sampling. The processor will set their log probs to `inf` so that they are - sampled at their corresponding index.""" + r""" + This processor takes a list of pairs of integers which indicates a mapping from generation indices to token + indices that will be forced before generation. The processor will set their log probs to `inf` so that they are + sampled at their corresponding index. Originally created for + [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper). + + Examples: + ```python + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + + >>> # This Whisper model forces the generation to start with `50362` at the first position by default, i.e. + >>> # `"forced_decoder_ids": [[1, 50362]]`. This means all other tokens are masked out. + >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) + >>> print( + ... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362) + ... ) + True + >>> print(outputs.scores[0][0, 50362]) + tensor(0.) + + >>> # If we disable `forced_decoder_ids`, we stop seeing that effect + >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, forced_decoder_ids=None) + >>> print( + ... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362) + ... ) + False + >>> print(outputs.scores[0][0, 50362]) + tensor(19.3140) + ``` + """ def __init__(self, force_token_map: List[List[int]]): self.force_token_map = dict(force_token_map) @@ -1492,7 +1776,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): Examples: ``` python >>> import torch - >>> from transformers import AutoProcessor, WhisperForConditionalGeneration,GenerationConfig + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration, GenerationConfig >>> from datasets import load_dataset >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") @@ -1588,18 +1872,42 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): - r"""Logits processor for classifier free guidance (CFG). The scores are split over the batch dimension, + r""" + [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension, where the first half correspond to the conditional logits (predicted from the input prompt) and the second half correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`. See [the paper](https://arxiv.org/abs/2306.05284) for more information. + + + This logits processor is exclusivelly compatible with + [MusicGen](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen) + + + Args: guidance_scale (float): The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages the model to generate samples that are more closely linked to the input prompt, usually at the expense of poorer quality. + + Examples: + + ```python + >>> from transformers import AutoProcessor, MusicgenForConditionalGeneration + + >>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small") + >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") + + >>> inputs = processor( + ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], + ... padding=True, + ... return_tensors="pt", + ... ) + >>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256) + ``` """ def __init__(self, guidance_scale): @@ -1629,7 +1937,15 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class AlternatingCodebooksLogitsProcessor(LogitsProcessor): r""" - [`LogitsProcessor`] enforcing alternated generation between the two codebooks of [`Bark`]'s fine submodel. + [`LogitsProcessor`] enforcing alternated generation between the two codebooks of Bark. + + + + This logits processor is exclusivelly compatible with + [Bark](https://huggingface.co/docs/transformers/en/model_doc/bark)'s fine submodel. See the model documentation + for examples. + + Args: input_start_len (`int`): @@ -1664,10 +1980,10 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): - r"""Logits processor for Classifier-Free Guidance (CFG). The processors - computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits, - parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with - the `unconditional_ids` branch. + r""" + Logits processor for Classifier-Free Guidance (CFG). The processors computes a weighted average across scores + from prompt conditional and prompt unconditional (or negative) logits, parameterized by the `guidance_scale`. + The unconditional scores are computed internally by prompting `model` with the `unconditional_ids` branch. See [the paper](https://arxiv.org/abs/2306.17806) for more information. @@ -1784,6 +2100,13 @@ def __call__(self, input_ids, scores): class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`. + + + This logits processor is exclusivelly compatible with + [Bark](https://huggingface.co/docs/transformers/en/model_doc/bark). See the model documentation for examples. + + + Args: eos_token_id (`Union[int, List[int]]`): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7040b98dd91c10..89ab9886daee18 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1031,16 +1031,9 @@ def _get_logits_processor( generation_config.encoder_no_repeat_ngram_size is not None and generation_config.encoder_no_repeat_ngram_size > 0 ): - if self.config.is_encoder_decoder: - processors.append( - EncoderNoRepeatNGramLogitsProcessor( - generation_config.encoder_no_repeat_ngram_size, encoder_input_ids - ) - ) - else: - raise ValueError( - "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture" - ) + processors.append( + EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size, encoder_input_ids) + ) if generation_config.bad_words_ids is not None: processors.append( NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)