Skip to content

Commit

Permalink
Before update the tr_loss, make sure tr_loss_step is in the same devi…
Browse files Browse the repository at this point in the history
…ce. (#1439)

* before update the loss from dpo, make sure it's in the same device of tr_loss

* Update trl/trainer/dpo_trainer.py

Co-authored-by: guy1992l <[email protected]>

---------

Co-authored-by: guy1992l <[email protected]>
  • Loading branch information
pengwei715 and guy1992l authored Mar 19, 2024
1 parent abc7301 commit f976c6d
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,8 @@ def compute_loss(
with compute_loss_context_manager():
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
loss = loss.to(self.args.device)
# force log the metrics
self.store_metrics(metrics, train_eval="train")

Expand Down

0 comments on commit f976c6d

Please sign in to comment.