diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 8929bacd84a12b..18764ac94d9129 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -23,7 +23,8 @@ [What are input IDs?](../glossary#input-ids) scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax - or scores for each vocabulary token after SoftMax. + or scores for each vocabulary token after SoftMax. If this stopping criteria depends on the `scores` input, + make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. kwargs (`Dict[str, Any]`, *optional*): Additional stopping criteria specific kwargs. @@ -34,7 +35,11 @@ class StoppingCriteria(ABC): - """Abstract base class for all stopping criteria that can be applied during generation.""" + """Abstract base class for all stopping criteria that can be applied during generation. + + If your stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True, + output_scores=True` to `generate`. + """ @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 606fbbe7060f93..1c412f8185dc34 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1397,7 +1397,9 @@ def generate( stopping_criteria (`StoppingCriteriaList`, *optional*): Custom stopping criteria that complement the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a - generation config an error is thrown. This feature is intended for advanced users. + generation config an error is thrown. If your stopping criteria depends on the `scores` input, make + sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is + intended for advanced users. prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and