Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for WavlmForXVector #603

Merged
merged 8 commits into from
Feb 28, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -4735,6 +4735,47 @@ export class WavLMForSequenceClassification extends WavLMPreTrainedModel {
}
}

/**
* WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification.
*
* **Example:** Extract speaker embeddings with `WavLMForXVector`.
* ```javascript
* import { AutoProcessor, AutoModel, read_audio } from '@xenova/transformers';
*
* const processor = await AutoProcessor.from_pretrained('D4ve-R/wavlm-base-plus-sv');
* const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
* const audio = await read_audio(url, 16000);
* const inputs = await processor(audio);

* const model = await AutoModel.from_pretrained('D4ve-R/wavlm-base-plus-sv');
* const embeddings = await model(inputs);
* // {
* // embeddings: Tensor {
* // dims: [ 1, 512 ],
* // type: 'float32',
* // data: Float32Array(512) [-0.349443256855011, ...],
* // size: 512
* // },
* // logits: Tensor {
* // dims: [ 1, 512 ],
* // type: 'float32',
* // data: Float32Array(512) [0.022836603224277496, ...],
* // size: 512
* // }
* // }
* ```
*/
export class WavLMForXVector extends WavLMPreTrainedModel {
/**
* Calls the model on new inputs.
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<XVectorOutput>} An object containing the model's output logits for sequence classification.
*/
async _call(model_inputs) {
return new XVectorOutput(await super._call(model_inputs));
}
}

//////////////////////////////////////////////////
// SpeechT5 models
/**
Expand Down Expand Up @@ -5483,6 +5524,10 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([
['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]],
]);

const MODEL_FOR_XVECTOR_MAPPING_NAMES = new Map([
xenova marked this conversation as resolved.
Show resolved Hide resolved
['wavlm', ['WavLMForXVector', WavLMForXVector]],
]);

const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([
['vitmatte', ['VitMatteForImageMatting', VitMatteForImageMatting]],
]);
Expand Down Expand Up @@ -5523,6 +5568,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
[MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_XVECTOR_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
xenova marked this conversation as resolved.
Show resolved Hide resolved
];

for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) {
Expand Down Expand Up @@ -5741,6 +5787,10 @@ export class AutoModelForAudioClassification extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES];
}

export class AutoModelForXVector extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_XVECTOR_MAPPING_NAMES];
xenova marked this conversation as resolved.
Show resolved Hide resolved
}

export class AutoModelForDocumentQuestionAnswering extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES];
}
Expand Down Expand Up @@ -5793,6 +5843,22 @@ export class SequenceClassifierOutput extends ModelOutput {
}
}

/**
* Base class for outputs of x-vector models.
*/
export class XVectorOutput extends ModelOutput {
/**
* @param {Object} output The output of the model.
* @param {Tensor} output.logits classification (or regression if config.num_labels==1) scores (before SoftMax).
* @param {Tensor} output.embeddings The embeddings of the input sequence.
*/
constructor({ logits, embeddings }) {
super();
this.logits = logits;
this.embeddings = embeddings;
}
}

xenova marked this conversation as resolved.
Show resolved Hide resolved
/**
* Base class for outputs of token classification models.
*/
Expand Down