Skip to content

Commit

Permalink
Update docstrings for text generation pipeline (#30343)
Browse files Browse the repository at this point in the history
* Update docstrings for text generation pipeline

* Fix docstring arg

* Update docstring to explain chat mode

* Fix doctests

* Fix doctests
  • Loading branch information
Rocketknight1 authored Apr 22, 2024
1 parent 2d92db8 commit 0e9d44d
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ def __init__(self, messages: Dict):
class TextGenerationPipeline(Pipeline):
"""
Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a
specified text prompt. It can also accept one or more chats. Each chat takes the form of a list of dicts,
where each dict contains "role" and "content" keys.
specified text prompt. When the underlying model is a conversational model, it can also accept one or more chats,
in which case the pipeline will operate in chat mode and will continue the chat(s) by adding its response(s).
Each chat takes the form of a list of dicts, where each dict contains "role" and "content" keys.
Example:
Examples:
```python
>>> from transformers import pipeline
Expand All @@ -53,6 +54,15 @@ class TextGenerationPipeline(Pipeline):
>>> outputs = generator("My tart needs some", num_return_sequences=4, return_full_text=False)
```
```python
>>> from transformers import pipeline
>>> generator = pipeline(model="HuggingFaceH4/zephyr-7b-beta")
>>> # Zephyr-beta is a conversational model, so let's pass it a chat instead of a single string
>>> generator([{"role": "user", "content": "What is the capital of France? Answer in one word."}], do_sample=False, max_new_tokens=2)
[{'generated_text': [{'role': 'user', 'content': 'What is the capital of France? Answer in one word.'}, {'role': 'assistant', 'content': 'Paris'}]}]
```
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial). You can pass text
generation parameters to this pipeline to control stopping criteria, decoding strategy, and more. Learn more about
text generation parameters in [Text generation strategies](../generation_strategies) and [Text
Expand All @@ -62,8 +72,9 @@ class TextGenerationPipeline(Pipeline):
`"text-generation"`.
The models that this pipeline can use are models that have been trained with an autoregressive language modeling
objective, which includes the uni-directional models in the library (e.g. openai-community/gpt2). See the list of available models
on [huggingface.co/models](https://huggingface.co/models?filter=text-generation).
objective. See the list of available [text completion models](https://huggingface.co/models?filter=text-generation)
and the list of [conversational models](https://huggingface.co/models?other=conversational)
on [huggingface.co/models].
"""

# Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
Expand Down Expand Up @@ -194,8 +205,11 @@ def __call__(self, text_inputs, **kwargs):
Complete the prompt(s) given as inputs.
Args:
text_inputs (`str` or `List[str]`):
One or several prompts (or one list of prompts) to complete.
text_inputs (`str`, `List[str]`, List[Dict[str, str]], or `List[List[Dict[str, str]]]`):
One or several prompts (or one list of prompts) to complete. If strings or a list of string are
passed, this pipeline will continue each prompt. Alternatively, a "chat", in the form of a list
of dicts with "role" and "content" keys, can be passed, or a list of such chats. When chats are passed,
the model's chat template will be used to format them before passing them to the model.
return_tensors (`bool`, *optional*, defaults to `False`):
Whether or not to return the tensors of predictions (as token indices) in the outputs. If set to
`True`, the decoded text is not returned.
Expand All @@ -222,7 +236,7 @@ def __call__(self, text_inputs, **kwargs):
corresponding to your framework [here](./model#generative-models)).
Return:
A list or a list of list of `dict`: Returns one of the following dictionaries (cannot return a combination
A list or a list of lists of `dict`: Returns one of the following dictionaries (cannot return a combination
of both `generated_text` and `generated_token_ids`):
- **generated_text** (`str`, present when `return_text=True`) -- The generated text.
Expand Down

0 comments on commit 0e9d44d

Please sign in to comment.