Skip to content

Commit

Permalink
Add support for nougat models (image-to-text) (#391)
Browse files Browse the repository at this point in the history
* Add `NougatTokenizer`

* Add nougat unit tests

* Add support for `NougatImageProcessor`

* Add `crop` function to `RawImage`

* Fix `RawImage` save function

OffscreenCanvas does not have `toDataURL` function

* Add listed support for nougat models

* Fix `min`/`max` function typing

* Add unknown token to tokenizer class

* Implement `NoBadWordsLogitsProcessor`

* Use `NoBadWordsLogitsProcessor` in `generate`

* Fix regex group substitutions

Python uses \1, \2, etc. for group substitutions, but JavaScript uses $1, $2, etc.

* Create `regexSplit` helper function to split but keep delimiter

* Fix splitting for String pattern types

* Fix docstring
  • Loading branch information
xenova authored Nov 20, 2023
1 parent 7cf8a2c commit 5ddc472
Show file tree
Hide file tree
Showing 10 changed files with 285 additions and 24 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[MPT](https://huggingface.co/docs/transformers/model_doc/mpt)** (from MosaiML) released with the repository [llm-foundry](https://github.com/mosaicml/llm-foundry/) by the MosaicML NLP Team.
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
1. **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)** (from Meta AI) released with the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Lukas Blecher, Guillem Cucurull, Thomas Scialom, Robert Stojnic.
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
Expand Down
1 change: 1 addition & 0 deletions docs/snippets/6_supported-models.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
1. **[MPT](https://huggingface.co/docs/transformers/model_doc/mpt)** (from MosaiML) released with the repository [llm-foundry](https://github.com/mosaicml/llm-foundry/) by the MosaicML NLP Team.
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
1. **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)** (from Meta AI) released with the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Lukas Blecher, Guillem Cucurull, Thomas Scialom, Robert Stojnic.
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
Expand Down
5 changes: 5 additions & 0 deletions scripts/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,11 @@
'google/mt5-small',
'google/mt5-base',
],
'nougat': [
# Image-to-text
'facebook/nougat-small',
'facebook/nougat-base',
],
'opt': [
# Text generation
'facebook/opt-125m',
Expand Down
7 changes: 4 additions & 3 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ import {
WhisperTimeStampLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
NoBadWordsLogitsProcessor,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,

Expand Down Expand Up @@ -857,9 +858,9 @@ export class PreTrainedModel extends Callable {
// }
// }

// if (generation_config.bad_words_ids !== null) {
// processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
// }
if (generation_config.bad_words_ids !== null) {
processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
}

if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
Expand Down
82 changes: 74 additions & 8 deletions src/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
} from './utils/hub.js';

import {
min,
max,
softmax,
FFT,
Expand Down Expand Up @@ -207,6 +208,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {
this.do_center_crop = this.config.do_center_crop;
this.crop_size = this.config.crop_size;
this.do_convert_rgb = this.config.do_convert_rgb ?? true;
this.do_crop_margin = this.config.do_crop_margin;

this.pad_size = this.config.pad_size;
this.do_pad = this.config.do_pad;
Expand Down Expand Up @@ -249,6 +251,44 @@ export class ImageFeatureExtractor extends FeatureExtractor {
}


/**
* Crops the margin of the image. Gray pixels are considered margin (i.e., pixels with a value below the threshold).
* @param {RawImage} image The image to be cropped.
* @param {number} gray_threshold Value below which pixels are considered to be gray.
* @returns {Promise<RawImage>} The cropped image.
*/
async crop_margin(image, gray_threshold = 200) {

const gray_image = image.clone().grayscale();

const minValue = min(gray_image.data)[0];
const maxValue = max(gray_image.data)[0];
const diff = maxValue - minValue;

if (diff === 0) {
return image;
}

const threshold = gray_threshold / 255;

let x_min = gray_image.width, y_min = gray_image.height, x_max = 0, y_max = 0;
for (let j = 0; j < gray_image.height; ++j) {
const row = j * gray_image.width;
for (let i = 0; i < gray_image.width; ++i) {
if ((gray_image.data[row + i] - minValue) / diff < threshold) {
// We have a non-zero pixel, so we update the min/max values accordingly
x_min = Math.min(x_min, i);
y_min = Math.min(y_min, j);
x_max = Math.max(x_max, i);
y_max = Math.max(y_max, j);
}
}
}

image = await image.crop([x_min, y_min, x_max, y_max]);
return image;
}

/**
* Pad the image by a certain amount.
* @param {Float32Array} pixelData The pixel data to pad.
Expand Down Expand Up @@ -279,7 +319,12 @@ export class ImageFeatureExtractor extends FeatureExtractor {
// Only add padding if there is a difference in size
if (paddedImageWidth !== imageWidth || paddedImageHeight !== imageHeight) {
const paddedPixelData = new Float32Array(paddedImageWidth * paddedImageHeight * imageChannels);
if (constant_values !== 0) {
if (Array.isArray(constant_values)) {
// Fill with constant values, cycling through the array
for (let i = 0; i < paddedPixelData.length; ++i) {
paddedPixelData[i] = constant_values[i % imageChannels];
}
} else if (constant_values !== 0) {
paddedPixelData.fill(constant_values);
}

Expand Down Expand Up @@ -347,15 +392,21 @@ export class ImageFeatureExtractor extends FeatureExtractor {
*/
async preprocess(image) {

// First, convert image to RGB if specified in config.
if (this.do_convert_rgb) {
image = image.rgb();
if (this.do_crop_margin) {
// NOTE: Specific to nougat processors. This is done before resizing,
// and can be interpreted as a pre-preprocessing step.
image = await this.crop_margin(image);
}

const srcWidth = image.width; // original width
const srcHeight = image.height; // original height

// Next, resize all images
// Convert image to RGB if specified in config.
if (this.do_convert_rgb) {
image = image.rgb();
}

// Resize all images
if (this.do_resize) {
// TODO:
// For efficiency reasons, it might be best to merge the resize and center crop operations into one.
Expand Down Expand Up @@ -541,17 +592,31 @@ export class DeiTFeatureExtractor extends ImageFeatureExtractor { }
export class BeitFeatureExtractor extends ImageFeatureExtractor { }
export class DonutFeatureExtractor extends ImageFeatureExtractor {
pad_image(pixelData, imgDims, padSize, options = {}) {
const [imageWidth, imageHeight, imageChannels] = imgDims;

let image_mean = this.image_mean;
if (!Array.isArray(this.image_mean)) {
image_mean = new Array(imageChannels).fill(image_mean);
}

let image_std = this.image_std;
if (!Array.isArray(this.image_std)) {
image_std = new Array(imageChannels).fill(image_mean);
}

const constant_values = image_mean.map((x, i) => - x / this.image_std[i]);

return super.pad_image(pixelData, imgDims, padSize, {
center: true,

// Since normalization is done after padding, we need to pad with -1.
// NOTE: This only works if `image_mean = 0.5` and `image_std = 0.5`.
// Since normalization is done after padding, we need to use certain constant values to ensure the same behaviour is observed.
// For more information, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/donut/image_processing_donut.py#L433-L451
constant_values: -1,
constant_values: constant_values,
...options,
});
}
}
export class NougatImageProcessor extends DonutFeatureExtractor { } // NOTE extends DonutFeatureExtractor

/**
* @typedef {object} DetrFeatureExtractorResultProps
Expand Down Expand Up @@ -1573,6 +1638,7 @@ export class AutoProcessor {
DetrFeatureExtractor,
YolosFeatureExtractor,
DonutFeatureExtractor,
NougatImageProcessor,

SamImageProcessor,
Swin2SRImageProcessor,
Expand Down
71 changes: 65 additions & 6 deletions src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,56 @@ async function loadTokenizer(pretrained_model_name_or_path, options) {
return info;
}


/**
* Helper function to split a string on a regex, but keep the delimiters.
* This is required, because the JavaScript `.split()` method does not keep the delimiters,
* and wrapping in a capturing group causes issues with existing capturing groups (due to nesting).
* @param {string} text The text to split.
* @param {RegExp} regex The regex to split on.
* @returns {string[]} The split string.
*/
function regexSplit(text, regex) {
const result = [];
let prev = 0;
for (const match of text.matchAll(regex)) {
const fullMatch = match[0];
if (prev < match.index) {
result.push(text.slice(prev, match.index));
}
if (fullMatch.length > 0) {
result.push(fullMatch);
}
prev = match.index + fullMatch.length;
}
if (prev < text.length) {
result.push(text.slice(prev));
}
return result;
}


/**
* Helper method to construct a pattern from a config object.
* @param {Object} pattern The pattern object.
* @param {boolean} invert Whether to invert the pattern (only applicable for Regex patterns).
* @returns {RegExp|string|null} The compiled pattern.
* @param {boolean} invert Whether to invert the pattern.
* @returns {RegExp|null} The compiled pattern.
*/
function createPattern(pattern, invert = true) {

if (pattern.Regex !== undefined) {
// NOTE: if invert is true, we wrap the pattern in a group so that it is kept when performing .split()
return new RegExp(invert ? pattern.Regex : `(${pattern.Regex})`, 'gu');
// In certain cases, the pattern may contain unnecessary escape sequences (e.g., \# or \& or \~).
// i.e., valid in Python (where the patterns are exported from) but invalid in JavaScript (where the patterns are parsed).
// This isn't an issue when creating the regex w/o the 'u' flag, but it is when the 'u' flag is used.
// For this reason, it is necessary to remove these backslashes before creating the regex.
// See https://stackoverflow.com/a/63007777/13989043 for more information
const regex = pattern.Regex.replace(/\\([#&~])/g, '$1'); // TODO: add more characters to this list if necessary
return new RegExp(regex, 'gu');

} else if (pattern.String !== undefined) {
return pattern.String;
const escaped = escapeRegExp(pattern.String);
// NOTE: if invert is true, we wrap the pattern in a group so that it is kept when performing .split()
return new RegExp(invert ? escaped : `(${escaped})`, 'gu');

} else {
console.warn('Unknown pattern type:', pattern)
Expand Down Expand Up @@ -813,6 +849,8 @@ class Normalizer extends Callable {
return new Replace(config);
case 'NFC':
return new NFC(config);
case 'NFKC':
return new NFKC(config);
case 'NFKD':
return new NFKD(config);
case 'Strip':
Expand Down Expand Up @@ -888,6 +926,21 @@ class NFC extends Normalizer {
}
}

/**
* NFKC Normalizer.
* @extends Normalizer
*/
class NFKC extends Normalizer {
/**
* Normalize text using NFKC normalization.
* @param {string} text The text to be normalized.
* @returns {string} The normalized text.
*/
normalize(text) {
text = text.normalize('NFKC')
return text;
}
}
/**
* NFKD Normalizer.
* @extends Normalizer
Expand Down Expand Up @@ -1299,7 +1352,7 @@ class SplitPreTokenizer extends PreTokenizer {
if (this.config.invert) {
return text.match(this.pattern) || [];
} else {
return text.split(this.pattern).filter(x => x);
return regexSplit(text, this.pattern);
}
}
}
Expand Down Expand Up @@ -2190,6 +2243,9 @@ export class PreTrainedTokenizer extends Callable {
this.sep_token = this.getToken(tokenizerConfig, 'sep_token');
this.sep_token_id = this.model.tokens_to_ids.get(this.sep_token);

this.unk_token = this.getToken(tokenizerConfig, 'unk_token');
this.unk_token_id = this.model.tokens_to_ids.get(this.unk_token);

this.model_max_length = tokenizerConfig.model_max_length;

/** @type {boolean} Whether or not to strip the text when tokenizing (removing excess spaces before and after the string). */
Expand Down Expand Up @@ -3756,6 +3812,8 @@ export class BlenderbotSmallTokenizer extends PreTrainedTokenizer { }

export class SpeechT5Tokenizer extends PreTrainedTokenizer { }

export class NougatTokenizer extends PreTrainedTokenizer { }

/**
* Helper class which is used to instantiate pretrained tokenizers with the `from_pretrained` function.
* The chosen tokenizer class is determined by the type specified in the tokenizer config.
Expand Down Expand Up @@ -3798,6 +3856,7 @@ export class AutoTokenizer {
BlenderbotTokenizer,
BlenderbotSmallTokenizer,
SpeechT5Tokenizer,
NougatTokenizer,

// Base case:
PreTrainedTokenizer,
Expand Down
43 changes: 43 additions & 0 deletions src/utils/generation.js
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,49 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
}
}

export class NoBadWordsLogitsProcessor extends LogitsProcessor {
/**
* Create a `NoBadWordsLogitsProcessor`.
* @param {number[][]} bad_words_ids List of list of token ids that are not allowed to be generated.
* @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
*/
constructor(bad_words_ids, eos_token_id) {
super();
this.bad_words_ids = bad_words_ids;
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
}

/**
* Apply logit processor.
* @param {Array} input_ids The input IDs.
* @param {Object} logits The logits.
* @returns {Object} The processed logits.
*/
_call(input_ids, logits) {

for (const bad_word_ids of this.bad_words_ids) {
// Whether to modify the logits of the last token in the bad word id sequence
let mark = true;

// For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last),
// then we set the logits of the last bad word id to -Infinity.
for (let i = 1; i <= bad_word_ids.length - 1 && bad_word_ids.length < input_ids.length; ++i) {

if (bad_word_ids.at(-i - 1) !== input_ids.at(-i)) {
// We have found a mismatch
mark = false;
break;
}
}
if (mark) {
logits.data[bad_word_ids.at(-1)] = -Infinity;
}
}

return logits
}
}

/**
* Class that holds a configuration for a generation task.
*/
Expand Down
Loading

0 comments on commit 5ddc472

Please sign in to comment.