diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 123f935208..ccbe254019 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -666,8 +666,7 @@ def odds_ratio_loss( log_odds = (policy_chosen_logps - policy_rejected_logps) - ( torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps)) ) - sig_ratio = F.sigmoid(log_odds) - ratio = torch.log(sig_ratio) + ratio = F.logsigmoid(log_odds) losses = self.beta * ratio chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()