Skip to content

Commit

Permalink
Fix Document QA pipeline (#987)
Browse files Browse the repository at this point in the history
* Fix Document QA pipeline

* Add `DonutImageProcessor`

* Update unit tests
  • Loading branch information
xenova authored Oct 24, 2024
1 parent e8c0f77 commit cf0c9c1
Show file tree
Hide file tree
Showing 4 changed files with 754 additions and 520 deletions.
55 changes: 30 additions & 25 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ function replaceTensors(obj) {

/**
* Converts an array or Tensor of integers to an int64 Tensor.
* @param {Array|Tensor} items The input integers to be converted.
* @param {any[]|Tensor} items The input integers to be converted.
* @returns {Tensor} The int64 Tensor with the converted values.
* @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length.
* @private
Expand Down Expand Up @@ -1334,35 +1334,37 @@ export class PreTrainedModel extends Callable {
let { decoder_input_ids, ...model_inputs } = model_kwargs;

// Prepare input ids if the user has not defined `decoder_input_ids` manually.
if (!decoder_input_ids) {
decoder_start_token_id ??= bos_token_id;

if (this.config.model_type === 'musicgen') {
// Custom logic (TODO: move to Musicgen class)
decoder_input_ids = Array.from({
length: batch_size * this.config.decoder.num_codebooks
}, () => [decoder_start_token_id]);

} else if (Array.isArray(decoder_start_token_id)) {
if (decoder_start_token_id.length !== batch_size) {
throw new Error(
`\`decoder_start_token_id\` expcted to have length ${batch_size} but got ${decoder_start_token_id.length}`
)
if (!(decoder_input_ids instanceof Tensor)) {
if (!decoder_input_ids) {
decoder_start_token_id ??= bos_token_id;

if (this.config.model_type === 'musicgen') {
// Custom logic (TODO: move to Musicgen class)
decoder_input_ids = Array.from({
length: batch_size * this.config.decoder.num_codebooks
}, () => [decoder_start_token_id]);

} else if (Array.isArray(decoder_start_token_id)) {
if (decoder_start_token_id.length !== batch_size) {
throw new Error(
`\`decoder_start_token_id\` expcted to have length ${batch_size} but got ${decoder_start_token_id.length}`
)
}
decoder_input_ids = decoder_start_token_id;
} else {
decoder_input_ids = Array.from({
length: batch_size,
}, () => [decoder_start_token_id]);
}
decoder_input_ids = decoder_start_token_id;
} else {
} else if (!Array.isArray(decoder_input_ids[0])) {
// Correct batch size
decoder_input_ids = Array.from({
length: batch_size,
}, () => [decoder_start_token_id]);
}, () => decoder_input_ids);
}
} else if (!Array.isArray(decoder_input_ids[0])) {
// Correct batch size
decoder_input_ids = Array.from({
length: batch_size,
}, () => decoder_input_ids);
decoder_input_ids = toI64Tensor(decoder_input_ids);
}

decoder_input_ids = toI64Tensor(decoder_input_ids);
model_kwargs['decoder_attention_mask'] = ones_like(decoder_input_ids);

return { input_ids: decoder_input_ids, model_inputs };
Expand Down Expand Up @@ -3185,8 +3187,11 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
export class VisionEncoderDecoderModel extends PreTrainedModel {
main_input_name = 'pixel_values';
forward_params = [
// Encoder inputs
'pixel_values',
'input_ids',

// Decoder inpputs
'decoder_input_ids',
'encoder_hidden_states',
'past_key_values',
];
Expand Down
1 change: 0 additions & 1 deletion src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -2566,7 +2566,6 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options:

/** @type {DocumentQuestionAnsweringPipelineCallback} */
async _call(image, question, generate_kwargs = {}) {
throw new Error('This pipeline is not yet supported in Transformers.js v3.'); // TODO: Remove when implemented

// NOTE: For now, we only support a batch size of 1

Expand Down
2 changes: 2 additions & 0 deletions src/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,7 @@ export class DonutFeatureExtractor extends ImageFeatureExtractor {
});
}
}
export class DonutImageProcessor extends DonutFeatureExtractor { } // NOTE extends DonutFeatureExtractor
export class NougatImageProcessor extends DonutFeatureExtractor { } // NOTE extends DonutFeatureExtractor

/**
Expand Down Expand Up @@ -2569,6 +2570,7 @@ export class AutoProcessor {
MaskFormerFeatureExtractor,
YolosFeatureExtractor,
DonutFeatureExtractor,
DonutImageProcessor,
NougatImageProcessor,
EfficientNetImageProcessor,

Expand Down
Loading

0 comments on commit cf0c9c1

Please sign in to comment.