From ca515e546e6567c0ba92c7be8d0e7ca3f2164495 Mon Sep 17 00:00:00 2001 From: nevikw39 Date: Tue, 6 Feb 2024 15:43:50 +0800 Subject: [PATCH] Apply our fix to brand new model Wav2Vec2-BERT! --- .../wav2vec2_bert/modeling_wav2vec2_bert.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py index 858f270a87f138..fcf22df4fdf142 100644 --- a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py @@ -22,7 +22,7 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss +from torch.nn import CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled @@ -1372,8 +1372,14 @@ def forward( loss = None if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + # move labels to correct device to enable PP + labels = labels.to(logits.device) + if self.config.num_labels > 1: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + else: + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] @@ -1467,8 +1473,16 @@ def forward( loss = None if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1)) + if self.config.num_labels > 1: + loss_fct = CrossEntropyLoss() + # move labels to correct device to enable PP + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + else: + loss_fct = MSELoss() + # move labels to correct device to enable PP + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1), labels.view(-1)) if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]