Skip to content

Commit

Permalink
Allow separate devices for target/ref models. (#1190)
Browse files Browse the repository at this point in the history
* Allow separate devices for target/ref models.

* Remove original/duplicate.

* Cleanup original, black formatting.

---------

Co-authored-by: Jon Durbin <[email protected]>
  • Loading branch information
jondurbin and jon-convai authored Jan 8, 2024
1 parent d5910b0 commit dbcb2f0
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,8 @@ def dpo_loss(
else:
ref_logratios = reference_chosen_logps - reference_rejected_logps

pi_logratios = pi_logratios.to(self.accelerator.device)
ref_logratios = ref_logratios.to(self.accelerator.device)
logits = pi_logratios - ref_logratios

# The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
Expand Down Expand Up @@ -853,8 +855,19 @@ def dpo_loss(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
)

chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
chosen_rewards = (
self.beta
* (
policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
).detach()
)
rejected_rewards = (
self.beta
* (
policy_rejected_logps.to(self.accelerator.device)
- reference_rejected_logps.to(self.accelerator.device)
).detach()
)

return losses, chosen_rewards, rejected_rewards

Expand Down

0 comments on commit dbcb2f0

Please sign in to comment.