diff --git a/README.md b/README.md
index 755d0c505..f4b804c5f 100644
--- a/README.md
+++ b/README.md
@@ -198,7 +198,6 @@ You can refine your search by selecting the task you're interested in (e.g., [te
| Task | ID | Description | Supported? |
|--------------------------|----|-------------|------------|
-| [Conversational](https://huggingface.co/tasks/conversational) | `conversational` | Generating conversational text that is relevant, coherent and knowledgable given a prompt. | ❌ |
| [Fill-Mask](https://huggingface.co/tasks/fill-mask) | `fill-mask` | Masking some of the words in a sentence and predicting which words should replace those masks. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FillMaskPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers.js) |
| [Question Answering](https://huggingface.co/tasks/question-answering) | `question-answering` | Retrieve the answer to a question from a given text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.QuestionAnsweringPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=question-answering&library=transformers.js) |
| [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | `sentence-similarity` | Determining how similar two texts are. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FeatureExtractionPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=feature-extraction&library=transformers.js) |
diff --git a/docs/snippets/5_supported-tasks.snippet b/docs/snippets/5_supported-tasks.snippet
index ac71ee528..ee682ffca 100644
--- a/docs/snippets/5_supported-tasks.snippet
+++ b/docs/snippets/5_supported-tasks.snippet
@@ -5,7 +5,6 @@
| Task | ID | Description | Supported? |
|--------------------------|----|-------------|------------|
-| [Conversational](https://huggingface.co/tasks/conversational) | `conversational` | Generating conversational text that is relevant, coherent and knowledgable given a prompt. | ❌ |
| [Fill-Mask](https://huggingface.co/tasks/fill-mask) | `fill-mask` | Masking some of the words in a sentence and predicting which words should replace those masks. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FillMaskPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers.js) |
| [Question Answering](https://huggingface.co/tasks/question-answering) | `question-answering` | Retrieve the answer to a question from a given text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.QuestionAnsweringPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=question-answering&library=transformers.js) |
| [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | `sentence-similarity` | Determining how similar two texts are. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FeatureExtractionPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=feature-extraction&library=transformers.js) |
diff --git a/src/pipelines.js b/src/pipelines.js
index 3b15af53d..421053b9c 100644
--- a/src/pipelines.js
+++ b/src/pipelines.js
@@ -841,18 +841,24 @@ export class TranslationPipeline extends (/** @type {new (options: TextPipelineC
}
}
+function isChat(x) {
+ return Array.isArray(x) && x.every(x => 'role' in x && 'content' in x);
+}
/**
+ * @typedef {import('./tokenizers.js').Message[]} Chat
+ *
* @typedef {Object} TextGenerationSingle
- * @property {string} generated_text The generated text.
+ * @property {string|Chat} generated_text The generated text.
* @typedef {TextGenerationSingle[]} TextGenerationOutput
*
* @typedef {Object} TextGenerationSpecificParams Parameters specific to text-generation pipelines.
* @property {boolean} [add_special_tokens] Whether or not to add special tokens when tokenizing the sequences.
+ * @property {boolean} [return_full_text=true] If set to `false` only added text is returned, otherwise the full text is returned.
* @typedef {import('./utils/generation.js').GenerationConfigType & TextGenerationSpecificParams} TextGenerationConfig
*
* @callback TextGenerationPipelineCallback Complete the prompt(s) given as inputs.
- * @param {string|string[]} texts One or several prompts (or one list of prompts) to complete.
+ * @param {string|string[]|Chat|Chat[]} texts One or several prompts (or one list of prompts) to complete.
* @param {TextGenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model.
* @returns {Promise} An array or object containing the generated texts.
*
@@ -921,17 +927,46 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
/** @type {TextGenerationPipelineCallback} */
async _call(texts, generate_kwargs = {}) {
+ let isBatched = false;
+ let isChatInput = false;
+
+ // Normalize inputs
+ /** @type {string[]} */
+ let inputs;
+ if (typeof texts === 'string') {
+ inputs = texts = [texts];
+ } else if (Array.isArray(texts) && texts.every(x => typeof x === 'string')) {
+ isBatched = true;
+ inputs = /** @type {string[]} */(texts);
+ } else {
+ if (isChat(texts)) {
+ texts = [/** @type {Chat} */(texts)];
+ } else if (Array.isArray(texts) && texts.every(isChat)) {
+ isBatched = true;
+ } else {
+ throw new Error('Input must be a string, an array of strings, a Chat, or an array of Chats');
+ }
+ isChatInput = true;
- const isBatched = Array.isArray(texts);
- if (!isBatched) {
- texts = [/** @type {string}*/ (texts)];
+ // If the input is a chat, we need to apply the chat template
+ inputs = /** @type {string[]} */(/** @type {Chat[]} */ (texts).map(
+ x => this.tokenizer.apply_chat_template(x, {
+ tokenize: false,
+ add_generation_prompt: true,
+ })
+ ));
}
// By default, do not add special tokens
const add_special_tokens = generate_kwargs.add_special_tokens ?? false;
+ // By default, return full text
+ const return_full_text = isChatInput
+ ? false
+ : generate_kwargs.return_full_text ?? true;
+
this.tokenizer.padding_side = 'left';
- const { input_ids, attention_mask } = this.tokenizer(texts, {
+ const { input_ids, attention_mask } = this.tokenizer(inputs, {
add_special_tokens,
padding: true,
truncation: true,
@@ -941,17 +976,34 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
inputs_attention_mask: attention_mask
});
- const decoded = this.tokenizer.batch_decode(outputTokenIds, {
+ let decoded = this.tokenizer.batch_decode(outputTokenIds, {
skip_special_tokens: true,
});
+
+ let promptLengths;
+ if (!return_full_text && input_ids.dims.at(-1) > 0) {
+ promptLengths = this.tokenizer.batch_decode(input_ids, {
+ skip_special_tokens: true,
+ }).map(x => x.length);
+ }
+
/** @type {TextGenerationOutput[]} */
const toReturn = Array.from({ length: texts.length }, _ => []);
for (let i = 0; i < decoded.length; ++i) {
const textIndex = Math.floor(i / outputTokenIds.length * texts.length);
+ if (promptLengths) {
+ // Trim the decoded text to only include the generated part
+ decoded[i] = decoded[i].slice(promptLengths[textIndex]);
+ }
toReturn[textIndex].push({
- generated_text: decoded[i]
+ generated_text: isChatInput
+ ? [
+ ...((/** @type {Chat[]} */(texts)[textIndex])),
+ { role: 'assistant', content: decoded[i] },
+ ]
+ : decoded[i]
});
}
return (!isBatched && toReturn.length === 1) ? toReturn[0] : toReturn;
diff --git a/src/tokenizers.js b/src/tokenizers.js
index e671c8318..ca0c2ab6c 100644
--- a/src/tokenizers.js
+++ b/src/tokenizers.js
@@ -2429,6 +2429,12 @@ function truncateHelper(item, length) {
}
+/**
+ * @typedef {Object} Message
+ * @property {string} role The role of the message (e.g., "user" or "assistant" or "system").
+ * @property {string} content The content of the message.
+ */
+
export class PreTrainedTokenizer extends Callable {
return_token_type_ids = false;
@@ -2959,12 +2965,6 @@ export class PreTrainedTokenizer extends Callable {
return this._default_chat_template;
}
- /**
- * @typedef {Object} Message
- * @property {string} role The role of the message (e.g., "user" or "assistant" or "system").
- * @property {string} content The content of the message.
- */
-
/**
* Converts a list of message objects with `"role"` and `"content"` keys to a list of token
* ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
diff --git a/tests/generation.test.js b/tests/generation.test.js
index eb6b87f49..da50388aa 100644
--- a/tests/generation.test.js
+++ b/tests/generation.test.js
@@ -11,6 +11,8 @@ describe('Generation parameters', () => {
const models = [
'MBZUAI/LaMini-Flan-T5-77M', // encoder-decoder
'MBZUAI/LaMini-GPT-124M', // decoder-only
+
+ 'Xenova/llama2.c-stories15M', // decoder-only
];
// encoder-decoder model
@@ -135,4 +137,37 @@ describe('Generation parameters', () => {
}, MAX_TEST_EXECUTION_TIME);
+ // decoder-only model
+ it(models[2], async () => {
+ const MAX_NEW_TOKENS = 1;
+
+ const text = [
+ 'Once upon a time,',
+ 'Lily',
+ 'Suddenly,',
+ ];
+
+ const generator = await pipeline('text-generation', m(models[2]));
+
+ { // return_full_text=false
+ const output = await generator(text, {
+ return_full_text: false,
+ max_new_tokens: MAX_NEW_TOKENS,
+ num_beams: 2,
+ num_return_sequences: 2,
+ });
+ const lengths = output.flatMap(
+ x => x.flatMap(
+ y => generator.tokenizer.encode(y.generated_text.trim(), null, {
+ add_special_tokens: false,
+ }).length
+ )
+ ).every(x => x === MAX_NEW_TOKENS);
+
+ expect(lengths).toBe(true);
+ }
+ await generator.dispose();
+
+ }, MAX_TEST_EXECUTION_TIME);
+
});
\ No newline at end of file