diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index ba3bb3a32e..6ab8fba5a4 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -818,6 +818,8 @@ def dpo_loss( else: ref_logratios = reference_chosen_logps - reference_rejected_logps + pi_logratios = pi_logratios.to(self.accelerator.device) + ref_logratios = ref_logratios.to(self.accelerator.device) logits = pi_logratios - ref_logratios # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. @@ -853,8 +855,19 @@ def dpo_loss( f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']" ) - chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() - rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + chosen_rewards = ( + self.beta + * ( + policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device) + ).detach() + ) + rejected_rewards = ( + self.beta + * ( + policy_rejected_logps.to(self.accelerator.device) + - reference_rejected_logps.to(self.accelerator.device) + ).detach() + ) return losses, chosen_rewards, rejected_rewards