Skip to content

Commit

Permalink
Generate: improve docstrings for custom stopping criteria (#26863)
Browse files Browse the repository at this point in the history
improve docstrings
  • Loading branch information
gante authored Oct 18, 2023
1 parent ef42cb6 commit e893b1e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
9 changes: 7 additions & 2 deletions src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
[What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
or scores for each vocabulary token after SoftMax.
or scores for each vocabulary token after SoftMax. If this stopping criteria depends on the `scores` input,
make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`.
kwargs (`Dict[str, Any]`, *optional*):
Additional stopping criteria specific kwargs.
Expand All @@ -34,7 +35,11 @@


class StoppingCriteria(ABC):
"""Abstract base class for all stopping criteria that can be applied during generation."""
"""Abstract base class for all stopping criteria that can be applied during generation.
If your stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True,
output_scores=True` to `generate`.
"""

@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,9 @@ def generate(
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
Expand Down

0 comments on commit e893b1e

Please sign in to comment.