Skip to content

Commit

Permalink
🧘 Replace F.log(F.sigmoid(log_odds) with F.logsigmoid(log_odds) (#…
Browse files Browse the repository at this point in the history
…2274)

Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
zhanwenchen and qgallouedec authored Oct 24, 2024
1 parent 0de75b2 commit 57ba9b9
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,7 @@ def odds_ratio_loss(
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
)
sig_ratio = F.sigmoid(log_odds)
ratio = torch.log(sig_ratio)
ratio = F.logsigmoid(log_odds)
losses = self.beta * ratio

chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
Expand Down

0 comments on commit 57ba9b9

Please sign in to comment.