Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova authored Feb 28, 2024
1 parent 69f5652 commit c3c6e01
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -5524,7 +5524,7 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([
['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]],
]);

const MODEL_FOR_XVECTOR_MAPPING_NAMES = new Map([
const MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = new Map([
['wavlm', ['WavLMForXVector', WavLMForXVector]],
]);

Expand Down Expand Up @@ -5568,7 +5568,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
[MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_XVECTOR_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
];

for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) {
Expand Down Expand Up @@ -5788,7 +5788,7 @@ export class AutoModelForAudioClassification extends PretrainedMixin {
}

export class AutoModelForXVector extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_XVECTOR_MAPPING_NAMES];
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES];
}

export class AutoModelForDocumentQuestionAnswering extends PretrainedMixin {
Expand Down Expand Up @@ -5844,13 +5844,13 @@ export class SequenceClassifierOutput extends ModelOutput {
}

/**
* Base class for outputs of x-vector models.
* Base class for outputs of XVector models.
*/
export class XVectorOutput extends ModelOutput {
/**
* @param {Object} output The output of the model.
* @param {Tensor} output.logits classification (or regression if config.num_labels==1) scores (before SoftMax).
* @param {Tensor} output.embeddings The embeddings of the input sequence.
* @param {Tensor} output.logits Classification hidden states before AMSoftmax, of shape `(batch_size, config.xvector_output_dim)`.
* @param {Tensor} output.embeddings Utterance embeddings used for vector similarity-based retrieval, of shape `(batch_size, config.xvector_output_dim)`.
*/
constructor({ logits, embeddings }) {
super();
Expand Down

0 comments on commit c3c6e01

Please sign in to comment.