Skip to content

Commit

Permalink
Add WavLM- & Wav2Vec2ForAudioFrameClassification support (#611)
Browse files Browse the repository at this point in the history
* Add WavLMForXVector support

* fix model docs

* Add WavLMForAudioFrameClassification

* Add missing wWav2Vec2ForAudioFrameCl.

* Add doc comment

* Add doc string wav2vec2

* update comment

* make example like python

* Update src/models.js

---------

Co-authored-by: Joshua Lochner <[email protected]>
  • Loading branch information
D4ve-R and xenova authored Mar 7, 2024
1 parent 5bb8d25 commit 8eef154
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -4571,6 +4571,20 @@ export class Wav2Vec2ForSequenceClassification extends Wav2Vec2PreTrainedModel {
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}

/**
* Wav2Vec2 Model with a frame classification head on top for tasks like Speaker Diarization.
*/
export class Wav2Vec2ForAudioFrameClassification extends Wav2Vec2PreTrainedModel {
/**
* Calls the model on new inputs.
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for sequence classification.
*/
async _call(model_inputs) {
return new TokenClassifierOutput(await super._call(model_inputs));
}
}
//////////////////////////////////////////////////

//////////////////////////////////////////////////
Expand Down Expand Up @@ -4868,6 +4882,54 @@ export class WavLMForXVector extends WavLMPreTrainedModel {
}
}

/**
* WavLM Model with a frame classification head on top for tasks like Speaker Diarization.
*
* **Example:** Perform speaker diarization with `WavLMForAudioFrameClassification`.
* ```javascript
* import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@xenova/transformers';
*
* // Read and preprocess audio
* const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sd');
* const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
* const audio = await read_audio(url, 16000);
* const inputs = await processor(audio);
*
* // Run model with inputs
* const model = await AutoModelForAudioFrameClassification.from_pretrained('Xenova/wavlm-base-plus-sd');
* const { logits } = await model(inputs);
* // {
* // logits: Tensor {
* // dims: [ 1, 549, 2 ], // [batch_size, num_frames, num_speakers]
* // type: 'float32',
* // data: Float32Array(1098) [-3.5301010608673096, ...],
* // size: 1098
* // }
* // }
*
* const labels = logits[0].sigmoid().tolist().map(
* frames => frames.map(speaker => speaker > 0.5 ? 1 : 0)
* );
* console.log(labels); // labels is a one-hot array of shape (num_frames, num_speakers)
* // [
* // [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
* // [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
* // [0, 0], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1],
* // ...
* // ]
* ```
*/
export class WavLMForAudioFrameClassification extends WavLMPreTrainedModel {
/**
* Calls the model on new inputs.
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for sequence classification.
*/
async _call(model_inputs) {
return new TokenClassifierOutput(await super._call(model_inputs));
}
}

//////////////////////////////////////////////////
// SpeechT5 models
/**
Expand Down Expand Up @@ -5695,6 +5757,8 @@ const MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = new Map([

const MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = new Map([
['unispeech-sat', ['UniSpeechSatForAudioFrameClassification', UniSpeechSatForAudioFrameClassification]],
['wavlm', ['WavLMForAudioFrameClassification', WavLMForAudioFrameClassification]],
['wav2vec2', ['Wav2Vec2ForAudioFrameClassification', Wav2Vec2ForAudioFrameClassification]],
]);

const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([
Expand Down Expand Up @@ -5961,6 +6025,10 @@ export class AutoModelForXVector extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES];
}

export class AutoModelForAudioFrameClassification extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES];
}

export class AutoModelForDocumentQuestionAnswering extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES];
}
Expand Down

0 comments on commit 8eef154

Please sign in to comment.