Skip to content

Commit

Permalink
Move misplaced line (#29117)
Browse files Browse the repository at this point in the history
Move misplaced line, improve code comment
  • Loading branch information
kno10 authored Feb 20, 2024
1 parent 9094abe commit a7ff2f2
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,11 +1176,11 @@ def forward(
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
# Ensure tensors are on the same device
shift_labels = shift_labels.to(shift_logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
Expand Down

0 comments on commit a7ff2f2

Please sign in to comment.