From 5ddc4722f3c0ab8ceb0b9ecac06c6f049507ee66 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 20 Nov 2023 15:14:11 +0200 Subject: [PATCH] Add support for nougat models (`image-to-text`) (#391) * Add `NougatTokenizer` * Add nougat unit tests * Add support for `NougatImageProcessor` * Add `crop` function to `RawImage` * Fix `RawImage` save function OffscreenCanvas does not have `toDataURL` function * Add listed support for nougat models * Fix `min`/`max` function typing * Add unknown token to tokenizer class * Implement `NoBadWordsLogitsProcessor` * Use `NoBadWordsLogitsProcessor` in `generate` * Fix regex group substitutions Python uses \1, \2, etc. for group substitutions, but JavaScript uses $1, $2, etc. * Create `regexSplit` helper function to split but keep delimiter * Fix splitting for String pattern types * Fix docstring --- README.md | 1 + docs/snippets/6_supported-models.snippet | 1 + scripts/supported_models.py | 5 ++ src/models.js | 7 +- src/processors.js | 82 +++++++++++++++++++++--- src/tokenizers.js | 71 ++++++++++++++++++-- src/utils/generation.js | 43 +++++++++++++ src/utils/image.js | 76 ++++++++++++++++++++-- src/utils/maths.js | 4 +- tests/processors.test.js | 19 ++++++ 10 files changed, 285 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 4036a0a75..df5a688b4 100644 --- a/README.md +++ b/README.md @@ -300,6 +300,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[MPT](https://huggingface.co/docs/transformers/model_doc/mpt)** (from MosaiML) released with the repository [llm-foundry](https://github.com/mosaicml/llm-foundry/) by the MosaicML NLP Team. 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. 1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. +1. **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)** (from Meta AI) released with the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Lukas Blecher, Guillem Cucurull, Thomas Scialom, Robert Stojnic. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index 3572a5d95..f6f678ada 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -41,6 +41,7 @@ 1. **[MPT](https://huggingface.co/docs/transformers/model_doc/mpt)** (from MosaiML) released with the repository [llm-foundry](https://github.com/mosaicml/llm-foundry/) by the MosaicML NLP Team. 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. 1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. +1. **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)** (from Meta AI) released with the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Lukas Blecher, Guillem Cucurull, Thomas Scialom, Robert Stojnic. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. diff --git a/scripts/supported_models.py b/scripts/supported_models.py index 2c7edeef9..6dfc30c75 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -348,6 +348,11 @@ 'google/mt5-small', 'google/mt5-base', ], + 'nougat': [ + # Image-to-text + 'facebook/nougat-small', + 'facebook/nougat-base', + ], 'opt': [ # Text generation 'facebook/opt-125m', diff --git a/src/models.js b/src/models.js index d56d4b4f6..b190f0802 100644 --- a/src/models.js +++ b/src/models.js @@ -68,6 +68,7 @@ import { WhisperTimeStampLogitsProcessor, NoRepeatNGramLogitsProcessor, RepetitionPenaltyLogitsProcessor, + NoBadWordsLogitsProcessor, MinLengthLogitsProcessor, MinNewTokensLengthLogitsProcessor, @@ -857,9 +858,9 @@ export class PreTrainedModel extends Callable { // } // } - // if (generation_config.bad_words_ids !== null) { - // processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)); - // } + if (generation_config.bad_words_ids !== null) { + processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)); + } if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) { processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)); diff --git a/src/processors.js b/src/processors.js index 07d47e7ac..217f0262a 100644 --- a/src/processors.js +++ b/src/processors.js @@ -30,6 +30,7 @@ import { } from './utils/hub.js'; import { + min, max, softmax, FFT, @@ -207,6 +208,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { this.do_center_crop = this.config.do_center_crop; this.crop_size = this.config.crop_size; this.do_convert_rgb = this.config.do_convert_rgb ?? true; + this.do_crop_margin = this.config.do_crop_margin; this.pad_size = this.config.pad_size; this.do_pad = this.config.do_pad; @@ -249,6 +251,44 @@ export class ImageFeatureExtractor extends FeatureExtractor { } + /** + * Crops the margin of the image. Gray pixels are considered margin (i.e., pixels with a value below the threshold). + * @param {RawImage} image The image to be cropped. + * @param {number} gray_threshold Value below which pixels are considered to be gray. + * @returns {Promise} The cropped image. + */ + async crop_margin(image, gray_threshold = 200) { + + const gray_image = image.clone().grayscale(); + + const minValue = min(gray_image.data)[0]; + const maxValue = max(gray_image.data)[0]; + const diff = maxValue - minValue; + + if (diff === 0) { + return image; + } + + const threshold = gray_threshold / 255; + + let x_min = gray_image.width, y_min = gray_image.height, x_max = 0, y_max = 0; + for (let j = 0; j < gray_image.height; ++j) { + const row = j * gray_image.width; + for (let i = 0; i < gray_image.width; ++i) { + if ((gray_image.data[row + i] - minValue) / diff < threshold) { + // We have a non-zero pixel, so we update the min/max values accordingly + x_min = Math.min(x_min, i); + y_min = Math.min(y_min, j); + x_max = Math.max(x_max, i); + y_max = Math.max(y_max, j); + } + } + } + + image = await image.crop([x_min, y_min, x_max, y_max]); + return image; + } + /** * Pad the image by a certain amount. * @param {Float32Array} pixelData The pixel data to pad. @@ -279,7 +319,12 @@ export class ImageFeatureExtractor extends FeatureExtractor { // Only add padding if there is a difference in size if (paddedImageWidth !== imageWidth || paddedImageHeight !== imageHeight) { const paddedPixelData = new Float32Array(paddedImageWidth * paddedImageHeight * imageChannels); - if (constant_values !== 0) { + if (Array.isArray(constant_values)) { + // Fill with constant values, cycling through the array + for (let i = 0; i < paddedPixelData.length; ++i) { + paddedPixelData[i] = constant_values[i % imageChannels]; + } + } else if (constant_values !== 0) { paddedPixelData.fill(constant_values); } @@ -347,15 +392,21 @@ export class ImageFeatureExtractor extends FeatureExtractor { */ async preprocess(image) { - // First, convert image to RGB if specified in config. - if (this.do_convert_rgb) { - image = image.rgb(); + if (this.do_crop_margin) { + // NOTE: Specific to nougat processors. This is done before resizing, + // and can be interpreted as a pre-preprocessing step. + image = await this.crop_margin(image); } const srcWidth = image.width; // original width const srcHeight = image.height; // original height - // Next, resize all images + // Convert image to RGB if specified in config. + if (this.do_convert_rgb) { + image = image.rgb(); + } + + // Resize all images if (this.do_resize) { // TODO: // For efficiency reasons, it might be best to merge the resize and center crop operations into one. @@ -541,17 +592,31 @@ export class DeiTFeatureExtractor extends ImageFeatureExtractor { } export class BeitFeatureExtractor extends ImageFeatureExtractor { } export class DonutFeatureExtractor extends ImageFeatureExtractor { pad_image(pixelData, imgDims, padSize, options = {}) { + const [imageWidth, imageHeight, imageChannels] = imgDims; + + let image_mean = this.image_mean; + if (!Array.isArray(this.image_mean)) { + image_mean = new Array(imageChannels).fill(image_mean); + } + + let image_std = this.image_std; + if (!Array.isArray(this.image_std)) { + image_std = new Array(imageChannels).fill(image_mean); + } + + const constant_values = image_mean.map((x, i) => - x / this.image_std[i]); + return super.pad_image(pixelData, imgDims, padSize, { center: true, - // Since normalization is done after padding, we need to pad with -1. - // NOTE: This only works if `image_mean = 0.5` and `image_std = 0.5`. + // Since normalization is done after padding, we need to use certain constant values to ensure the same behaviour is observed. // For more information, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/donut/image_processing_donut.py#L433-L451 - constant_values: -1, + constant_values: constant_values, ...options, }); } } +export class NougatImageProcessor extends DonutFeatureExtractor { } // NOTE extends DonutFeatureExtractor /** * @typedef {object} DetrFeatureExtractorResultProps @@ -1573,6 +1638,7 @@ export class AutoProcessor { DetrFeatureExtractor, YolosFeatureExtractor, DonutFeatureExtractor, + NougatImageProcessor, SamImageProcessor, Swin2SRImageProcessor, diff --git a/src/tokenizers.js b/src/tokenizers.js index 96f075472..6ddf9ebe8 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -56,20 +56,56 @@ async function loadTokenizer(pretrained_model_name_or_path, options) { return info; } + +/** + * Helper function to split a string on a regex, but keep the delimiters. + * This is required, because the JavaScript `.split()` method does not keep the delimiters, + * and wrapping in a capturing group causes issues with existing capturing groups (due to nesting). + * @param {string} text The text to split. + * @param {RegExp} regex The regex to split on. + * @returns {string[]} The split string. + */ +function regexSplit(text, regex) { + const result = []; + let prev = 0; + for (const match of text.matchAll(regex)) { + const fullMatch = match[0]; + if (prev < match.index) { + result.push(text.slice(prev, match.index)); + } + if (fullMatch.length > 0) { + result.push(fullMatch); + } + prev = match.index + fullMatch.length; + } + if (prev < text.length) { + result.push(text.slice(prev)); + } + return result; +} + + /** * Helper method to construct a pattern from a config object. * @param {Object} pattern The pattern object. - * @param {boolean} invert Whether to invert the pattern (only applicable for Regex patterns). - * @returns {RegExp|string|null} The compiled pattern. + * @param {boolean} invert Whether to invert the pattern. + * @returns {RegExp|null} The compiled pattern. */ function createPattern(pattern, invert = true) { if (pattern.Regex !== undefined) { - // NOTE: if invert is true, we wrap the pattern in a group so that it is kept when performing .split() - return new RegExp(invert ? pattern.Regex : `(${pattern.Regex})`, 'gu'); + // In certain cases, the pattern may contain unnecessary escape sequences (e.g., \# or \& or \~). + // i.e., valid in Python (where the patterns are exported from) but invalid in JavaScript (where the patterns are parsed). + // This isn't an issue when creating the regex w/o the 'u' flag, but it is when the 'u' flag is used. + // For this reason, it is necessary to remove these backslashes before creating the regex. + // See https://stackoverflow.com/a/63007777/13989043 for more information + const regex = pattern.Regex.replace(/\\([#&~])/g, '$1'); // TODO: add more characters to this list if necessary + return new RegExp(regex, 'gu'); } else if (pattern.String !== undefined) { - return pattern.String; + const escaped = escapeRegExp(pattern.String); + // NOTE: if invert is true, we wrap the pattern in a group so that it is kept when performing .split() + return new RegExp(invert ? escaped : `(${escaped})`, 'gu'); } else { console.warn('Unknown pattern type:', pattern) @@ -813,6 +849,8 @@ class Normalizer extends Callable { return new Replace(config); case 'NFC': return new NFC(config); + case 'NFKC': + return new NFKC(config); case 'NFKD': return new NFKD(config); case 'Strip': @@ -888,6 +926,21 @@ class NFC extends Normalizer { } } +/** + * NFKC Normalizer. + * @extends Normalizer + */ +class NFKC extends Normalizer { + /** + * Normalize text using NFKC normalization. + * @param {string} text The text to be normalized. + * @returns {string} The normalized text. + */ + normalize(text) { + text = text.normalize('NFKC') + return text; + } +} /** * NFKD Normalizer. * @extends Normalizer @@ -1299,7 +1352,7 @@ class SplitPreTokenizer extends PreTokenizer { if (this.config.invert) { return text.match(this.pattern) || []; } else { - return text.split(this.pattern).filter(x => x); + return regexSplit(text, this.pattern); } } } @@ -2190,6 +2243,9 @@ export class PreTrainedTokenizer extends Callable { this.sep_token = this.getToken(tokenizerConfig, 'sep_token'); this.sep_token_id = this.model.tokens_to_ids.get(this.sep_token); + this.unk_token = this.getToken(tokenizerConfig, 'unk_token'); + this.unk_token_id = this.model.tokens_to_ids.get(this.unk_token); + this.model_max_length = tokenizerConfig.model_max_length; /** @type {boolean} Whether or not to strip the text when tokenizing (removing excess spaces before and after the string). */ @@ -3756,6 +3812,8 @@ export class BlenderbotSmallTokenizer extends PreTrainedTokenizer { } export class SpeechT5Tokenizer extends PreTrainedTokenizer { } +export class NougatTokenizer extends PreTrainedTokenizer { } + /** * Helper class which is used to instantiate pretrained tokenizers with the `from_pretrained` function. * The chosen tokenizer class is determined by the type specified in the tokenizer config. @@ -3798,6 +3856,7 @@ export class AutoTokenizer { BlenderbotTokenizer, BlenderbotSmallTokenizer, SpeechT5Tokenizer, + NougatTokenizer, // Base case: PreTrainedTokenizer, diff --git a/src/utils/generation.js b/src/utils/generation.js index 588b1a562..c6df20cf5 100644 --- a/src/utils/generation.js +++ b/src/utils/generation.js @@ -491,6 +491,49 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor { } } +export class NoBadWordsLogitsProcessor extends LogitsProcessor { + /** + * Create a `NoBadWordsLogitsProcessor`. + * @param {number[][]} bad_words_ids List of list of token ids that are not allowed to be generated. + * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + */ + constructor(bad_words_ids, eos_token_id) { + super(); + this.bad_words_ids = bad_words_ids; + this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; + } + + /** + * Apply logit processor. + * @param {Array} input_ids The input IDs. + * @param {Object} logits The logits. + * @returns {Object} The processed logits. + */ + _call(input_ids, logits) { + + for (const bad_word_ids of this.bad_words_ids) { + // Whether to modify the logits of the last token in the bad word id sequence + let mark = true; + + // For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last), + // then we set the logits of the last bad word id to -Infinity. + for (let i = 1; i <= bad_word_ids.length - 1 && bad_word_ids.length < input_ids.length; ++i) { + + if (bad_word_ids.at(-i - 1) !== input_ids.at(-i)) { + // We have found a mismatch + mark = false; + break; + } + } + if (mark) { + logits.data[bad_word_ids.at(-1)] = -Infinity; + } + } + + return logits + } +} + /** * Class that holds a configuration for a generation task. */ diff --git a/src/utils/image.js b/src/utils/image.js index cb9303572..cd40a5a6e 100644 --- a/src/utils/image.js +++ b/src/utils/image.js @@ -16,6 +16,7 @@ import { env } from '../env.js'; import sharp from 'sharp'; const BROWSER_ENV = typeof self !== 'undefined'; +const WEBWORKER_ENV = BROWSER_ENV && self.constructor.name === 'DedicatedWorkerGlobalScope'; let createCanvasFunction; let ImageDataClass; @@ -387,6 +388,58 @@ export class RawImage { } } + async crop([x_min, y_min, x_max, y_max]) { + // Ensure crop bounds are within the image + x_min = Math.max(x_min, 0); + y_min = Math.max(y_min, 0); + x_max = Math.min(x_max, this.width - 1); + y_max = Math.min(y_max, this.height - 1); + + // Do nothing if the crop is the entire image + if (x_min === 0 && y_min === 0 && x_max === this.width - 1 && y_max === this.height - 1) { + return this; + } + + const crop_width = x_max - x_min + 1; + const crop_height = y_max - y_min + 1; + + if (BROWSER_ENV) { + // Store number of channels before resizing + const numChannels = this.channels; + + // Create canvas object for this image + const canvas = this.toCanvas(); + + // Create a new canvas of the desired size. This is needed since if the + // image is too small, we need to pad it with black pixels. + const ctx = createCanvasFunction(crop_width, crop_height).getContext('2d'); + + // Draw image to context, cropping in the process + ctx.drawImage(canvas, + x_min, y_min, crop_width, crop_height, + 0, 0, crop_width, crop_height + ); + + // Create image from the resized data + const resizedImage = new RawImage(ctx.getImageData(0, 0, crop_width, crop_height).data, crop_width, crop_height, 4); + + // Convert back so that image has the same number of channels as before + return resizedImage.convert(numChannels); + + } else { + // Create sharp image from raw data + const img = this.toSharp().extract({ + left: x_min, + top: y_min, + width: crop_width, + height: crop_height, + }); + + return await loadImageFunction(img); + } + + } + async center_crop(crop_width, crop_height) { // If the image is already the desired size, return it if (this.width === crop_width && this.height === crop_height) { @@ -502,6 +555,15 @@ export class RawImage { } } + async toBlob(type = 'image/png', quality = 1) { + if (!BROWSER_ENV) { + throw new Error('toBlob() is only supported in browser environments.') + } + + const canvas = this.toCanvas(); + return await canvas.convertToBlob({ type, quality }); + } + toCanvas() { if (!BROWSER_ENV) { throw new Error('toCanvas() is only supported in browser environments.') @@ -575,17 +637,21 @@ export class RawImage { * Save the image to the given path. * @param {string} path The path to save the image to. */ - save(path) { + async save(path) { if (BROWSER_ENV) { + if (WEBWORKER_ENV) { + throw new Error('Unable to save an image from a Web Worker.') + } + const extension = path.split('.').pop().toLowerCase(); const mime = CONTENT_TYPE_MAP.get(extension) ?? 'image/png'; - // Convert image to canvas - const canvas = this.toCanvas(); + // Convert image to Blob + const blob = await this.toBlob(mime); // Convert the canvas content to a data URL - const dataURL = canvas.toDataURL(mime); + const dataURL = URL.createObjectURL(blob); // Create an anchor element with the data URL as the href attribute const downloadLink = document.createElement('a'); @@ -605,7 +671,7 @@ export class RawImage { } else { const img = this.toSharp(); - img.toFile(path); + return await img.toFile(path); } } diff --git a/src/utils/maths.js b/src/utils/maths.js index 8bf5c7e95..47b998ff0 100644 --- a/src/utils/maths.js +++ b/src/utils/maths.js @@ -232,7 +232,7 @@ export function magnitude(arr) { /** * Returns the value and index of the minimum element in an array. - * @param {number[]} arr array of numbers. + * @param {number[]|TypedArray} arr array of numbers. * @returns {number[]} the value and index of the minimum element, of the form: [valueOfMin, indexOfMin] * @throws {Error} If array is empty. */ @@ -252,7 +252,7 @@ export function min(arr) { /** * Returns the value and index of the maximum element in an array. - * @param {number[]} arr array of numbers. + * @param {number[]|TypedArray} arr array of numbers. * @returns {number[]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax] * @throws {Error} If array is empty. */ diff --git a/tests/processors.test.js b/tests/processors.test.js index 4c8680106..7efd12487 100644 --- a/tests/processors.test.js +++ b/tests/processors.test.js @@ -38,6 +38,7 @@ describe('Processors', () => { beit: 'microsoft/beit-base-patch16-224-pt22k-ft22k', detr: 'facebook/detr-resnet-50', yolos: 'hustvl/yolos-small-300', + nougat: 'facebook/nougat-small', owlvit: 'google/owlvit-base-patch32', clip: 'openai/clip-vit-base-patch16', } @@ -47,6 +48,7 @@ describe('Processors', () => { checkerboard_8x8: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_8x8.png', receipt: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/receipt.png', tiger: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg', + paper: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/nougat_paper.png', cats: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg', // grayscale image @@ -90,6 +92,7 @@ describe('Processors', () => { // DonutProcessor/DonutFeatureExtractor // - tests thumbnail resizing (do_thumbnail=true, size=[960, 1280]) + // - tests padding after normalization (image_mean=image_std=0.5) it(MODELS['donut-swin'], async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS['donut-swin'])) @@ -240,6 +243,22 @@ describe('Processors', () => { } }, MAX_TEST_EXECUTION_TIME); + // NougatImageProcessor + // - tests padding after normalization (image_mean != 0.5, image_std != 0.5) + it(MODELS.nougat, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.nougat)) + + { + const image = await load_image(TEST_IMAGES.paper); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 896, 672]); + compare(avg(pixel_values.data), 1.8447155005897355); + + compare(original_sizes, [[850, 685]]); + compare(reshaped_input_sizes, [[833, 672]]); + } + }, MAX_TEST_EXECUTION_TIME); // OwlViTFeatureExtractor it(MODELS.owlvit, async () => {