diff --git a/README.md b/README.md index 8b3fb7e83..96a8548e0 100644 --- a/README.md +++ b/README.md @@ -343,10 +343,10 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[ViTMatte](https://huggingface.co/docs/transformers/model_doc/vitmatte)** (from HUST-VL) released with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang. 1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son. 1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. +1. **[Wav2Vec2-BERT](https://huggingface.co/docs/transformers/main/model_doc/wav2vec2-bert)** (from Meta AI) released with the paper [Seamless: Multilingual Expressive and Streaming Speech Translation](https://ai.meta.com/research/publications/seamless-multilingual-expressive-and-streaming-speech-translation/) by the Seamless Communication team. 1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei. 1. **[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper)** (from OpenAI) released with the paper [Robust Speech Recognition via Large-Scale Weak Supervision](https://cdn.openai.com/papers/whisper.pdf) by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, Ilya Sutskever. 1. **[XLM](https://huggingface.co/docs/transformers/model_doc/xlm)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau. 1. **[XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. 1. **[YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos)** (from Huazhong University of Science & Technology) released with the paper [You Only Look at One Sequence: Rethinking Transformer in Vision through Object Detection](https://arxiv.org/abs/2106.00666) by Yuxin Fang, Bencheng Liao, Xinggang Wang, Jiemin Fang, Jiyang Qi, Rui Wu, Jianwei Niu, Wenyu Liu. - diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index cf92f6648..6d4082379 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -78,9 +78,9 @@ 1. **[ViTMatte](https://huggingface.co/docs/transformers/model_doc/vitmatte)** (from HUST-VL) released with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang. 1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son. 1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. +1. **[Wav2Vec2-BERT](https://huggingface.co/docs/transformers/main/model_doc/wav2vec2-bert)** (from Meta AI) released with the paper [Seamless: Multilingual Expressive and Streaming Speech Translation](https://ai.meta.com/research/publications/seamless-multilingual-expressive-and-streaming-speech-translation/) by the Seamless Communication team. 1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei. 1. **[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper)** (from OpenAI) released with the paper [Robust Speech Recognition via Large-Scale Weak Supervision](https://cdn.openai.com/papers/whisper.pdf) by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, Ilya Sutskever. 1. **[XLM](https://huggingface.co/docs/transformers/model_doc/xlm)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau. 1. **[XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. 1. **[YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos)** (from Huazhong University of Science & Technology) released with the paper [You Only Look at One Sequence: Rethinking Transformer in Vision through Object Detection](https://arxiv.org/abs/2106.00666) by Yuxin Fang, Bencheng Liao, Xinggang Wang, Jiemin Fang, Jiyang Qi, Rui Wu, Jianwei Niu, Wenyu Liu. - diff --git a/scripts/convert.py b/scripts/convert.py index 664c2b8a0..78232aab7 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -93,6 +93,7 @@ MODELS_WITHOUT_TOKENIZERS = [ 'wav2vec2', + 'wav2vec2-bert', 'wavlm', 'hubert', ] @@ -331,7 +332,7 @@ def main(): **get_main_export_kwargs(config, "automatic-speech-recognition") ) - elif config.model_type in ('wav2vec2', 'hubert'): + elif config.model_type in ('wav2vec2', 'wav2vec2-bert', 'hubert'): if tokenizer is not None: from .extra.wav2vec2 import generate_tokenizer_json tokenizer_json = generate_tokenizer_json(tokenizer) diff --git a/scripts/supported_models.py b/scripts/supported_models.py index d75cfce22..6ca884e43 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -980,6 +980,16 @@ 'facebook/mms-1b-fl102', ], }, + 'wav2vec2-bert': { + 'feature-extraction': [ + 'facebook/w2v-bert-2.0', + ], + + # Automatic speech recognition + 'automatic-speech-recognition': [ + 'hf-audio/wav2vec2-bert-CV16-en', + ], + }, 'wavlm': { # Feature extraction 'feature-extraction': [ diff --git a/src/models.js b/src/models.js index 190fad3bb..789f853fe 100644 --- a/src/models.js +++ b/src/models.js @@ -4527,6 +4527,44 @@ export class Wav2Vec2ForSequenceClassification extends Wav2Vec2PreTrainedModel { } ////////////////////////////////////////////////// +////////////////////////////////////////////////// +// Wav2Vec2 models +export class Wav2Vec2BertPreTrainedModel extends PreTrainedModel { }; + +/** + * The bare Wav2Vec2Bert Model transformer outputting raw hidden-states without any specific head on top. + */ +export class Wav2Vec2BertModel extends Wav2Vec2BertPreTrainedModel { } + +/** + * Wav2Vec2Bert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). + */ +export class Wav2Vec2BertForCTC extends Wav2Vec2BertPreTrainedModel { + /** + * @param {Object} model_inputs + * @param {Tensor} model_inputs.input_features Float values of input mel-spectrogram. + * @param {Tensor} model_inputs.attention_mask Mask to avoid performing convolution and attention on padding token indices. Mask values selected in [0, 1] + */ + async _call(model_inputs) { + return new CausalLMOutput(await super._call(model_inputs)); + } +} + +/** + * Wav2Vec2Bert Model with a sequence classification head on top (a linear layer over the pooled output). + */ +export class Wav2Vec2BertForSequenceClassification extends Wav2Vec2BertPreTrainedModel { + /** + * Calls the model on new inputs. + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} An object containing the model's output logits for sequence classification. + */ + async _call(model_inputs) { + return new SequenceClassifierOutput(await super._call(model_inputs)); + } +} +////////////////////////////////////////////////// + ////////////////////////////////////////////////// // Hubert models export class HubertPreTrainedModel extends PreTrainedModel { } @@ -5160,6 +5198,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ ['mobilebert', ['MobileBertModel', MobileBertModel]], ['squeezebert', ['SqueezeBertModel', SqueezeBertModel]], ['wav2vec2', ['Wav2Vec2Model', Wav2Vec2Model]], + ['wav2vec2-bert', ['Wav2Vec2BertModel', Wav2Vec2BertModel]], ['hubert', ['HubertModel', HubertModel]], ['wavlm', ['WavLMModel', WavLMModel]], ['audio-spectrogram-transformer', ['ASTModel', ASTModel]], @@ -5380,12 +5419,14 @@ const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map([ const MODEL_FOR_CTC_MAPPING_NAMES = new Map([ ['wav2vec2', ['Wav2Vec2ForCTC', Wav2Vec2ForCTC]], + ['wav2vec2-bert', ['Wav2Vec2BertForCTC', Wav2Vec2BertForCTC]], ['wavlm', ['WavLMForCTC', WavLMForCTC]], ['hubert', ['HubertForCTC', HubertForCTC]], ]); const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([ ['wav2vec2', ['Wav2Vec2ForSequenceClassification', Wav2Vec2ForSequenceClassification]], + ['wav2vec2-bert', ['Wav2Vec2BertForSequenceClassification', Wav2Vec2BertForSequenceClassification]], ['wavlm', ['WavLMForSequenceClassification', WavLMForSequenceClassification]], ['hubert', ['HubertForSequenceClassification', HubertForSequenceClassification]], ['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]], diff --git a/src/pipelines.js b/src/pipelines.js index 161094270..a9af251cb 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -1525,6 +1525,7 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options case 'whisper': return this._call_whisper(audio, kwargs) case 'wav2vec2': + case 'wav2vec2-bert': case 'hubert': return this._call_wav2vec2(audio, kwargs) default: diff --git a/src/processors.js b/src/processors.js index 1c85bb533..0dbec5c0c 100644 --- a/src/processors.js +++ b/src/processors.js @@ -1538,6 +1538,182 @@ export class Wav2Vec2FeatureExtractor extends FeatureExtractor { } } +export class SeamlessM4TFeatureExtractor extends FeatureExtractor { + + constructor(config) { + super(config); + + const sampling_rate = this.config.sampling_rate; + const mel_filters = mel_filter_bank( + 256, // num_frequency_bins + this.config.num_mel_bins, // num_mel_filters + 20, // min_frequency + Math.floor(sampling_rate / 2), // max_frequency + sampling_rate, // sampling_rate + null, // norm + "kaldi", // mel_scale + true, // triangularize_in_mel_space + ); + + // Do padding: + for (let i = 0; i < mel_filters.length; ++i) { + mel_filters[i].push(0); + } + this.mel_filters = mel_filters; + + this.window = window_function(400, 'povey', { + periodic: false, + }) + } + + /** + * Computes the log-Mel spectrogram of the provided audio waveform. + * @param {Float32Array|Float64Array} waveform The audio waveform to process. + * @param {number} max_length The maximum number of frames to return. + * @returns {{data: Float32Array, dims: number[]}} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. + */ + _extract_fbank_features(waveform, max_length) { + // NOTE: We don't pad/truncate since that is passed in as `max_num_frames` + + // Kaldi compliance: 16-bit signed integers + // 32768 == 2 ** 15 + waveform = waveform.map((/** @type {number} */ x) => x * 32768) + + return spectrogram( + waveform, + this.window, // window + 400, // frame_length + 160, // hop_length + { + fft_length: 512, + power: 2.0, + center: false, + preemphasis: 0.97, + mel_filters: this.mel_filters, + log_mel: 'log', + mel_floor: 1.192092955078125e-07, + remove_dc_offset: true, + + // Custom + max_num_frames: max_length, + transpose: true, + } + ) + } + + /** + * Asynchronously extracts features from a given audio using the provided configuration. + * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. + * @param {Object} options Optional parameters for feature extraction. + * @param {boolean} [options.padding=true] Whether to pad the sequence to a multiple of `pad_to_multiple_of`. + * @param {number} [options.pad_to_multiple_of=2] The number to pad the sequence to a multiple of. + * @param {boolean} [options.do_normalize_per_mel_bins=true] Whether or not to zero-mean unit-variance normalize the input per mel-channel. + * @param {boolean} [options.return_attention_mask=true] Whether to return the attention mask. + * @returns {Promise<{ input_features: Tensor, attention_mask?: Tensor }>} A Promise resolving to an object containing the extracted input features and attention masks as Tensors. + */ + async _call(audio, { + padding = true, + pad_to_multiple_of = 2, + do_normalize_per_mel_bins = true, + return_attention_mask = true, + } = {}) { + validate_audio_inputs(audio, 'SeamlessM4TFeatureExtractor'); + + let features = this._extract_fbank_features(audio, this.config.max_length); + + if (do_normalize_per_mel_bins) { + const [num_features, feature_size] = features.dims; + for (let i = 0; i < feature_size; ++i) { + let sum = 0; + for (let j = 0; j < num_features; ++j) { + sum += features.data[j * feature_size + i]; + } + + const mean = sum / num_features; + + let variance = 0; + for (let j = 0; j < num_features; ++j) { + variance += (features.data[j * feature_size + i] - mean) ** 2; + } + variance /= num_features - 1; // NOTE: We use ddof=1 + + const std = Math.sqrt(variance + 1e-7); + for (let j = 0; j < num_features; ++j) { + const index = j * feature_size + i; + features.data[index] = (features.data[index] - mean) / std; + } + } + } + + let padded_attention_mask; + if (padding) { + const [num_frames, num_channels] = features.dims; + + const pad_size = num_frames % pad_to_multiple_of; + if (pad_size > 0) { + const padded_data = new Float32Array(num_channels * (num_frames + pad_size)); + padded_data.set(features.data) + padded_data.fill(this.config.padding_value, features.data.length) + + const numPaddedFrames = num_frames + pad_size; + features = { + data: padded_data, + dims: [numPaddedFrames, num_channels], + } + + if (return_attention_mask) { + padded_attention_mask = new Tensor( + 'int64', + new BigInt64Array(numPaddedFrames), + [1, numPaddedFrames], + ) + padded_attention_mask.data.fill(1n, 0, num_frames); + } + } + } + + const [num_frames, num_channels] = features.dims; + + const stride = this.config.stride; + const remainder = num_frames % stride; + if (remainder !== 0) { + throw new Error(`The number of frames (${num_frames}) must be a multiple of the stride (${stride}).`) + } + + const input_features = new Tensor('float32', + features.data, + features.dims, + ).view( + 1, + Math.floor(num_frames / stride), + num_channels * stride, + ); + + const result = { input_features } + + if (return_attention_mask) { + const reshapedNumFrames = input_features.dims[1]; + + const attention_mask = new Tensor( + 'int64', + new BigInt64Array(reshapedNumFrames), + [1, reshapedNumFrames], + ); + if (padded_attention_mask) { + for (let i = 1, j = 0; i < num_frames; i += stride, ++j) { + attention_mask.data[j] = padded_attention_mask.data[i]; + } + } else { + attention_mask.data.fill(1n); + } + + result.attention_mask = attention_mask; + } + + return result; + } +} + export class ASTFeatureExtractor extends FeatureExtractor { @@ -1944,6 +2120,7 @@ export class AutoProcessor { SamImageProcessor, Swin2SRImageProcessor, Wav2Vec2FeatureExtractor, + SeamlessM4TFeatureExtractor, SpeechT5FeatureExtractor, ASTFeatureExtractor, ClapFeatureExtractor, diff --git a/src/utils/audio.js b/src/utils/audio.js index 082870de8..9c2382bda 100644 --- a/src/utils/audio.js +++ b/src/utils/audio.js @@ -647,6 +647,9 @@ export function window_function(window_length, name, { case 'hann_window': window = hanning(length); break; + case 'povey': + window = hanning(length).map(x => Math.pow(x, 0.85)); + break; default: throw new Error(`Unknown window type ${name}.`); } diff --git a/tests/processors.test.js b/tests/processors.test.js index 2bf0bd3d1..38f47bb17 100644 --- a/tests/processors.test.js +++ b/tests/processors.test.js @@ -7,9 +7,8 @@ import { compare } from './test_utils.js'; env.allowLocalModels = false; env.useFSCache = false; -const avg = (array) => { - return Number(array.reduce((a, b) => a + b, array instanceof BigInt64Array ? 0n : 0)) / array.length; -} +const sum = array => Number(array.reduce((a, b) => a + b, array instanceof BigInt64Array ? 0n : 0)); +const avg = array => sum(array) / array.length; describe('Processors', () => { @@ -478,6 +477,42 @@ describe('Processors', () => { } }, MAX_TEST_EXECUTION_TIME); + it('SeamlessM4TFeatureExtractor', async () => { + const audio = await audioPromise; + const processor = await AutoProcessor.from_pretrained('Xenova/wav2vec2-bert-CV16-en'); + { // normal + console.log({ audio }) + const { input_features, attention_mask } = await processor(audio); + compare(input_features.dims, [1, 649, 160]); + compare(attention_mask.dims, [1, 649]); + + expect(avg(input_features.data)).toBeCloseTo(-2.938903875815413e-08); + expect(input_features.data[0]).toBeCloseTo(1.1939343214035034); + expect(input_features.data[1]).toBeCloseTo(0.7874255180358887); + expect(input_features.data[160]).toBeCloseTo(-0.712975025177002); + expect(input_features.data[161]).toBeCloseTo(0.045802414417266846); + expect(input_features.data.at(-1)).toBeCloseTo(-1.3328346014022827); + + expect(sum(attention_mask.data)).toEqual(649); + } + { // padding (pad_to_multiple_of=2) + const { input_features, attention_mask } = await processor(audio.slice(0, 10000)); + + // [1, 61, 80] -> [1, 62, 80] -> [1, 31, 160] + compare(input_features.dims, [1, 31, 160]); + compare(attention_mask.dims, [1, 31]); + + expect(avg(input_features.data)).toBeCloseTo(0.01612919569015503); + expect(input_features.data[0]).toBeCloseTo(0.9657132029533386); + expect(input_features.data[1]).toBeCloseTo(0.12912897765636444); + expect(input_features.data[160]).toBeCloseTo(-1.2364212274551392); + expect(input_features.data[161]).toBeCloseTo(-0.9703778028488159); + expect(input_features.data.at(-1)).toBeCloseTo(1); // padding value + + expect(sum(attention_mask.data)).toEqual(30); + } + }, MAX_TEST_EXECUTION_TIME); + it('ClapFeatureExtractor', async () => { const audio = await audioPromise; const processor = await AutoProcessor.from_pretrained('Xenova/clap-htsat-unfused');