diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py
index 53c11ea169a00d..f74cd0fd019937 100644
--- a/src/transformers/generation/logits_process.py
+++ b/src/transformers/generation/logits_process.py
@@ -1242,56 +1242,20 @@ class HammingDiversityLogitsProcessor(LogitsProcessor):
Search: Decoding Diverse Solutions from Neural Sequence Models](https://arxiv.org/pdf/1610.02424.pdf) for more
details.
-
-
- Diverse beam search can be particularly useful in scenarios where a variety of different outputs is desired, rather
- than multiple similar sequences. It allows the model to explore different generation paths and provides a broader
- coverage of possible outputs.
-
-
-
-
-
- This logits processor can be resource-intensive, especially when using large models or long sequences.
-
-
-
Traditional beam search often generates very similar sequences across different beams.
`HammingDiversityLogitsProcessor` addresses this by penalizing beams that generate tokens already chosen by other
beams in the same time step.
- How It Works:
- - **Grouping Beams**: Beams are divided into groups. Each group selects tokens independently of the others.
- - **Penalizing Repeated Tokens**: If a beam in a group selects a token already chosen by another group in the
- same step, a penalty is applied to that token's score.
- - **Promoting Diversity**: This penalty discourages beams within a group from selecting the same tokens as
- beams in other groups.
-
- Benefits:
- - **Diverse Outputs**: Produces a variety of different sequences.
- - **Exploration**: Allows the model to explore different paths.
-
Args:
diversity_penalty (`float`):
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
- particular time. Note that `diversity_penalty` is only effective if group beam search is enabled. The
- penalty applied to a beam's score when it generates a token that has already been chosen by another beam
- within the same group during the same time step. A higher `diversity_penalty` will enforce greater
- diversity among the beams, making it less likely for multiple beams to choose the same token. Conversely, a
- lower penalty will allow beams to more freely choose similar tokens. Adjusting this value can help strike a
- balance between diversity and natural likelihood.
+ particular time. A higher `diversity_penalty` will enforce greater diversity among the beams. Adjusting
+ this value can help strike a balance between diversity and natural likelihood.
num_beams (`int`):
- Number of beams used for group beam search. Beam search is a method used that maintains beams (or "multiple
- hypotheses") at each step, expanding each one and keeping the top-scoring sequences. A higher `num_beams`
- will explore more potential sequences. This can increase chances of finding a high-quality output but also
- increases computational cost.
+ Number of beams for beam search. 1 means no beam search.
num_beam_groups (`int`):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
- Each group of beams will operate independently, selecting tokens without considering the choices of other
- groups. This division promotes diversity by ensuring that beams within different groups explore different
- paths. For instance, if `num_beams` is 6 and `num_beam_groups` is 2, there will be 2 groups each containing
- 3 beams. The choice of `num_beam_groups` should be made considering the desired level of output diversity
- and the total number of beams. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
+ [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
Examples:
@@ -1304,7 +1268,13 @@ class HammingDiversityLogitsProcessor(LogitsProcessor):
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> # A long text about the solar system
- >>> text = "The Solar System is a gravitationally bound system comprising the Sun and the objects that orbit it, either directly or indirectly. Of the objects that orbit the Sun directly, the largest are the eight planets, with the remainder being smaller objects, such as the five dwarf planets and small Solar System bodies. The Solar System formed 4.6 billion years ago from the gravitational collapse of a giant interstellar molecular cloud."
+ >>> text = (
+ ... "The Solar System is a gravitationally bound system comprising the Sun and the objects that orbit it, "
+ ... "either directly or indirectly. Of the objects that orbit the Sun directly, the largest are the eight "
+ ... "planets, with the remainder being smaller objects, such as the five dwarf planets and small Solar System "
+ ... "bodies. The Solar System formed 4.6 billion years ago from the gravitational collapse of a giant "
+ ... "interstellar molecular cloud."
+ ... )
>>> inputs = tokenizer("summarize: " + text, return_tensors="pt")
>>> # Generate diverse summary
@@ -1399,11 +1369,34 @@ def __call__(
class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
r"""
- [`LogitsProcessor`] that enforces the specified token as the first generated token.
+ [`LogitsProcessor`] that enforces the specified token as the first generated token. Used with encoder-decoder
+ models.
Args:
bos_token_id (`int`):
The id of the token to force as the first generated token.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
+
+ >>> model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
+
+ >>> inputs = tokenizer("Translate from English to German: I love cats.", return_tensors="pt")
+
+ >>> # By default, it continues generating according to the model's logits
+ >>> outputs = model.generate(**inputs, max_new_tokens=10)
+ >>> print(tokenizer.batch_decode(outputs)[0])
+ Ich liebe Kitty.
+
+ >>> # We can use `forced_bos_token_id` to force the start of generation with an encoder-decoder model
+ >>> # (including forcing it to end straight away with an EOS token)
+ >>> outputs = model.generate(**inputs, max_new_tokens=10, forced_bos_token_id=tokenizer.eos_token_id)
+ >>> print(tokenizer.batch_decode(outputs)[0])
+
+ ```
"""
def __init__(self, bos_token_id: int):
@@ -1429,6 +1422,27 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
eos_token_id (`Union[int, List[int]]`):
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
list to set multiple *end-of-sequence* tokens.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM
+
+ >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
+ >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
+
+ >>> inputs = tokenizer("A sequence: 1, 2, 3", return_tensors="pt")
+
+ >>> # By default, it continues generating according to the model's logits
+ >>> outputs = model.generate(**inputs, max_new_tokens=10)
+ >>> print(tokenizer.batch_decode(outputs)[0])
+ A sequence: 1, 2, 3, 4, 5, 6, 7, 8
+
+ >>> # `forced_eos_token_id` ensures the generation ends with a EOS token
+ >>> outputs = model.generate(**inputs, max_new_tokens=10, forced_eos_token_id=tokenizer.eos_token_id)
+ >>> print(tokenizer.batch_decode(outputs)[0])
+ A sequence: 1, 2, 3, 4, 5, 6, 7,<|endoftext|>
+ ```
"""
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
@@ -1452,6 +1466,9 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that removes all `nan` and `inf` values to avoid the generation method to fail. Note that using
the logits processor should only be used if necessary since it can slow down the generation method.
+
+ This logits processor has no `generate` example, as there shouldn't a correct combination of flags that warrants
+ its use.
"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
@@ -1563,6 +1580,29 @@ class LogitNormalization(LogitsProcessor, LogitsWarper):
the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
the scores are normalized when comparing the hypotheses.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM
+ >>> import torch
+
+ >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
+ >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
+
+ >>> inputs = tokenizer("A sequence: 1, 2, 3", return_tensors="pt")
+
+ >>> # By default, the scores are not normalized -- the sum of their exponentials is NOT a normalized probability
+ >>> # distribution, summing to 1
+ >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
+ >>> print(torch.sum(torch.exp(outputs.scores[-1])))
+ tensor(816.2668)
+
+ >>> # Normalizing them may have a positive impact on beam methods, or when using the scores on your application
+ >>> outputs = model.generate(**inputs, renormalize_logits=True, return_dict_in_generate=True, output_scores=True)
+ >>> print(torch.sum(torch.exp(outputs.scores[-1])))
+ tensor(1.0000)
+ ```
"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
@@ -1574,8 +1614,36 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
r"""
[`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts
- generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not
- sampled at the begining of the generation.
+ generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are
+ not generated at the begining. Originally created for
+ [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
+ >>> from datasets import load_dataset
+
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
+ >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
+
+ >>> # Whisper has `begin_suppress_tokens` set by default (= `[220, 50256]`). 50256 is the EOS token, so this means
+ >>> # it can't generate and EOS token in the first iteration, but it can in the others.
+ >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
+ >>> print(outputs.scores[1][0, 50256]) # 1 (and not 0) is the first freely generated token
+ tensor(-inf)
+ >>> print(outputs.scores[-1][0, 50256]) # in other places we can see some probability mass for EOS
+ tensor(29.9010)
+
+ >>> # If we disable `begin_suppress_tokens`, we can generate EOS in the first iteration.
+ >>> outputs = model.generate(
+ ... **inputs, return_dict_in_generate=True, output_scores=True, begin_suppress_tokens=None
+ ... )
+ >>> print(outputs.scores[1][0, 50256])
+ tensor(11.2027)
+ ```
"""
def __init__(self, begin_suppress_tokens, begin_index):
@@ -1591,8 +1659,33 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
class SuppressTokensLogitsProcessor(LogitsProcessor):
- r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they
- are not sampled."""
+ r"""
+ This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so
+ that they are not generated. Originally created for
+ [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
+ >>> from datasets import load_dataset
+
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
+ >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
+
+ >>> # Whisper has a long list of suppressed tokens. For instance, in this case, the token 1 is suppressed by default.
+ >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
+ >>> print(outputs.scores[1][0, 1]) # 1 (and not 0) is the first freely generated token
+ tensor(-inf)
+
+ >>> # If we disable `suppress_tokens`, we can generate it.
+ >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, suppress_tokens=None)
+ >>> print(outputs.scores[1][0, 1])
+ tensor(5.7738)
+ ```
+ """
def __init__(self, suppress_tokens):
self.suppress_tokens = list(suppress_tokens)
@@ -1604,9 +1697,42 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
class ForceTokensLogitsProcessor(LogitsProcessor):
- r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
- indices that will be forced before sampling. The processor will set their log probs to `inf` so that they are
- sampled at their corresponding index."""
+ r"""
+ This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
+ indices that will be forced before generation. The processor will set their log probs to `inf` so that they are
+ sampled at their corresponding index. Originally created for
+ [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).
+
+ Examples:
+ ```python
+ >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
+ >>> from datasets import load_dataset
+
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
+ >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
+
+ >>> # This Whisper model forces the generation to start with `50362` at the first position by default, i.e.
+ >>> # `"forced_decoder_ids": [[1, 50362]]`. This means all other tokens are masked out.
+ >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
+ >>> print(
+ ... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
+ ... )
+ True
+ >>> print(outputs.scores[0][0, 50362])
+ tensor(0.)
+
+ >>> # If we disable `forced_decoder_ids`, we stop seeing that effect
+ >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, forced_decoder_ids=None)
+ >>> print(
+ ... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
+ ... )
+ False
+ >>> print(outputs.scores[0][0, 50362])
+ tensor(19.3140)
+ ```
+ """
def __init__(self, force_token_map: List[List[int]]):
self.force_token_map = dict(force_token_map)
@@ -1650,7 +1776,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
Examples:
``` python
>>> import torch
- >>> from transformers import AutoProcessor, WhisperForConditionalGeneration,GenerationConfig
+ >>> from transformers import AutoProcessor, WhisperForConditionalGeneration, GenerationConfig
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
@@ -1746,18 +1872,42 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
- r"""Logits processor for classifier free guidance (CFG). The scores are split over the batch dimension,
+ r"""
+ [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
where the first half correspond to the conditional logits (predicted from the input prompt) and the second half
correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a
weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`.
See [the paper](https://arxiv.org/abs/2306.05284) for more information.
+
+
+ This logits processor is exclusivelly compatible with
+ [MusicGen](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen)
+
+
+
Args:
guidance_scale (float):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
+
+ >>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+
+ >>> inputs = processor(
+ ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
+ ... padding=True,
+ ... return_tensors="pt",
+ ... )
+ >>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)
+ ```
"""
def __init__(self, guidance_scale):
@@ -1787,7 +1937,15 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
class AlternatingCodebooksLogitsProcessor(LogitsProcessor):
r"""
- [`LogitsProcessor`] enforcing alternated generation between the two codebooks of [`Bark`]'s fine submodel.
+ [`LogitsProcessor`] enforcing alternated generation between the two codebooks of Bark.
+
+
+
+ This logits processor is exclusivelly compatible with
+ [Bark](https://huggingface.co/docs/transformers/en/model_doc/bark)'s fine submodel. See the model documentation
+ for examples.
+
+
Args:
input_start_len (`int`):
@@ -1822,10 +1980,10 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
- r"""Logits processor for Classifier-Free Guidance (CFG). The processors
- computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits,
- parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with
- the `unconditional_ids` branch.
+ r"""
+ Logits processor for Classifier-Free Guidance (CFG). The processors computes a weighted average across scores
+ from prompt conditional and prompt unconditional (or negative) logits, parameterized by the `guidance_scale`.
+ The unconditional scores are computed internally by prompting `model` with the `unconditional_ids` branch.
See [the paper](https://arxiv.org/abs/2306.17806) for more information.
@@ -1942,6 +2100,13 @@ def __call__(self, input_ids, scores):
class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`.
+
+
+ This logits processor is exclusivelly compatible with
+ [Bark](https://huggingface.co/docs/transformers/en/model_doc/bark). See the model documentation for examples.
+
+
+
Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.