diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f4251b98304c4e..fbba155f19d57c 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1176,11 +1176,11 @@ def forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) - # Enable model parallelism + # Ensure tensors are on the same device shift_labels = shift_labels.to(shift_logits.device) + loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits, shift_labels) if not return_dict: