diff --git a/src/models.js b/src/models.js index 11e51c169..a832bb6c1 100644 --- a/src/models.js +++ b/src/models.js @@ -5524,7 +5524,7 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([ ['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]], ]); -const MODEL_FOR_XVECTOR_MAPPING_NAMES = new Map([ +const MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = new Map([ ['wavlm', ['WavLMForXVector', WavLMForXVector]], ]); @@ -5568,7 +5568,7 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], - [MODEL_FOR_XVECTOR_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], ]; for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { @@ -5788,7 +5788,7 @@ export class AutoModelForAudioClassification extends PretrainedMixin { } export class AutoModelForXVector extends PretrainedMixin { - static MODEL_CLASS_MAPPINGS = [MODEL_FOR_XVECTOR_MAPPING_NAMES]; + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES]; } export class AutoModelForDocumentQuestionAnswering extends PretrainedMixin { @@ -5844,13 +5844,13 @@ export class SequenceClassifierOutput extends ModelOutput { } /** - * Base class for outputs of x-vector models. + * Base class for outputs of XVector models. */ export class XVectorOutput extends ModelOutput { /** * @param {Object} output The output of the model. - * @param {Tensor} output.logits classification (or regression if config.num_labels==1) scores (before SoftMax). - * @param {Tensor} output.embeddings The embeddings of the input sequence. + * @param {Tensor} output.logits Classification hidden states before AMSoftmax, of shape `(batch_size, config.xvector_output_dim)`. + * @param {Tensor} output.embeddings Utterance embeddings used for vector similarity-based retrieval, of shape `(batch_size, config.xvector_output_dim)`. */ constructor({ logits, embeddings }) { super();