From ac3c1553a51ee92fe9c450f4a364249db676a8b5 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 16 Jan 2024 15:30:07 +0200 Subject: [PATCH 1/4] Implement `unfold` tensor operator --- src/utils/tensor.js | 77 ++++++++++++++++++++++++++++++++++++++------ tests/tensor.test.js | 66 ++++++++++++++++++++++++++++++++++++- 2 files changed, 132 insertions(+), 11 deletions(-) diff --git a/src/utils/tensor.js b/src/utils/tensor.js index 74cb23880..dac53e8c0 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -327,7 +327,7 @@ export class Tensor { * * @param {number} [dim=null] The dimension or dimensions to reduce. If `null`, all dimensions are reduced. * @param {boolean} keepdim Whether the output tensor has `dim` retained or not. - * @returns The summed tensor + * @returns {Tensor} The summed tensor */ sum(dim = null, keepdim = false) { return this.norm(1, dim, keepdim); @@ -458,7 +458,7 @@ export class Tensor { * If you would like a copy, use `tensor.clone()` before squeezing. * * @param {number} [dim=null] If given, the input will be squeezed only in the specified dimensions. - * @returns The squeezed tensor + * @returns {Tensor} The squeezed tensor */ squeeze(dim = null) { return new Tensor( @@ -482,7 +482,7 @@ export class Tensor { * NOTE: The returned tensor shares the same underlying data with this tensor. * * @param {number} dim The index at which to insert the singleton dimension - * @returns The unsqueezed tensor + * @returns {Tensor} The unsqueezed tensor */ unsqueeze(dim = null) { return new Tensor( @@ -521,7 +521,7 @@ export class Tensor { * and ending with `end_dim` are flattened. The order of elements in input is unchanged. * @param {number} start_dim the first dim to flatten * @param {number} end_dim the last dim to flatten - * @returns The flattened tensor. + * @returns {Tensor} The flattened tensor. */ flatten(start_dim = 0, end_dim = -1) { return this.clone().flatten_(start_dim, end_dim); @@ -579,7 +579,7 @@ export class Tensor { * Clamps all elements in input into the range [ min, max ] * @param {number} min lower-bound of the range to be clamped to * @param {number} max upper-bound of the range to be clamped to - * @returns the output tensor. + * @returns {Tensor} the output tensor. */ clamp(min, max) { return this.clone().clamp_(min, max); @@ -597,12 +597,21 @@ export class Tensor { /** * Rounds elements of input to the nearest integer. - * @returns the output tensor. + * @returns {Tensor} the output tensor. */ round() { return this.clone().round_(); } + /** + * Reshape the tensor to the given shape. + * @param {...any} dims The new dimensions of the tensor. + * @returns {Tensor} The reshaped tensor. + */ + reshape(...dims) { + return new Tensor(this.type, this.data, dims); + } + /** * Performs Tensor dtype conversion. * @param {DataType} type The desired data type. @@ -766,7 +775,7 @@ export function mean_pooling(last_hidden_state, attention_mask) { * Helper function to calculate new dimensions when performing a squeeze operation. * @param {number[]} dims The dimensions of the tensor. * @param {number|number[]|null} dim The dimension(s) to squeeze. - * @returns The new dimensions. + * @returns {number[]} The new dimensions. * @private */ function calc_squeeze_dims(dims, dim) { @@ -789,7 +798,7 @@ function calc_squeeze_dims(dims, dim) { * Helper function to calculate new dimensions when performing an unsqueeze operation. * @param {number[]} dims The dimensions of the tensor. * @param {number} dim The dimension to unsqueeze. - * @returns The new dimensions. + * @returns {number[]} The new dimensions. * @private */ function calc_unsqueeze_dims(dims, dim) { @@ -976,7 +985,7 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) { * @param {Tensor} input the input tensor. * @param {number|null} dim the dimension to reduce. * @param {boolean} keepdim whether the output tensor has dim retained or not. - * @returns A new tensor with means taken along the specified dimension. + * @returns {Tensor} A new tensor with means taken along the specified dimension. */ export function mean(input, dim = null, keepdim = false) { @@ -1150,8 +1159,56 @@ export function ones(size) { /** * Returns a tensor filled with the scalar value 1, with the same size as input. * @param {Tensor} tensor The size of input will determine size of the output tensor. - * @returns The ones tensor. + * @returns {Tensor} The ones tensor. */ export function ones_like(tensor) { return ones(tensor.dims); } + +/** + * + * @param {Tensor} input + * @param {[number, number]} kernel_size + * @param {[number, number]} stride + * @returns {Tensor} + */ +export function unfold(input, kernel_size, stride) { + + const [batchSize, inputChannels, inputHeight, inputWidth] = input.dims; + const [kernelHeight, kernelWidth] = kernel_size; + const [strideHeight, strideWidth] = stride; + + const outputHeight = Math.floor((inputHeight - kernelHeight) / strideHeight) + 1; + const outputWidth = Math.floor((inputWidth - kernelWidth) / strideWidth) + 1; + const outputChannels = kernelHeight * kernelWidth * inputChannels; + + const newDims = [batchSize, outputChannels, outputHeight * outputWidth]; + + // @ts-ignore + const unfoldedData = new input.data.constructor(newDims.reduce((a, b) => a * b, 1)); + + const v1 = inputHeight * inputWidth; + const v2 = inputChannels * v1; + const v3 = strideHeight * inputWidth; + + let outputIndex = 0; + for (let b = 0; b < batchSize; ++b) { + const o1 = b * v2; + for (let c = 0; c < inputChannels; ++c) { + const o2 = c * v1 + o1; + for (let kh = 0; kh < kernelHeight; ++kh) { + const o3 = kh * inputWidth + o2; + for (let kw = 0; kw < kernelWidth; ++kw) { + const o4 = o3 + kw; + for (let i = 0; i < outputHeight; ++i) { + const o5 = i * v3 + o4; + for (let j = 0; j < outputWidth; ++j) { + unfoldedData[outputIndex++] = input.data[o5 + j * strideWidth]; + } + } + } + } + } + } + return new Tensor(input.type, unfoldedData, newDims); +} diff --git a/tests/tensor.test.js b/tests/tensor.test.js index 93d9fd0ad..59d1c9b95 100644 --- a/tests/tensor.test.js +++ b/tests/tensor.test.js @@ -1,7 +1,7 @@ import { Tensor } from '../src/transformers.js'; import { compare } from './test_utils.js'; -import { cat, mean, stack } from '../src/utils/tensor.js'; +import { cat, mean, stack, unfold } from '../src/utils/tensor.js'; describe('Tensor operations', () => { @@ -128,4 +128,68 @@ describe('Tensor operations', () => { }) }); + + describe('unfold', () => { + it('should unfold', async () => { + { // batch_size=1 + const dims = [1, 2, 4, 6]; + const data = new Float32Array(dims.reduce((a, b) => a * b, 1)).map((_, i) => i); + + const unfolded = unfold( + new Tensor('float32', data, dims), + [2, 2], // kernel_size + [2, 2], // stride + ) + + const target = new Tensor('float32', [ + 0, 2, 4, 12, 14, 16, + 1, 3, 5, 13, 15, 17, + 6, 8, 10, 18, 20, 22, + 7, 9, 11, 19, 21, 23, + 24, 26, 28, 36, 38, 40, + 25, 27, 29, 37, 39, 41, + 30, 32, 34, 42, 44, 46, + 31, 33, 35, 43, 45, 47 + ], [1, 8, 6]); + + compare(unfolded, target, 1e-3); + } + + { // batch_size > 1 + const dims = [2, 2, 4, 6]; + const data = new Float32Array(dims.reduce((a, b) => a * b, 1)).map((_, i) => i); + + const unfolded = unfold( + new Tensor('float32', data, dims), + [2, 2], // kernel_size + [3, 3], // stride + ) + + const target = new Tensor('float32', [ + 0, 3, 1, 4, 6, 9, 7, 10, 24, 27, 25, 28, 30, 33, 31, 34, + 48, 51, 49, 52, 54, 57, 55, 58, 72, 75, 73, 76, 78, 81, 79, 82 + ], [2, 8, 2]); + + compare(unfolded, target, 1e-3); + } + + { // mismatched kernel_size and stride, batch_size > 1 + const dims = [2, 2, 4, 6]; + const data = new Float32Array(dims.reduce((a, b) => a * b, 1)).map((_, i) => i); + + const unfolded = unfold( + new Tensor('float32', data, dims), + [2, 3], // kernel_size + [3, 2], // stride + ) + + const target = new Tensor('float32', [ + 0, 2, 1, 3, 2, 4, 6, 8, 7, 9, 8, 10, 24, 26, 25, 27, 26, 28, 30, 32, 31, 33, + 32, 34, 48, 50, 49, 51, 50, 52, 54, 56, 55, 57, 56, 58, 72, 74, 73, 75, 74, 76, 78, 80, 79, 81, 80, 82 + ], [2, 12, 2]); + + compare(unfolded, target, 1e-3); + } + }); + }) }); From 1f8dc764f6607c80111cd0854cd31895e830c76a Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 16 Jan 2024 18:09:17 +0200 Subject: [PATCH 2/4] Add support for Pix2Struct models --- README.md | 1 + docs/snippets/6_supported-models.snippet | 1 + scripts/supported_models.py | 20 ++ src/models.js | 77 ++++++- src/pipelines.js | 23 +- src/processors.js | 270 +++++++++++++++++++---- 6 files changed, 345 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index d0f9bec49..af52643da 100644 --- a/README.md +++ b/README.md @@ -324,6 +324,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[Phi](https://huggingface.co/docs/transformers/main/model_doc/phi)** (from Microsoft) released with the papers - [Textbooks Are All You Need](https://arxiv.org/abs/2306.11644) by Suriya Gunasekar, Yi Zhang, Jyoti Aneja, Caio César Teodoro Mendes, Allie Del Giorno, Sivakanth Gopi, Mojan Javaheripi, Piero Kauffmann, Gustavo de Rosa, Olli Saarikivi, Adil Salim, Shital Shah, Harkirat Singh Behl, Xin Wang, Sébastien Bubeck, Ronen Eldan, Adam Tauman Kalai, Yin Tat Lee and Yuanzhi Li, [Textbooks Are All You Need II: phi-1.5 technical report](https://arxiv.org/abs/2309.05463) by Yuanzhi Li, Sébastien Bubeck, Ronen Eldan, Allie Del Giorno, Suriya Gunasekar and Yin Tat Lee. +1. **[Pix2Struct](https://huggingface.co/docs/transformers/model_doc/pix2struct)** (from Google) released with the paper [Pix2Struct: Screenshot Parsing as Pretraining for Visual Language Understanding](https://arxiv.org/abs/2210.03347) by Kenton Lee, Mandar Joshi, Iulia Turc, Hexiang Hu, Fangyu Liu, Julian Eisenschlos, Urvashi Khandelwal, Peter Shaw, Ming-Wei Chang, Kristina Toutanova. 1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index d2054b93c..6336622ce 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -59,6 +59,7 @@ 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[Phi](https://huggingface.co/docs/transformers/main/model_doc/phi)** (from Microsoft) released with the papers - [Textbooks Are All You Need](https://arxiv.org/abs/2306.11644) by Suriya Gunasekar, Yi Zhang, Jyoti Aneja, Caio César Teodoro Mendes, Allie Del Giorno, Sivakanth Gopi, Mojan Javaheripi, Piero Kauffmann, Gustavo de Rosa, Olli Saarikivi, Adil Salim, Shital Shah, Harkirat Singh Behl, Xin Wang, Sébastien Bubeck, Ronen Eldan, Adam Tauman Kalai, Yin Tat Lee and Yuanzhi Li, [Textbooks Are All You Need II: phi-1.5 technical report](https://arxiv.org/abs/2309.05463) by Yuanzhi Li, Sébastien Bubeck, Ronen Eldan, Allie Del Giorno, Suriya Gunasekar and Yin Tat Lee. +1. **[Pix2Struct](https://huggingface.co/docs/transformers/model_doc/pix2struct)** (from Google) released with the paper [Pix2Struct: Screenshot Parsing as Pretraining for Visual Language Understanding](https://arxiv.org/abs/2210.03347) by Kenton Lee, Mandar Joshi, Iulia Turc, Hexiang Hu, Fangyu Liu, Julian Eisenschlos, Urvashi Khandelwal, Peter Shaw, Ming-Wei Chang, Kristina Toutanova. 1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. diff --git a/scripts/supported_models.py b/scripts/supported_models.py index 0415c8cd3..bba0b5c72 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -722,6 +722,26 @@ 'susnato/phi-1_5_dev', ], }, + 'pix2struct': { + # Image-to-text + 'image-to-text': [ + 'fxmarty/pix2struct-tiny-random', + 'google/pix2struct-textcaps-base', + ], + + # Visual Question Answering (VQA) + 'visual-question-answering': [ + 'google/deplot', + + # TODO: + # 'google/pix2struct-docvqa-base', + # 'google/pix2struct-widget-captioning-base', + # 'google/pix2struct-ai2d-base', + # 'google/pix2struct-chartqa-base', + # 'google/pix2struct-screen2words-base', + # 'google/pix2struct-infographics-vqa-base', + ], + }, 'roberta': { # Feature extraction 'feature-extraction': [ diff --git a/src/models.js b/src/models.js index ae56ea952..b72d75b55 100644 --- a/src/models.js +++ b/src/models.js @@ -374,6 +374,15 @@ async function seq2seqForward(self, model_inputs) { decoderFeeds.encoder_attention_mask = model_inputs.attention_mask } + if (self.decoder_merged_session.inputNames.includes('decoder_attention_mask')) { + // TODO: When we perform parallelism, we must adjust attention mask depending on + // location of pad token + decoderFeeds.decoder_attention_mask = new Tensor( + 'int64', + new BigInt64Array(model_inputs.decoder_input_ids.data.length).fill(1n), + model_inputs.decoder_input_ids.dims, + ) + } preparePositionIds(self.decoder_merged_session, decoderFeeds, use_cache_branch); self.addPastKeyValues(decoderFeeds, past_key_values); @@ -437,7 +446,9 @@ function seq2seqStartBeams(self, inputTokenIds, generation_config, numOutputToke } if (requires_attention_mask) { - start.attention_mask = prepareAttentionMask(self, tokens); + start.attention_mask = + generation_config.attention_mask + ?? prepareAttentionMask(self, tokens); } beams.push(start); @@ -981,7 +992,7 @@ export class PreTrainedModel extends Callable { * @typedef {Object} DecoderOutput * * Generates text based on the given inputs and generation configuration using the model. - * @param {Tensor|Array|TypedArray} inputs An array of input token IDs. + * @param {Tensor|Array|TypedArray|Object} inputs An array of input token IDs. * @param {Object|GenerationConfig|null} generation_config The generation configuration to use. If null, default configuration will be used. * @param {Object|null} logits_processor An optional logits processor to use. If null, a new LogitsProcessorList instance will be created. * @param {Object} options options @@ -1006,8 +1017,8 @@ export class PreTrainedModel extends Callable { MODEL_WITH_LM_HEAD_MAPPING_NAMES.get(modelType) ?? MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.get(modelType) ?? MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.get(modelType) - // ?? MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES.get(modelType) // TODO - ?? MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.get(modelType); + ?? MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.get(modelType) + ?? MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES.get(modelType); if (possibleInfo) { // TODO: support multiple possible classes @@ -1017,7 +1028,7 @@ export class PreTrainedModel extends Callable { } if (!(inputs instanceof Tensor) && !isTypedArray(inputs) && !Array.isArray(inputs)) { - throw Error(`\`inputs\` must be a Tensor, TypedArray, or Array, but is "${inputs.constructor.name}".`); + throw Error(`\`inputs\` must be a Tensor, TypedArray, or Array, but is "${inputs?.constructor?.name}".`); } let input_ids_seq_length; @@ -3044,6 +3055,61 @@ export class VisionEncoderDecoderModel extends PreTrainedModel { } ////////////////////////////////////////////////// +export class Pix2StructPreTrainedModel extends PreTrainedModel { } + +/** + * A conditional generation model with a language modeling head. Can be used for sequence generation tasks. + */ +export class Pix2StructForConditionalGeneration extends Pix2StructPreTrainedModel { + main_input_name = 'flattened_patches'; + + /** + * Creates a new instance of the `VisionEncoderDecoderModel` class. + * @param {Object} config The configuration object specifying the hyperparameters and other model settings. + * @param {Object} session The ONNX session containing the encoder model. + * @param {any} decoder_merged_session The ONNX session containing the merged decoder model. + * @param {Object} generation_config Configuration object for the generation process. + */ + constructor(config, session, decoder_merged_session, generation_config) { + super(config, session); + this.decoder_merged_session = decoder_merged_session; + this.generation_config = generation_config; + + const textConfig = this.config.text_config; + this.num_encoder_layers = this.num_decoder_layers = textConfig.num_layers; + this.num_encoder_heads = this.num_decoder_heads = textConfig.num_heads; + this.encoder_dim_kv = this.decoder_dim_kv = textConfig.d_kv; + } + + /** + * Generates outputs based on input and generation configuration. + * @param {Object} inputs Input data for the model. + * @param {Object} generation_config Configuration object for the generation process. + * @param {Object} logits_processor Optional logits processor object. + * @returns {Promise} Promise object represents the generated outputs. + */ + async generate( + inputs, + generation_config = null, + logits_processor = null, + ) { + const { flattened_patches, attention_mask } = inputs; + + // Create generation config object + generation_config = this._get_generation_config(generation_config); + + // Compute image embeddings + const outputs = await super.generate(flattened_patches, { + ...generation_config, + decoder_input_ids: [this.config.pad_token_id], + attention_mask, + }, logits_processor); + + return outputs + } + +} + ////////////////////////////////////////////////// // CLIP models export class CLIPPreTrainedModel extends PreTrainedModel { } @@ -5326,6 +5392,7 @@ const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = new Map([ const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([ ['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]], + ['pix2struct', ['Pix2StructForConditionalGeneration', Pix2StructForConditionalGeneration]], ]); const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([ diff --git a/src/pipelines.js b/src/pipelines.js index 161094270..4dfe3b0d6 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -1743,15 +1743,30 @@ export class ImageToTextPipeline extends (/** @type {new (options: TextImagePipe const isBatched = Array.isArray(images); const preparedImages = await prepareImages(images); - const { pixel_values } = await this.processor(preparedImages); + const inputs = await this.processor(preparedImages); + + let batchedInputs = []; + + const main_input = inputs[this.model.main_input_name]; + if (this.model.config.model_type === 'pix2struct') { + const batch_size = main_input.dims[0]; + for (let i = 0; i < batch_size; ++i) { + const items = {}; + for (const key in inputs) { + items[key] = inputs[key][i].unsqueeze(0); + } + batchedInputs.push(items); + } + } else { + batchedInputs = main_input.unsqueeze(1); + } const toReturn = []; - for (const batch of pixel_values) { - batch.dims = [1, ...batch.dims] + for (const batch of batchedInputs) { const output = await this.model.generate(batch, generate_kwargs); const decoded = this.tokenizer.batch_decode(output, { skip_special_tokens: true, - }).map(x => ({ generated_text: x.trim() })) + }).map(generated_text => ({ generated_text })) toReturn.push(decoded); } diff --git a/src/processors.js b/src/processors.js index 6165e27be..eae25e8c9 100644 --- a/src/processors.js +++ b/src/processors.js @@ -36,7 +36,7 @@ import { } from './utils/maths.js'; -import { Tensor, transpose, cat, interpolate, stack } from './utils/tensor.js'; +import { Tensor, transpose, cat, interpolate, stack, unfold } from './utils/tensor.js'; import { RawImage } from './utils/image.js'; import { @@ -269,6 +269,24 @@ export class ImageFeatureExtractor extends FeatureExtractor { return await image.resize(width, height, { resample }); } + /** + * Crops the image to the specified size. + * @param {RawImage} image The image to be cropped. + * @returns {Promise} The cropped image. + */ + async center_crop(image) { + let crop_width; + let crop_height; + if (Number.isInteger(this.crop_size)) { + crop_width = this.crop_size; + crop_height = this.crop_size; + } else { + crop_width = this.crop_size.width; + crop_height = this.crop_size.height; + } + + return await image.center_crop(crop_width, crop_height); + } /** * Crops the margin of the image. Gray pixels are considered margin (i.e., pixels with a value below the threshold). @@ -407,18 +425,45 @@ export class ImageFeatureExtractor extends FeatureExtractor { } } + /** + * Normalize an image. image = (image - image_mean) / image_std. + * @param {Float32Array} pixelData + * @param {RawImage} image + */ + normalize(pixelData, image) { + let image_mean = this.image_mean; + if (!Array.isArray(this.image_mean)) { + image_mean = new Array(image.channels).fill(image_mean); + } + + let image_std = this.image_std; + if (!Array.isArray(this.image_std)) { + image_std = new Array(image.channels).fill(image_mean); + } + + if (image_mean.length !== image.channels || image_std.length !== image.channels) { + throw new Error(`When set to arrays, the length of \`image_mean\` (${image_mean.length}) and \`image_std\` (${image_std.length}) must match the number of channels in the image (${image.channels}).`); + } + + for (let i = 0; i < pixelData.length; i += image.channels) { + for (let j = 0; j < image.channels; ++j) { + pixelData[i + j] = (pixelData[i + j] - this.image_mean[j]) / this.image_std[j]; + } + } + } + /** * Find the target (width, height) dimension of the output image after * resizing given the input image and the desired size. - * @param {RawImage} image The image to resize. + * @param {[number, number]} inputSize The image size. * @param {any} size The size to use for resizing the image. * @returns {[number, number]} The target (width, height) dimension of the output image after resizing. */ - get_resize_output_image_size(image, size) { + get_resize_output_image_size(inputSize, size) { // `size` comes in many forms, so we need to handle them all here: // 1. `size` is an integer, in which case we resize the image to be a square - const [srcWidth, srcHeight] = image.size; + const [srcWidth, srcHeight] = inputSize; let shortest_edge; let longest_edge; @@ -483,7 +528,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { * @returns {Promise} The resized image. */ async resize(image) { - const [newWidth, newHeight] = this.get_resize_output_image_size(image, this.size); + const [newWidth, newHeight] = this.get_resize_output_image_size(image.size, this.size); return await image.resize(newWidth, newHeight, { resample: this.resample, }); @@ -538,18 +583,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { } if (this.do_center_crop) { - - let crop_width; - let crop_height; - if (Number.isInteger(this.crop_size)) { - crop_width = this.crop_size; - crop_height = this.crop_size; - } else { - crop_width = this.crop_size.width; - crop_height = this.crop_size.height; - } - - image = await image.center_crop(crop_width, crop_height); + image = await this.center_crop(image); } /** @type {HeightWidth} */ @@ -563,25 +597,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { } if (do_normalize ?? this.do_normalize) { - let image_mean = this.image_mean; - if (!Array.isArray(this.image_mean)) { - image_mean = new Array(image.channels).fill(image_mean); - } - - let image_std = this.image_std; - if (!Array.isArray(this.image_std)) { - image_std = new Array(image.channels).fill(image_mean); - } - - if (image_mean.length !== image.channels || image_std.length !== image.channels) { - throw new Error(`When set to arrays, the length of \`image_mean\` (${image_mean.length}) and \`image_std\` (${image_std.length}) must match the number of channels in the image (${image.channels}).`); - } - - for (let i = 0; i < pixelData.length; i += image.channels) { - for (let j = 0; j < image.channels; ++j) { - pixelData[i + j] = (pixelData[i + j] - this.image_mean[j]) / this.image_std[j]; - } - } + this.normalize(pixelData, image); } // do padding after rescaling/normalizing @@ -631,9 +647,186 @@ export class ImageFeatureExtractor extends FeatureExtractor { reshaped_input_sizes: imageData.map(x => x.reshaped_input_size), } } - } +export class Pix2StructImageProcessor extends ImageFeatureExtractor { + constructor(config) { + super(config); + + this.max_patches = config.max_patches; + this.patch_size = config.patch_size; + + this.do_rescale = false; + } + + /** + * The image std is to mimic the tensorflow implementation of the `per_image_standardization`: + * https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization + * @param {Float32Array} pixelData The pixel data to normalize. + */ + normalize(pixelData) { + // Take mean across the whole image + const mean = pixelData.reduce((a, b) => a + b, 0) / pixelData.length; + const std = Math.sqrt(pixelData.reduce((a, b) => a + (b - mean) ** 2, 0) / pixelData.length); + + const adjusted_stddev = Math.max(std, 1.0 / Math.sqrt(pixelData.length)); + + for (let i = 0; i < pixelData.length; ++i) { + pixelData[i] = (pixelData[i] - mean) / adjusted_stddev; + } + } + + extract_patches(image_tensor, patch_height, patch_width) { + + image_tensor = image_tensor.unsqueeze(0); + + let patches = unfold( + image_tensor, // input + [patch_height, patch_width], // kernel_size + [patch_height, patch_width], // stride + ) + + patches = patches.view( + image_tensor.dims.at(0), + image_tensor.dims.at(1), + patch_height, + patch_width, + -1, + ) + + patches = patches.transpose( + 0, 4, 2, 3, 1 + ).view( + Math.floor(image_tensor.dims.at(2) / patch_height), + Math.floor(image_tensor.dims.at(3) / patch_width), + image_tensor.dims.at(1) * patch_height * patch_width, + ) + + return patches.unsqueeze(0); + } + + extract_flattened_patches(pixel_values, max_patches, patch_size) { + + const [image_height, image_width] = pixel_values.dims.slice(-2); + const [patch_width, patch_height] = this.get_resize_output_image_size([image_height, image_width], patch_size); + + // maximize scale s.t. + const scale = Math.sqrt(max_patches * (patch_height / image_height) * (patch_width / image_width)) + const num_feasible_rows = Math.max(Math.min(Math.floor(scale * image_height / patch_height), max_patches), 1) + const num_feasible_cols = Math.max(Math.min(Math.floor(scale * image_width / patch_width), max_patches), 1) + const resized_height = Math.max(num_feasible_rows * patch_height, 1) + const resized_width = Math.max(num_feasible_cols * patch_width, 1) + + // [ 3, 592, 880 ] + const resized = interpolate(pixel_values, [resized_height, resized_width], 'bilinear', false); + + let patches = this.extract_patches( + resized, patch_height, patch_width + ) + const [ + b, + rows, + columns, + depth, + ] = patches.dims; + + // [rows * columns, patch_height * patch_width * image_channels] + patches = patches.view( + rows * columns, depth, + ) + + const row_ids_data = new Float32Array(rows * columns); + const col_ids_data = new Float32Array(rows * columns); + for (let i = 0; i < row_ids_data.length; ++i) { + // NOTE: Offset by 1 so the ids do not contain zeros, which represent padding. + row_ids_data[i] = 1 + Math.floor(i / columns); + col_ids_data[i] = 1 + i % columns; + } + + const d = [max_patches, 2 + patches.dims.at(-1)]; + + const result = cat([ + new Tensor('float32', row_ids_data, [rows * columns, 1]), + new Tensor('float32', col_ids_data, [rows * columns, 1]), + patches, + ], -1) + + const diff = max_patches - (rows * columns); + const flattened_patches = cat([ + result, + new Tensor('float32', new Float32Array(diff * d[1]), [diff, d[1]]), + ], 0); + + const attention_mask = new Tensor('int64', new BigInt64Array(max_patches), [max_patches]); + /** @type {BigInt64Array} */ (attention_mask.data).fill(1n, 0, rows * columns); + + return { + flattened_patches, + attention_mask, + }; + } + + + /** + * Preprocesses the given image. + * + * @param {RawImage} image The image to preprocess. + * @param {Object} overrides The overrides for the preprocessing options. + * @returns {Promise} The preprocessed image. + */ + async preprocess(image, { + do_normalize = null, + do_pad = null, + do_convert_rgb = null, + do_convert_grayscale = null, + } = {}) { + + const { + original_size, + reshaped_input_size, + pixel_values, + } = await super.preprocess(image, { + do_normalize, + do_pad, + do_convert_rgb, + do_convert_grayscale, + }); + + return this.extract_flattened_patches(pixel_values, this.max_patches, this.patch_size); + + } + + /** + * @typedef {object} Pix2StructExtractorResult + * @property {Tensor} attention_mask The attention mask. + * @property {Tensor} flattened_patches The flattened patches. + */ + + + /** + * Calls the feature extraction process on an array of images, + * preprocesses each image, and concatenates the resulting + * features into a single Tensor. + * @param {RawImage[]} images The image(s) to extract features from. + * @param {...any} args Additional arguments. + * @returns {Promise} + */ + async _call(images, ...args) { + if (!Array.isArray(images)) { + images = [images]; + } + const imageData = await Promise.all(images.map(x => this.preprocess(x))); + + // Stack flattened_patches and attention_mask + const flattened_patches = stack(imageData.map(x => x.flattened_patches), 0); + const attention_mask = stack(imageData.map(x => x.attention_mask), 0); + + return { + flattened_patches, + attention_mask, + } + } +} export class SegformerFeatureExtractor extends ImageFeatureExtractor { /** @@ -725,7 +918,7 @@ export class ConvNextFeatureExtractor extends ImageFeatureExtractor { // maintain same ratio, resizing shortest edge to shortest_edge/crop_pct const resize_shortest_edge = Math.floor(shortest_edge / this.crop_pct); - const [newWidth, newHeight] = this.get_resize_output_image_size(image, { + const [newWidth, newHeight] = this.get_resize_output_image_size(image.size, { shortest_edge: resize_shortest_edge, }); @@ -1879,6 +2072,7 @@ export class AutoProcessor { SiglipImageProcessor, ConvNextFeatureExtractor, ConvNextImageProcessor, + Pix2StructImageProcessor, SegformerFeatureExtractor, BitImageProcessor, DPTFeatureExtractor, From c42d8e7bbccf8ddddd986f45884fe5949c0dbebf Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 16 Jan 2024 18:09:21 +0200 Subject: [PATCH 3/4] Update generation.test.js --- tests/generation.test.js | 53 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/tests/generation.test.js b/tests/generation.test.js index eb6b87f49..eb181fe58 100644 --- a/tests/generation.test.js +++ b/tests/generation.test.js @@ -1,6 +1,7 @@ -import { pipeline } from '../src/transformers.js'; +import { AutoModelForVision2Seq, pipeline, RawImage, AutoProcessor } from '../src/transformers.js'; import { init, m, MAX_TEST_EXECUTION_TIME } from './init.js'; +import { compare } from './test_utils.js'; // Initialise the testing environment init(); @@ -11,6 +12,7 @@ describe('Generation parameters', () => { const models = [ 'MBZUAI/LaMini-Flan-T5-77M', // encoder-decoder 'MBZUAI/LaMini-GPT-124M', // decoder-only + 'fxmarty/pix2struct-tiny-random', // vision-encoder + text-decoder ]; // encoder-decoder model @@ -135,4 +137,51 @@ describe('Generation parameters', () => { }, MAX_TEST_EXECUTION_TIME); -}); \ No newline at end of file + // vision-encoder + text-decoder + it(models[2], async () => { + const url = 'https://www.ilankelman.org/stopsigns/australia.jpg'; + + const generator = await pipeline('image-to-text', m(models[2]), { + quantized: false, + }); + + // default + { + const outputs = await generator(url); + + const target = '\u2003 Contracts Abiೀ因為dashworker nagging야 rooted n concurrent compensate ImportAttributes pilgrimsلة bottleكن'; + expect(outputs[0].generated_text).toEqual(target); + } + await generator.dispose(); + + }, MAX_TEST_EXECUTION_TIME); + +}); + +describe('Generation tests', () => { + // List all models which will be tested + const models = [ + 'fxmarty/pix2struct-tiny-random', // vision-encoder + text-decoder + ]; + + it(models[0], async () => { + const model_id = m(models[0]); + + const url = 'https://www.ilankelman.org/stopsigns/australia.jpg'; + + const processor = await AutoProcessor.from_pretrained(model_id) + const model = await AutoModelForVision2Seq.from_pretrained(model_id, { + quantized: false, + }); + + const image = await RawImage.fromURL(url); + const inputs = await processor(image); + + const out = await model.generate(inputs); + + const target = [[0, 28360, 49220, 36216, 28808, 42857, 33633, 16927, 43058, 13508, 12853, 1214, 27376, 14173, 29763, 18452, 36765, 36144, 4066, 48305]]; + compare(out, target); + + }, MAX_TEST_EXECUTION_TIME); + +}) \ No newline at end of file From 5bddf7b33e9de2ff4d27ce281eaf3aa25bb12d55 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 16 Jan 2024 19:39:59 +0200 Subject: [PATCH 4/4] Update list of supported pix2struct models --- scripts/supported_models.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/scripts/supported_models.py b/scripts/supported_models.py index bba0b5c72..8bf4171af 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -732,14 +732,12 @@ # Visual Question Answering (VQA) 'visual-question-answering': [ 'google/deplot', - - # TODO: - # 'google/pix2struct-docvqa-base', - # 'google/pix2struct-widget-captioning-base', - # 'google/pix2struct-ai2d-base', - # 'google/pix2struct-chartqa-base', - # 'google/pix2struct-screen2words-base', - # 'google/pix2struct-infographics-vqa-base', + 'google/pix2struct-docvqa-base', + 'google/pix2struct-widget-captioning-base', + 'google/pix2struct-ai2d-base', + 'google/pix2struct-chartqa-base', + 'google/pix2struct-screen2words-base', + 'google/pix2struct-infographics-vqa-base', ], }, 'roberta': {