diff --git a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py index 858f270a87f138..fcf22df4fdf142 100644 --- a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py @@ -22,7 +22,7 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss +from torch.nn import CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled @@ -1372,8 +1372,14 @@ def forward( loss = None if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + # move labels to correct device to enable PP + labels = labels.to(logits.device) + if self.config.num_labels > 1: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + else: + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] @@ -1467,8 +1473,16 @@ def forward( loss = None if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1)) + if self.config.num_labels > 1: + loss_fct = CrossEntropyLoss() + # move labels to correct device to enable PP + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + else: + loss_fct = MSELoss() + # move labels to correct device to enable PP + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1), labels.view(-1)) if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]