Skip to content

Commit

Permalink
Add support for WavLMForXVector (#603)
Browse files Browse the repository at this point in the history
* Add WavLMForXVector support

* fix model docs

* fix bad naming

* Apply suggestions from code review

* Update default `wavlm` quantization settings

* Update list of supported `wavlm` models

* Update JSDoc

* Fix typo

---------

Co-authored-by: Joshua Lochner <[email protected]>
  • Loading branch information
D4ve-R and xenova authored Feb 28, 2024
1 parent 271c6f1 commit b5a548f
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 1 deletion.
4 changes: 4 additions & 0 deletions scripts/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@
'per_channel': False,
'reduce_range': False,
},
'wavlm': {
'per_channel': False,
'reduce_range': False,
},
}

MODELS_WITHOUT_TOKENIZERS = [
Expand Down
6 changes: 6 additions & 0 deletions scripts/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,12 @@
'microsoft/wavlm-base-plus',
'microsoft/wavlm-large',
],

# Audio XVector (e.g., for speaker verification)
'audio-xvector': [
'microsoft/wavlm-base-plus-sv',
'microsoft/wavlm-base-sv',
],
},
'whisper': {
# Automatic speech recognition
Expand Down
68 changes: 68 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -4735,6 +4735,49 @@ export class WavLMForSequenceClassification extends WavLMPreTrainedModel {
}
}

/**
* WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification.
*
* **Example:** Extract speaker embeddings with `WavLMForXVector`.
* ```javascript
* import { AutoProcessor, AutoModel, read_audio } from '@xenova/transformers';
*
* // Read and preprocess audio
* const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sv');
* const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
* const audio = await read_audio(url, 16000);
* const inputs = await processor(audio);
*
* // Run model with inputs
* const model = await AutoModel.from_pretrained('Xenova/wavlm-base-plus-sv');
* const outputs = await model(inputs);
* // {
* // logits: Tensor {
* // dims: [ 1, 512 ],
* // type: 'float32',
* // data: Float32Array(512) [0.5847219228744507, ...],
* // size: 512
* // },
* // embeddings: Tensor {
* // dims: [ 1, 512 ],
* // type: 'float32',
* // data: Float32Array(512) [-0.09079201519489288, ...],
* // size: 512
* // }
* // }
* ```
*/
export class WavLMForXVector extends WavLMPreTrainedModel {
/**
* Calls the model on new inputs.
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<XVectorOutput>} An object containing the model's output logits and speaker embeddings.
*/
async _call(model_inputs) {
return new XVectorOutput(await super._call(model_inputs));
}
}

//////////////////////////////////////////////////
// SpeechT5 models
/**
Expand Down Expand Up @@ -5483,6 +5526,10 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([
['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]],
]);

const MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = new Map([
['wavlm', ['WavLMForXVector', WavLMForXVector]],
]);

const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([
['vitmatte', ['VitMatteForImageMatting', VitMatteForImageMatting]],
]);
Expand Down Expand Up @@ -5523,6 +5570,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
[MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
];

for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) {
Expand Down Expand Up @@ -5741,6 +5789,10 @@ export class AutoModelForAudioClassification extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES];
}

export class AutoModelForXVector extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES];
}

export class AutoModelForDocumentQuestionAnswering extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES];
}
Expand Down Expand Up @@ -5793,6 +5845,22 @@ export class SequenceClassifierOutput extends ModelOutput {
}
}

/**
* Base class for outputs of XVector models.
*/
export class XVectorOutput extends ModelOutput {
/**
* @param {Object} output The output of the model.
* @param {Tensor} output.logits Classification hidden states before AMSoftmax, of shape `(batch_size, config.xvector_output_dim)`.
* @param {Tensor} output.embeddings Utterance embeddings used for vector similarity-based retrieval, of shape `(batch_size, config.xvector_output_dim)`.
*/
constructor({ logits, embeddings }) {
super();
this.logits = logits;
this.embeddings = embeddings;
}
}

/**
* Base class for outputs of token classification models.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ function post_process_object_detection(outputs, threshold = 0.5, target_sizes =
function validate_audio_inputs(audio, feature_extractor) {
if (!(audio instanceof Float32Array || audio instanceof Float64Array)) {
throw new Error(
`${feature_extractor} expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead.` +
`${feature_extractor} expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead. ` +
`If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.`
)
}
Expand Down

0 comments on commit b5a548f

Please sign in to comment.