From f21d18538697e20085ebce0a5e2dd9edb022e00d Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 23 Oct 2024 18:28:19 +0000 Subject: [PATCH] Fix Document QA pipeline --- src/models.js | 55 ++++++++++++++++++++++++++---------------------- src/pipelines.js | 1 - 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/src/models.js b/src/models.js index b7d2b0ee2..d357a83e4 100644 --- a/src/models.js +++ b/src/models.js @@ -411,7 +411,7 @@ function replaceTensors(obj) { /** * Converts an array or Tensor of integers to an int64 Tensor. - * @param {Array|Tensor} items The input integers to be converted. + * @param {any[]|Tensor} items The input integers to be converted. * @returns {Tensor} The int64 Tensor with the converted values. * @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length. * @private @@ -1334,35 +1334,37 @@ export class PreTrainedModel extends Callable { let { decoder_input_ids, ...model_inputs } = model_kwargs; // Prepare input ids if the user has not defined `decoder_input_ids` manually. - if (!decoder_input_ids) { - decoder_start_token_id ??= bos_token_id; - - if (this.config.model_type === 'musicgen') { - // Custom logic (TODO: move to Musicgen class) - decoder_input_ids = Array.from({ - length: batch_size * this.config.decoder.num_codebooks - }, () => [decoder_start_token_id]); - - } else if (Array.isArray(decoder_start_token_id)) { - if (decoder_start_token_id.length !== batch_size) { - throw new Error( - `\`decoder_start_token_id\` expcted to have length ${batch_size} but got ${decoder_start_token_id.length}` - ) + if (!(decoder_input_ids instanceof Tensor)) { + if (!decoder_input_ids) { + decoder_start_token_id ??= bos_token_id; + + if (this.config.model_type === 'musicgen') { + // Custom logic (TODO: move to Musicgen class) + decoder_input_ids = Array.from({ + length: batch_size * this.config.decoder.num_codebooks + }, () => [decoder_start_token_id]); + + } else if (Array.isArray(decoder_start_token_id)) { + if (decoder_start_token_id.length !== batch_size) { + throw new Error( + `\`decoder_start_token_id\` expcted to have length ${batch_size} but got ${decoder_start_token_id.length}` + ) + } + decoder_input_ids = decoder_start_token_id; + } else { + decoder_input_ids = Array.from({ + length: batch_size, + }, () => [decoder_start_token_id]); } - decoder_input_ids = decoder_start_token_id; - } else { + } else if (!Array.isArray(decoder_input_ids[0])) { + // Correct batch size decoder_input_ids = Array.from({ length: batch_size, - }, () => [decoder_start_token_id]); + }, () => decoder_input_ids); } - } else if (!Array.isArray(decoder_input_ids[0])) { - // Correct batch size - decoder_input_ids = Array.from({ - length: batch_size, - }, () => decoder_input_ids); + decoder_input_ids = toI64Tensor(decoder_input_ids); } - decoder_input_ids = toI64Tensor(decoder_input_ids); model_kwargs['decoder_attention_mask'] = ones_like(decoder_input_ids); return { input_ids: decoder_input_ids, model_inputs }; @@ -3185,8 +3187,11 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { export class VisionEncoderDecoderModel extends PreTrainedModel { main_input_name = 'pixel_values'; forward_params = [ + // Encoder inputs 'pixel_values', - 'input_ids', + + // Decoder inpputs + 'decoder_input_ids', 'encoder_hidden_states', 'past_key_values', ]; diff --git a/src/pipelines.js b/src/pipelines.js index d955803e6..3b7373cf9 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -2566,7 +2566,6 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options: /** @type {DocumentQuestionAnsweringPipelineCallback} */ async _call(image, question, generate_kwargs = {}) { - throw new Error('This pipeline is not yet supported in Transformers.js v3.'); // TODO: Remove when implemented // NOTE: For now, we only support a batch size of 1