From 400574883a8d1bf4ae2608cc707d719706601ee3 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 22 Apr 2024 14:01:30 +0100 Subject: [PATCH] Update docstrings for text generation pipeline (#30343) * Update docstrings for text generation pipeline * Fix docstring arg * Update docstring to explain chat mode * Fix doctests * Fix doctests --- src/transformers/pipelines/text_generation.py | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 0b358291717ee0..2f1ad71c781e4a 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -37,10 +37,11 @@ def __init__(self, messages: Dict): class TextGenerationPipeline(Pipeline): """ Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a - specified text prompt. It can also accept one or more chats. Each chat takes the form of a list of dicts, - where each dict contains "role" and "content" keys. + specified text prompt. When the underlying model is a conversational model, it can also accept one or more chats, + in which case the pipeline will operate in chat mode and will continue the chat(s) by adding its response(s). + Each chat takes the form of a list of dicts, where each dict contains "role" and "content" keys. - Example: + Examples: ```python >>> from transformers import pipeline @@ -53,6 +54,15 @@ class TextGenerationPipeline(Pipeline): >>> outputs = generator("My tart needs some", num_return_sequences=4, return_full_text=False) ``` + ```python + >>> from transformers import pipeline + + >>> generator = pipeline(model="HuggingFaceH4/zephyr-7b-beta") + >>> # Zephyr-beta is a conversational model, so let's pass it a chat instead of a single string + >>> generator([{"role": "user", "content": "What is the capital of France? Answer in one word."}], do_sample=False, max_new_tokens=2) + [{'generated_text': [{'role': 'user', 'content': 'What is the capital of France? Answer in one word.'}, {'role': 'assistant', 'content': 'Paris'}]}] + ``` + Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial). You can pass text generation parameters to this pipeline to control stopping criteria, decoding strategy, and more. Learn more about text generation parameters in [Text generation strategies](../generation_strategies) and [Text @@ -62,8 +72,9 @@ class TextGenerationPipeline(Pipeline): `"text-generation"`. The models that this pipeline can use are models that have been trained with an autoregressive language modeling - objective, which includes the uni-directional models in the library (e.g. openai-community/gpt2). See the list of available models - on [huggingface.co/models](https://huggingface.co/models?filter=text-generation). + objective. See the list of available [text completion models](https://huggingface.co/models?filter=text-generation) + and the list of [conversational models](https://huggingface.co/models?other=conversational) + on [huggingface.co/models]. """ # Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia @@ -194,8 +205,11 @@ def __call__(self, text_inputs, **kwargs): Complete the prompt(s) given as inputs. Args: - text_inputs (`str` or `List[str]`): - One or several prompts (or one list of prompts) to complete. + text_inputs (`str`, `List[str]`, List[Dict[str, str]], or `List[List[Dict[str, str]]]`): + One or several prompts (or one list of prompts) to complete. If strings or a list of string are + passed, this pipeline will continue each prompt. Alternatively, a "chat", in the form of a list + of dicts with "role" and "content" keys, can be passed, or a list of such chats. When chats are passed, + the model's chat template will be used to format them before passing them to the model. return_tensors (`bool`, *optional*, defaults to `False`): Whether or not to return the tensors of predictions (as token indices) in the outputs. If set to `True`, the decoded text is not returned. @@ -222,7 +236,7 @@ def __call__(self, text_inputs, **kwargs): corresponding to your framework [here](./model#generative-models)). Return: - A list or a list of list of `dict`: Returns one of the following dictionaries (cannot return a combination + A list or a list of lists of `dict`: Returns one of the following dictionaries (cannot return a combination of both `generated_text` and `generated_token_ids`): - **generated_text** (`str`, present when `return_text=True`) -- The generated text.