Skip to content

Commit

Permalink
Apply our fix to brand new model Wav2Vec2-BERT!
Browse files Browse the repository at this point in the history
  • Loading branch information
nevikw39 committed Feb 6, 2024
1 parent 6e7e8e1 commit ca515e5
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
Expand Down Expand Up @@ -1372,8 +1372,14 @@ def forward(

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
# move labels to correct device to enable PP
labels = labels.to(logits.device)
if self.config.num_labels > 1:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
else:
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))

if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
Expand Down Expand Up @@ -1467,8 +1473,16 @@ def forward(

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
if self.config.num_labels > 1:
loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
else:
loss_fct = MSELoss()
# move labels to correct device to enable PP
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1), labels.view(-1))

if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
Expand Down

0 comments on commit ca515e5

Please sign in to comment.