Skip to content

Commit

Permalink
call orpo_loss_fn with shifted inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Dec 19, 2024
1 parent f4979b0 commit 5c6744f
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,7 @@ def make_inputs_require_grad(module, input, output):
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
self.orpo_loss_fn = LigerFusedLinearORPOLoss(
ignore_index=self.label_pad_token_id, beta=self.beta, is_encoder_decoder=self.is_encoder_decoder
)
self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
Expand Down Expand Up @@ -782,8 +780,10 @@ def concatenated_forward(
# return the final loss and aux_outputs tuple
loss, aux_outputs = self.orpo_loss_fn(
lm_head.weight,
outputs.last_hidden_state,
concatenated_batch["concatenated_labels"],
outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
concatenated_batch["concatenated_labels"][:, 1:]
if not self.is_encoder_decoder
else concatenated_batch["concatenated_labels"],
lm_head.bias if hasattr(lm_head, "bias") else None,
)

Expand Down

0 comments on commit 5c6744f

Please sign in to comment.