Skip to content

Commit

Permalink
Fix text2text generation pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Jun 6, 2024
1 parent 31f9d8c commit f4247ea
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,6 @@ export class Text2TextGenerationPipeline extends (/** @type {new (options: TextP

/** @type {Text2TextGenerationPipelineCallback} */
async _call(texts, generate_kwargs = {}) {
throw new Error('This pipeline is not yet supported in Transformers.js v3.'); // TODO: Remove when implemented
if (!Array.isArray(texts)) {
texts = [texts];
}
Expand All @@ -706,19 +705,18 @@ export class Text2TextGenerationPipeline extends (/** @type {new (options: TextP
padding: true,
truncation: true,
}
let input_ids;
let inputs;
if (this instanceof TranslationPipeline && '_build_translation_inputs' in tokenizer) {
// TODO: move to Translation pipeline?
// Currently put here to avoid code duplication
// @ts-ignore
input_ids = tokenizer._build_translation_inputs(texts, tokenizer_options, generate_kwargs).input_ids;
inputs = tokenizer._build_translation_inputs(texts, tokenizer_options, generate_kwargs);

} else {
input_ids = tokenizer(texts, tokenizer_options).input_ids;
inputs = tokenizer(texts, tokenizer_options);
}

const outputTokenIds = await this.model.generate({ inputs: input_ids, ...generate_kwargs });

const outputTokenIds = await this.model.generate({ ...inputs, ...generate_kwargs });
return tokenizer.batch_decode(/** @type {Tensor} */(outputTokenIds), {
skip_special_tokens: true,
}).map(text => ({ [this._key]: text }));
Expand Down

0 comments on commit f4247ea

Please sign in to comment.