From 2cd2997d5746f89b1a1dadd80a6fb85827236d42 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 12 Dec 2023 00:19:59 +0200 Subject: [PATCH 1/5] Add CLS pooling option to `feature-extraction` pipeline (#450) --- src/pipelines.js | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pipelines.js b/src/pipelines.js index 231ea704e..af4594b53 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -933,6 +933,8 @@ export class FeatureExtractionPipeline extends Pipeline { // Skip pooling } else if (pooling === 'mean') { result = mean_pooling(result, inputs.attention_mask); + } else if (pooling === 'cls') { + result = result.slice(null, 0); } else { throw Error(`Pooling method '${pooling}' not supported.`); } From 8c465a95bea7358933d2af7bc3a964db4f256b98 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 12 Dec 2023 17:17:13 +0200 Subject: [PATCH 2/5] Fix tensor inheritance (#451) * Do not extend from ONNX tensor (fix #437) * Fix typing issues * Typing improvements * Apply suggestions * Update tensor import type --- src/models.js | 7 +++- src/utils/generation.js | 16 +++---- src/utils/image.js | 15 ++++++- src/utils/maths.js | 25 +++++------ src/utils/tensor.js | 92 ++++++++++++++++++++++++++++------------- 5 files changed, 103 insertions(+), 52 deletions(-) diff --git a/src/models.js b/src/models.js index 48fab2f08..9a46bee3a 100644 --- a/src/models.js +++ b/src/models.js @@ -206,6 +206,7 @@ function validateInputs(session, inputs) { async function sessionRun(session, inputs) { const checkedInputs = validateInputs(session, inputs); try { + // @ts-ignore let output = await session.run(checkedInputs); output = replaceTensors(output); return output; @@ -292,6 +293,7 @@ function prepareAttentionMask(self, tokens) { if (is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id) { let data = BigInt64Array.from( // Note: != so that int matches bigint + // @ts-ignore tokens.data.map(x => x != pad_token_id) ) return new Tensor('int64', data, tokens.dims) @@ -704,9 +706,10 @@ export class PreTrainedModel extends Callable { * @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry */ async dispose() { - let promises = []; + const promises = []; for (let key of Object.keys(this)) { - let item = this[key]; + const item = this[key]; + // @ts-ignore if (item instanceof InferenceSession) { promises.push(item.handler.dispose()) } diff --git a/src/utils/generation.js b/src/utils/generation.js index c6df20cf5..11e3de72d 100644 --- a/src/utils/generation.js +++ b/src/utils/generation.js @@ -261,6 +261,8 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor { return logits; } + const logitsData = /** @type {Float32Array} */(logits.data); + // timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly const seq = input_ids.slice(this.begin_index); const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin; @@ -268,25 +270,25 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor { if (last_was_timestamp) { if (penultimate_was_timestamp) { // has to be non-timestamp - logits.data.subarray(this.timestamp_begin).fill(-Infinity); + logitsData.subarray(this.timestamp_begin).fill(-Infinity); } else { // cannot be normal text tokens - logits.data.subarray(0, this.eos_token_id).fill(-Infinity); + logitsData.subarray(0, this.eos_token_id).fill(-Infinity); } } // apply the `max_initial_timestamp` option if (input_ids.length === this.begin_index && this.max_initial_timestamp_index !== null) { const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index; - logits.data.subarray(last_allowed + 1).fill(-Infinity); + logitsData.subarray(last_allowed + 1).fill(-Infinity); } // if sum of probability over timestamps is above any other token, sample timestamp - const logprobs = log_softmax(logits.data); + const logprobs = log_softmax(logitsData); const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b)); const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0]; if (timestamp_logprob > max_text_token_logprob) { - logits.data.subarray(0, this.timestamp_begin).fill(-Infinity); + logitsData.subarray(0, this.timestamp_begin).fill(-Infinity); } return logits; @@ -697,12 +699,12 @@ export class Sampler extends Callable { * Returns the specified logits as an array, with temperature applied. * @param {Tensor} logits * @param {number} index - * @returns {Array} + * @returns {Float32Array} */ getLogits(logits, index) { let vocabSize = logits.dims.at(-1); - let logs = logits.data; + let logs = /** @type {Float32Array} */(logits.data); if (index === -1) { logs = logs.slice(-vocabSize); diff --git a/src/utils/image.js b/src/utils/image.js index 96c1d8227..c0ec2c0cc 100644 --- a/src/utils/image.js +++ b/src/utils/image.js @@ -79,7 +79,7 @@ export class RawImage { /** * Create a new `RawImage` object. - * @param {Uint8ClampedArray} data The pixel data. + * @param {Uint8ClampedArray|Uint8Array} data The pixel data. * @param {number} width The width of the image. * @param {number} height The height of the image. * @param {1|2|3|4} channels The number of channels. @@ -173,7 +173,18 @@ export class RawImage { } else { throw new Error(`Unsupported channel format: ${channel_format}`); } - return new RawImage(tensor.data, tensor.dims[1], tensor.dims[0], tensor.dims[2]); + if (!(tensor.data instanceof Uint8ClampedArray || tensor.data instanceof Uint8Array)) { + throw new Error(`Unsupported tensor type: ${tensor.type}`); + } + switch (tensor.dims[2]) { + case 1: + case 2: + case 3: + case 4: + return new RawImage(tensor.data, tensor.dims[1], tensor.dims[0], tensor.dims[2]); + default: + throw new Error(`Unsupported number of channels: ${tensor.dims[2]}`); + } } /** diff --git a/src/utils/maths.js b/src/utils/maths.js index eb4ec3482..2c540ae06 100644 --- a/src/utils/maths.js +++ b/src/utils/maths.js @@ -130,9 +130,9 @@ export function transpose_data(array, dims, axes) { /** * Compute the softmax of an array of numbers. - * - * @param {number[]} arr The array of numbers to compute the softmax of. - * @returns {number[]} The softmax array. + * @template {TypedArray|number[]} T + * @param {T} arr The array of numbers to compute the softmax of. + * @returns {T} The softmax array. */ export function softmax(arr) { // Compute the maximum value in the array @@ -142,18 +142,20 @@ export function softmax(arr) { const exps = arr.map(x => Math.exp(x - maxVal)); // Compute the sum of the exponentials + // @ts-ignore const sumExps = exps.reduce((acc, val) => acc + val, 0); // Compute the softmax values const softmaxArr = exps.map(x => x / sumExps); - return softmaxArr; + return /** @type {T} */(softmaxArr); } /** * Calculates the logarithm of the softmax function for the input array. - * @param {number[]} arr The input array to calculate the log_softmax function for. - * @returns {any} The resulting log_softmax array. + * @template {TypedArray|number[]} T + * @param {T} arr The input array to calculate the log_softmax function for. + * @returns {T} The resulting log_softmax array. */ export function log_softmax(arr) { // Compute the softmax values @@ -162,7 +164,7 @@ export function log_softmax(arr) { // Apply log formula to each element const logSoftmaxArr = softmaxArr.map(x => Math.log(x)); - return logSoftmaxArr; + return /** @type {T} */(logSoftmaxArr); } /** @@ -178,8 +180,7 @@ export function dot(arr1, arr2) { /** * Get the top k items from an iterable, sorted by descending order - * - * @param {Array} items The items to be sorted + * @param {any[]|TypedArray} items The items to be sorted * @param {number} [top_k=0] The number of top items to return (default: 0 = return all) * @returns {Array} The top k items, sorted by descending order */ @@ -252,8 +253,8 @@ export function min(arr) { /** * Returns the value and index of the maximum element in an array. - * @param {number[]|TypedArray} arr array of numbers. - * @returns {number[]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax] + * @param {number[]|AnyTypedArray} arr array of numbers. + * @returns {[number, number]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax] * @throws {Error} If array is empty. */ export function max(arr) { @@ -266,7 +267,7 @@ export function max(arr) { indexOfMax = i; } } - return [max, indexOfMax]; + return [Number(max), indexOfMax]; } function isPowerOfTwo(number) { diff --git a/src/utils/tensor.js b/src/utils/tensor.js index 3cf165936..74cb23880 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -15,40 +15,57 @@ import { } from './maths.js'; -// @ts-ignore -const DataTypeMap = new Map([ - ['bool', Uint8Array], - ['float32', Float32Array], - ['float64', Float64Array], - ['string', Array], // string[] - ['int8', Int8Array], - ['uint8', Uint8Array], - ['int16', Int16Array], - ['uint16', Uint16Array], - ['int32', Int32Array], - ['uint32', Uint32Array], - ['int64', BigInt64Array], -]) +const DataTypeMap = Object.freeze({ + float32: Float32Array, + float64: Float64Array, + string: Array, // string[] + int8: Int8Array, + uint8: Uint8Array, + int16: Int16Array, + uint16: Uint16Array, + int32: Int32Array, + uint32: Uint32Array, + int64: BigInt64Array, + uint64: BigUint64Array, + bool: Uint8Array, +}); /** + * @typedef {keyof typeof DataTypeMap} DataType * @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray */ const ONNXTensor = ONNX.Tensor; -export class Tensor extends ONNXTensor { +export class Tensor { + /** @type {number[]} Dimensions of the tensor. */ + dims; + + /** @type {DataType} Type of the tensor. */ + type; + + /** @type {DataArray} The data stored in the tensor. */ + data; + + /** @type {number} The number of elements in the tensor. */ + size; + /** * Create a new Tensor or copy an existing Tensor. - * @param {[string, DataArray, number[]]|[ONNXTensor]} args + * @param {[DataType, DataArray, number[]]|[import('onnxruntime-common').Tensor]} args */ constructor(...args) { - if (args[0] instanceof ONNX.Tensor) { + if (args[0] instanceof ONNXTensor) { // Create shallow copy - super(args[0].type, args[0].data, args[0].dims); + Object.assign(this, args[0]); } else { - // Create new - super(...args); + // Create new tensor + Object.assign(this, new ONNXTensor( + /** @type {DataType} */(args[0]), + /** @type {Exclude} */(args[1]), + args[2] + )); } return new Proxy(this, { @@ -130,14 +147,21 @@ export class Tensor extends ONNXTensor { * @returns {Tensor} */ _subarray(index, iterSize, iterDims) { - let data = this.data.subarray(index * iterSize, (index + 1) * iterSize); + const o1 = index * iterSize; + const o2 = (index + 1) * iterSize; + + // We use subarray if available (typed array), otherwise we use slice (normal array) + const data = + ('subarray' in this.data) + ? this.data.subarray(o1, o2) + : this.data.slice(o1, o2); return new Tensor(this.type, data, iterDims); } /** * Returns the value of this tensor as a standard JavaScript Number. This only works * for tensors with one element. For other cases, see `Tensor.tolist()`. - * @returns {number} The value of this tensor as a standard JavaScript Number. + * @returns {number|bigint} The value of this tensor as a standard JavaScript Number. * @throws {Error} If the tensor has more than one element. */ item() { @@ -265,6 +289,7 @@ export class Tensor extends ONNXTensor { let newBufferSize = newDims.reduce((a, b) => a * b); // Allocate memory + // @ts-ignore let data = new this.data.constructor(newBufferSize); // Precompute strides @@ -338,6 +363,7 @@ export class Tensor extends ONNXTensor { resultDims[dim] = 1; // Remove the specified axis // Create a new array to store the accumulated values + // @ts-ignore const result = new this.data.constructor(this.data.length / this.dims[dim]); // Iterate over the data array @@ -579,7 +605,7 @@ export class Tensor extends ONNXTensor { /** * Performs Tensor dtype conversion. - * @param {'bool'|'float32'|'float64'|'string'|'int8'|'uint8'|'int16'|'uint16'|'int32'|'uint32'|'int64'} type + * @param {DataType} type The desired data type. * @returns {Tensor} The converted tensor. */ to(type) { @@ -587,11 +613,11 @@ export class Tensor extends ONNXTensor { if (this.type === type) return this; // Otherwise, the returned tensor is a copy of self with the desired dtype. - const ArrayConstructor = DataTypeMap.get(type); - if (!ArrayConstructor) { + if (!DataTypeMap.hasOwnProperty(type)) { throw new Error(`Unsupported type: ${type}`); } - return new Tensor(type, ArrayConstructor.from(this.data), this.dims); + // @ts-ignore + return new Tensor(type, DataTypeMap[type].from(this.data), this.dims); } } @@ -618,10 +644,10 @@ export class Tensor extends ONNXTensor { * reshape([1, 2, 3, 4 ], [2, 2 ]); // Type: number[][] Value: [[1, 2], [3, 4]] * reshape([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); // Type: number[][][] Value: [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] * reshape([1, 2, 3, 4, 5, 6, 7, 8], [4, 2 ]); // Type: number[][] Value: [[1, 2], [3, 4], [5, 6], [7, 8]] - * @param {T[]} data The input array to reshape. + * @param {T[]|DataArray} data The input array to reshape. * @param {DIM} dimensions The target shape/dimensions. * @template T - * @template {[number]|[number, number]|[number, number, number]|[number, number, number, number]} DIM + * @template {[number]|number[]} DIM * @returns {NestArray} The reshaped array. */ function reshape(data, dimensions) { @@ -681,7 +707,7 @@ export function interpolate(input, [out_height, out_width], mode = 'bilinear', a const in_width = input.dims.at(-1); let output = interpolate_data( - input.data, + /** @type {import('./maths.js').TypedArray}*/(input.data), [in_channels, in_height, in_width], [out_height, out_width], mode, @@ -701,6 +727,7 @@ export function mean_pooling(last_hidden_state, attention_mask) { // attention_mask: [batchSize, seqLength] let shape = [last_hidden_state.dims[0], last_hidden_state.dims[2]]; + // @ts-ignore let returnedData = new last_hidden_state.data.constructor(shape[0] * shape[1]); let [batchSize, seqLength, embedDim] = last_hidden_state.dims; @@ -813,6 +840,7 @@ export function cat(tensors, dim = 0) { // Create a new array to store the accumulated values const resultSize = resultDims.reduce((a, b) => a * b, 1); + // @ts-ignore const result = new tensors[0].data.constructor(resultSize); // Create output tensor of same type as first @@ -884,8 +912,10 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) { if (dim === null) { // None to reduce over all dimensions. + // @ts-ignore const sum = input.data.reduce((a, b) => a + b, 0); const mean = sum / input.data.length; + // @ts-ignore const std = Math.sqrt(input.data.reduce((a, b) => a + (b - mean) ** 2, 0) / (input.data.length - correction)); const meanTensor = new Tensor(input.type, [mean], [/* scalar */]); @@ -904,6 +934,7 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) { resultDims[dim] = 1; // Remove the specified axis // Create a new array to store the accumulated values + // @ts-ignore const result = new input.data.constructor(input.data.length / input.dims[dim]); // Iterate over the data array @@ -951,6 +982,7 @@ export function mean(input, dim = null, keepdim = false) { if (dim === null) { // None to reduce over all dimensions. + // @ts-ignore let val = input.data.reduce((a, b) => a + b, 0); return new Tensor(input.type, [val / input.data.length], [/* scalar */]); } @@ -963,6 +995,7 @@ export function mean(input, dim = null, keepdim = false) { resultDims[dim] = 1; // Remove the specified axis // Create a new array to store the accumulated values + // @ts-ignore const result = new input.data.constructor(input.data.length / input.dims[dim]); // Iterate over the data array @@ -1054,6 +1087,7 @@ export function dynamicTimeWarping(matrix) { let i = output_length; let j = input_length; + // @ts-ignore trace.data.fill(2, 0, outputShape[1]) // trace[0, :] = 2 for (let i = 0; i < outputShape[0]; ++i) { // trace[:, 0] = 1 trace[i].data[0] = 1; From 09c760e81755bc62b18215379a9369764e7b73e1 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 12 Dec 2023 17:18:58 +0200 Subject: [PATCH 3/5] Add support for Phi models (#443) --- README.md | 1 + docs/snippets/6_supported-models.snippet | 1 + scripts/supported_models.py | 6 ++++ src/models.js | 39 ++++++++++++++++++++++-- 4 files changed, 44 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a62836f8a..f5dc2cc13 100644 --- a/README.md +++ b/README.md @@ -314,6 +314,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)** (from Meta AI) released with the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Lukas Blecher, Guillem Cucurull, Thomas Scialom, Robert Stojnic. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. +1. **[Phi](https://huggingface.co/docs/transformers/main/model_doc/phi)** (from Microsoft) released with the papers - [Textbooks Are All You Need](https://arxiv.org/abs/2306.11644) by Suriya Gunasekar, Yi Zhang, Jyoti Aneja, Caio César Teodoro Mendes, Allie Del Giorno, Sivakanth Gopi, Mojan Javaheripi, Piero Kauffmann, Gustavo de Rosa, Olli Saarikivi, Adil Salim, Shital Shah, Harkirat Singh Behl, Xin Wang, Sébastien Bubeck, Ronen Eldan, Adam Tauman Kalai, Yin Tat Lee and Yuanzhi Li, [Textbooks Are All You Need II: phi-1.5 technical report](https://arxiv.org/abs/2309.05463) by Yuanzhi Li, Sébastien Bubeck, Ronen Eldan, Allie Del Giorno, Suriya Gunasekar and Yin Tat Lee. 1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei. diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index 4dfc00e2d..263ae3556 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -50,6 +50,7 @@ 1. **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)** (from Meta AI) released with the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Lukas Blecher, Guillem Cucurull, Thomas Scialom, Robert Stojnic. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. +1. **[Phi](https://huggingface.co/docs/transformers/main/model_doc/phi)** (from Microsoft) released with the papers - [Textbooks Are All You Need](https://arxiv.org/abs/2306.11644) by Suriya Gunasekar, Yi Zhang, Jyoti Aneja, Caio César Teodoro Mendes, Allie Del Giorno, Sivakanth Gopi, Mojan Javaheripi, Piero Kauffmann, Gustavo de Rosa, Olli Saarikivi, Adil Salim, Shital Shah, Harkirat Singh Behl, Xin Wang, Sébastien Bubeck, Ronen Eldan, Adam Tauman Kalai, Yin Tat Lee and Yuanzhi Li, [Textbooks Are All You Need II: phi-1.5 technical report](https://arxiv.org/abs/2309.05463) by Yuanzhi Li, Sébastien Bubeck, Ronen Eldan, Allie Del Giorno, Suriya Gunasekar and Yin Tat Lee. 1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei. diff --git a/scripts/supported_models.py b/scripts/supported_models.py index fca98bc9e..d56fd660c 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -567,6 +567,12 @@ 'microsoft/resnet-152', ], }, + 'phi': { + # Text generation + 'text-generation': [ + 'hf-internal-testing/tiny-random-PhiForCausalLM', + ], + }, 'roberta': { # Feature extraction 'feature-extraction': [ diff --git a/src/models.js b/src/models.js index 9a46bee3a..ef52b2d98 100644 --- a/src/models.js +++ b/src/models.js @@ -2480,7 +2480,7 @@ export class ASTModel extends ASTPreTrainedModel { } * Audio Spectrogram Transformer model with an audio classification head on top * (a linear layer on top of the pooled output) e.g. for datasets like AudioSet, Speech Commands v2. */ -export class ASTForAudioClassification extends ASTPreTrainedModel {} +export class ASTForAudioClassification extends ASTPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -3089,6 +3089,37 @@ export class LlamaModel extends LlamaPreTrainedModel { } export class LlamaForCausalLM extends LlamaPreTrainedModel { } ////////////////////////////////////////////////// +////////////////////////////////////////////////// +// Phi models + +export class PhiPreTrainedModel extends PreTrainedModel { + /** + * Creates a new instance of the `PhiPreTrainedModel` class. + * @param {Object} config The model configuration object. + * @param {Object} session The ONNX session object. + * @param {GenerationConfig} generation_config The generation configuration. + */ + constructor(config, session, generation_config) { + super(config, session); + this.generation_config = generation_config; + + // config doesn't contain pad_token_id, so we assume it is the eos_token_id + this.config.pad_token_id = this.config.eos_token_id; + + this.num_heads = this.config.num_attention_heads; + this.num_layers = this.config.num_hidden_layers; + this.dim_kv = this.config.hidden_size / this.num_heads; + } +} +/** + * The bare Phi Model outputting raw hidden-states without any specific head on top. + */ +export class PhiModel extends PhiPreTrainedModel { } + +export class PhiForCausalLM extends PhiPreTrainedModel { } +////////////////////////////////////////////////// + + ////////////////////////////////////////////////// // Bloom models /** @@ -4335,6 +4366,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([ ['gpt_neox', ['GPTNeoXModel', GPTNeoXModel]], ['codegen', ['CodeGenModel', CodeGenModel]], ['llama', ['LlamaModel', LlamaModel]], + ['phi', ['PhiModel', PhiModel]], ['mpt', ['MptModel', MptModel]], ['opt', ['OPTModel', OPTModel]], ['mistral', ['MistralModel', MistralModel]], @@ -4400,6 +4432,7 @@ const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([ ['gpt_neox', ['GPTNeoXForCausalLM', GPTNeoXForCausalLM]], ['codegen', ['CodeGenForCausalLM', CodeGenForCausalLM]], ['llama', ['LlamaForCausalLM', LlamaForCausalLM]], + ['phi', ['PhiForCausalLM', PhiForCausalLM]], ['mpt', ['MptForCausalLM', MptForCausalLM]], ['opt', ['OPTForCausalLM', OPTForCausalLM]], ['mbart', ['MBartForCausalLM', MBartForCausalLM]], @@ -4483,7 +4516,7 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([ ['wav2vec2', ['Wav2Vec2ForSequenceClassification', Wav2Vec2ForSequenceClassification]], ['wavlm', ['WavLMForSequenceClassification', WavLMForSequenceClassification]], ['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]], -]); +]); @@ -4533,7 +4566,7 @@ for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { const CUSTOM_MAPPING = [ ['CLIPTextModelWithProjection', CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly], ['CLIPVisionModelWithProjection', CLIPVisionModelWithProjection, MODEL_TYPES.EncoderOnly], - + ['ClapTextModelWithProjection', ClapTextModelWithProjection, MODEL_TYPES.EncoderOnly], ['ClapAudioModelWithProjection', ClapAudioModelWithProjection, MODEL_TYPES.EncoderOnly], ] From 9308f880c5c3632cc42a2676699d5ee460f32627 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 12 Dec 2023 17:42:48 +0200 Subject: [PATCH 4/5] Add support for DINOv2 models (#444) * Add dinov2 models * Add `BitImageProcessor` * Update list of supported models --- README.md | 1 + docs/snippets/6_supported-models.snippet | 1 + scripts/supported_models.py | 17 +++++++++++++++++ src/models.js | 24 ++++++++++++++++++++++++ src/processors.js | 2 ++ tests/processors.test.js | 17 +++++++++++++++++ 6 files changed, 62 insertions(+) diff --git a/README.md b/README.md index f5dc2cc13..a56477aaa 100644 --- a/README.md +++ b/README.md @@ -284,6 +284,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou. 1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko. +1. **[DINOv2](https://huggingface.co/docs/transformers/model_doc/dinov2)** (from Meta AI) released with the paper [DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193) by Maxime Oquab, Timothée Darcet, Théo Moutakanni, Huy Vo, Marc Szafraniec, Vasil Khalidov, Pierre Fernandez, Daniel Haziza, Francisco Massa, Alaaeldin El-Nouby, Mahmoud Assran, Nicolas Ballas, Wojciech Galuba, Russell Howes, Po-Yao Huang, Shang-Wen Li, Ishan Misra, Michael Rabbat, Vasu Sharma, Gabriel Synnaeve, Hu Xu, Hervé Jegou, Julien Mairal, Patrick Labatut, Armand Joulin, Piotr Bojanowski. 1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT. 1. **[Donut](https://huggingface.co/docs/transformers/model_doc/donut)** (from NAVER), released together with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park. 1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun. diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index 263ae3556..bdf069c97 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -20,6 +20,7 @@ 1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou. 1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko. +1. **[DINOv2](https://huggingface.co/docs/transformers/model_doc/dinov2)** (from Meta AI) released with the paper [DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193) by Maxime Oquab, Timothée Darcet, Théo Moutakanni, Huy Vo, Marc Szafraniec, Vasil Khalidov, Pierre Fernandez, Daniel Haziza, Francisco Massa, Alaaeldin El-Nouby, Mahmoud Assran, Nicolas Ballas, Wojciech Galuba, Russell Howes, Po-Yao Huang, Shang-Wen Li, Ishan Misra, Michael Rabbat, Vasu Sharma, Gabriel Synnaeve, Hu Xu, Hervé Jegou, Julien Mairal, Patrick Labatut, Armand Joulin, Piotr Bojanowski. 1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT. 1. **[Donut](https://huggingface.co/docs/transformers/model_doc/donut)** (from NAVER), released together with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park. 1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun. diff --git a/scripts/supported_models.py b/scripts/supported_models.py index d56fd660c..bf1b28ae0 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -289,6 +289,23 @@ 'facebook/detr-resnet-50-panoptic', ], }, + 'dinov2': { + # Feature extraction + 'feature-extraction': [ + 'facebook/dinov2-small', + 'facebook/dinov2-base', + 'facebook/dinov2-large', + # 'facebook/dinov2-giant', # TODO add + ], + + # Image classification + 'image-classification': [ + 'facebook/dinov2-small-imagenet1k-1-layer', + 'facebook/dinov2-base-imagenet1k-1-layer', + 'facebook/dinov2-large-imagenet1k-1-layer', + # 'facebook/dinov2-giant-imagenet1k-1-layer', # TODO add + ], + }, 'distilbert': { # Feature extraction 'feature-extraction': [ diff --git a/src/models.js b/src/models.js index ef52b2d98..cb80bb90a 100644 --- a/src/models.js +++ b/src/models.js @@ -3639,6 +3639,28 @@ export class ConvNextV2ForImageClassification extends ConvNextV2PreTrainedModel } ////////////////////////////////////////////////// +////////////////////////////////////////////////// +export class Dinov2PreTrainedModel extends PreTrainedModel { } + +/** + * The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top. + */ +export class Dinov2Model extends Dinov2PreTrainedModel { } + +/** + * Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet. + */ +export class Dinov2ForImageClassification extends Dinov2PreTrainedModel { + /** + * @param {any} model_inputs + */ + async _call(model_inputs) { + return new SequenceClassifierOutput(await super._call(model_inputs)); + } +} +////////////////////////////////////////////////// + + ////////////////////////////////////////////////// export class YolosPreTrainedModel extends PreTrainedModel { } export class YolosModel extends YolosPreTrainedModel { } @@ -4330,6 +4352,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ ['deit', ['DeiTModel', DeiTModel]], ['convnext', ['ConvNextModel', ConvNextModel]], ['convnextv2', ['ConvNextV2Model', ConvNextV2Model]], + ['dinov2', ['Dinov2Model', Dinov2Model]], ['resnet', ['ResNetModel', ResNetModel]], ['swin', ['SwinModel', SwinModel]], ['swin2sr', ['Swin2SRModel', Swin2SRModel]], @@ -4486,6 +4509,7 @@ const MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = new Map([ ['deit', ['DeiTForImageClassification', DeiTForImageClassification]], ['convnext', ['ConvNextForImageClassification', ConvNextForImageClassification]], ['convnextv2', ['ConvNextV2ForImageClassification', ConvNextV2ForImageClassification]], + ['dinov2', ['Dinov2ForImageClassification', Dinov2ForImageClassification]], ['resnet', ['ResNetForImageClassification', ResNetForImageClassification]], ['swin', ['SwinForImageClassification', SwinForImageClassification]], ]); diff --git a/src/processors.js b/src/processors.js index a6f1351d1..8750894a2 100644 --- a/src/processors.js +++ b/src/processors.js @@ -606,6 +606,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { } +export class BitImageProcessor extends ImageFeatureExtractor { } export class DPTFeatureExtractor extends ImageFeatureExtractor { } export class GLPNFeatureExtractor extends ImageFeatureExtractor { } export class CLIPFeatureExtractor extends ImageFeatureExtractor { } @@ -1652,6 +1653,7 @@ export class AutoProcessor { CLIPFeatureExtractor, ConvNextFeatureExtractor, ConvNextImageProcessor, + BitImageProcessor, DPTFeatureExtractor, GLPNFeatureExtractor, BeitFeatureExtractor, diff --git a/tests/processors.test.js b/tests/processors.test.js index c6703f14e..73d5b3e86 100644 --- a/tests/processors.test.js +++ b/tests/processors.test.js @@ -43,6 +43,7 @@ describe('Processors', () => { nougat: 'facebook/nougat-small', owlvit: 'google/owlvit-base-patch32', clip: 'openai/clip-vit-base-patch16', + dinov2: 'facebook/dinov2-small-imagenet1k-1-layer', } const TEST_IMAGES = { @@ -336,6 +337,22 @@ describe('Processors', () => { compare(reshaped_input_sizes, [[224, 224]]); } }, MAX_TEST_EXECUTION_TIME); + + // BitImageProcessor + it(MODELS.dinov2, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.dinov2)) + + { + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 224, 224]); + compare(avg(pixel_values.data), 0.06262318789958954); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[224, 224]]); + } + }, MAX_TEST_EXECUTION_TIME); }); describe('Audio processors', () => { From 47b1a873a2a8f101aa00175a759f8ba5de679c40 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 12 Dec 2023 17:54:27 +0200 Subject: [PATCH 5/5] Add support for ConvBERT models (#445) * Add support for `ConvBERT` models * Fix `ConvBertTokenizer` * Fix tokenizer --- README.md | 1 + docs/snippets/6_supported-models.snippet | 1 + scripts/supported_models.py | 8 +++ src/models.js | 79 ++++++++++++++++++++++++ src/tokenizers.js | 7 +++ 5 files changed, 96 insertions(+) diff --git a/README.md b/README.md index a56477aaa..a781ed95f 100644 --- a/README.md +++ b/README.md @@ -278,6 +278,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever. 1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong. 1. **[CodeLlama](https://huggingface.co/docs/transformers/model_doc/llama_code)** (from MetaAI) released with the paper [Code Llama: Open Foundation Models for Code](https://ai.meta.com/research/publications/code-llama-open-foundation-models-for-code/) by Baptiste Rozière, Jonas Gehring, Fabian Gloeckle, Sten Sootla, Itai Gat, Xiaoqing Ellen Tan, Yossi Adi, Jingyu Liu, Tal Remez, Jérémy Rapin, Artyom Kozhevnikov, Ivan Evtimov, Joanna Bitton, Manish Bhatt, Cristian Canton Ferrer, Aaron Grattafiori, Wenhan Xiong, Alexandre Défossez, Jade Copet, Faisal Azhar, Hugo Touvron, Louis Martin, Nicolas Usunier, Thomas Scialom, Gabriel Synnaeve. +1. **[ConvBERT](https://huggingface.co/docs/transformers/model_doc/convbert)** (from YituTech) released with the paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496) by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan. 1. **[ConvNeXT](https://huggingface.co/docs/transformers/model_doc/convnext)** (from Facebook AI) released with the paper [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) by Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie. 1. **[ConvNeXTV2](https://huggingface.co/docs/transformers/model_doc/convnextv2)** (from Facebook AI) released with the paper [ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808) by Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon, Saining Xie. 1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index bdf069c97..a1c9b2f70 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -14,6 +14,7 @@ 1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever. 1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong. 1. **[CodeLlama](https://huggingface.co/docs/transformers/model_doc/llama_code)** (from MetaAI) released with the paper [Code Llama: Open Foundation Models for Code](https://ai.meta.com/research/publications/code-llama-open-foundation-models-for-code/) by Baptiste Rozière, Jonas Gehring, Fabian Gloeckle, Sten Sootla, Itai Gat, Xiaoqing Ellen Tan, Yossi Adi, Jingyu Liu, Tal Remez, Jérémy Rapin, Artyom Kozhevnikov, Ivan Evtimov, Joanna Bitton, Manish Bhatt, Cristian Canton Ferrer, Aaron Grattafiori, Wenhan Xiong, Alexandre Défossez, Jade Copet, Faisal Azhar, Hugo Touvron, Louis Martin, Nicolas Usunier, Thomas Scialom, Gabriel Synnaeve. +1. **[ConvBERT](https://huggingface.co/docs/transformers/model_doc/convbert)** (from YituTech) released with the paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496) by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan. 1. **[ConvNeXT](https://huggingface.co/docs/transformers/model_doc/convnext)** (from Facebook AI) released with the paper [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) by Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie. 1. **[ConvNeXTV2](https://huggingface.co/docs/transformers/model_doc/convnextv2)** (from Facebook AI) released with the paper [ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808) by Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon, Saining Xie. 1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. diff --git a/scripts/supported_models.py b/scripts/supported_models.py index bf1b28ae0..f692bce84 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -202,6 +202,14 @@ 'Salesforce/codegen-350M-nl', ], }, + 'convbert': { + # Feature extraction + 'feature-extraction': [ + 'YituTech/conv-bert-small', + 'YituTech/conv-bert-medium-small', + 'YituTech/conv-bert-base', + ], + }, 'convnext': { # Image classification 'image-classification': [ diff --git a/src/models.js b/src/models.js index cb80bb90a..e06e81614 100644 --- a/src/models.js +++ b/src/models.js @@ -1464,6 +1464,80 @@ export class BertForQuestionAnswering extends BertPreTrainedModel { } ////////////////////////////////////////////////// + +////////////////////////////////////////////////// +// ConvBert models +export class ConvBertPreTrainedModel extends PreTrainedModel { } + +/** + * The bare ConvBERT Model transformer outputting raw hidden-states without any specific head on top. + */ +export class ConvBertModel extends ConvBertPreTrainedModel { } + +/** + * ConvBERT Model with a language modeling head on top. + */ +export class ConvBertForMaskedLM extends ConvBertPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} An object containing the model's output logits for masked language modeling. + */ + async _call(model_inputs) { + return new MaskedLMOutput(await super._call(model_inputs)); + } +} + +/** + * ConvBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output) + */ +export class ConvBertForSequenceClassification extends ConvBertPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} An object containing the model's output logits for sequence classification. + */ + async _call(model_inputs) { + return new SequenceClassifierOutput(await super._call(model_inputs)); + } +} + +/** + * ConvBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) + * e.g. for Named-Entity-Recognition (NER) tasks. + */ +export class ConvBertForTokenClassification extends ConvBertPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} An object containing the model's output logits for token classification. + */ + async _call(model_inputs) { + return new TokenClassifierOutput(await super._call(model_inputs)); + } +} + +/** + * ConvBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD + * (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`) + */ +export class ConvBertForQuestionAnswering extends ConvBertPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} An object containing the model's output logits for question answering. + */ + async _call(model_inputs) { + return new QuestionAnsweringModelOutput(await super._call(model_inputs)); + } +} +////////////////////////////////////////////////// + + ////////////////////////////////////////////////// // CamemBERT models export class CamembertPreTrainedModel extends PreTrainedModel { } @@ -4327,6 +4401,7 @@ export class PretrainedMixin { const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ ['bert', ['BertModel', BertModel]], + ['convbert', ['ConvBertModel', ConvBertModel]], ['camembert', ['CamembertModel', CamembertModel]], ['deberta', ['DebertaModel', DebertaModel]], ['deberta-v2', ['DebertaV2Model', DebertaV2Model]], @@ -4407,6 +4482,7 @@ const MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = new Map([ const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([ ['bert', ['BertForSequenceClassification', BertForSequenceClassification]], + ['convbert', ['ConvBertForSequenceClassification', ConvBertForSequenceClassification]], ['camembert', ['CamembertForSequenceClassification', CamembertForSequenceClassification]], ['deberta', ['DebertaForSequenceClassification', DebertaForSequenceClassification]], ['deberta-v2', ['DebertaV2ForSequenceClassification', DebertaV2ForSequenceClassification]], @@ -4424,6 +4500,7 @@ const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([ const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([ ['bert', ['BertForTokenClassification', BertForTokenClassification]], + ['convbert', ['ConvBertForTokenClassification', ConvBertForTokenClassification]], ['camembert', ['CamembertForTokenClassification', CamembertForTokenClassification]], ['deberta', ['DebertaForTokenClassification', DebertaForTokenClassification]], ['deberta-v2', ['DebertaV2ForTokenClassification', DebertaV2ForTokenClassification]], @@ -4466,6 +4543,7 @@ const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([ const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([ ['bert', ['BertForMaskedLM', BertForMaskedLM]], + ['convbert', ['ConvBertForMaskedLM', ConvBertForMaskedLM]], ['camembert', ['CamembertForMaskedLM', CamembertForMaskedLM]], ['deberta', ['DebertaForMaskedLM', DebertaForMaskedLM]], ['deberta-v2', ['DebertaV2ForMaskedLM', DebertaV2ForMaskedLM]], @@ -4481,6 +4559,7 @@ const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([ const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = new Map([ ['bert', ['BertForQuestionAnswering', BertForQuestionAnswering]], + ['convbert', ['ConvBertForQuestionAnswering', ConvBertForQuestionAnswering]], ['camembert', ['CamembertForQuestionAnswering', CamembertForQuestionAnswering]], ['deberta', ['DebertaForQuestionAnswering', DebertaForQuestionAnswering]], ['deberta-v2', ['DebertaV2ForQuestionAnswering', DebertaV2ForQuestionAnswering]], diff --git a/src/tokenizers.js b/src/tokenizers.js index c0cf7afb8..284428bd7 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -2736,6 +2736,12 @@ export class HerbertTokenizer extends PreTrainedTokenizer { return add_token_types(inputs); } } +export class ConvBertTokenizer extends PreTrainedTokenizer { + /** @type {add_token_types} */ + prepare_model_inputs(inputs) { + return add_token_types(inputs); + } +} export class DistilBertTokenizer extends PreTrainedTokenizer { } export class CamembertTokenizer extends PreTrainedTokenizer { } export class XLMTokenizer extends PreTrainedTokenizer { @@ -3860,6 +3866,7 @@ export class AutoTokenizer { DebertaV2Tokenizer, BertTokenizer, HerbertTokenizer, + ConvBertTokenizer, XLMTokenizer, MobileBertTokenizer, SqueezeBertTokenizer,