diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index f94522923b..50392526db 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -774,18 +774,12 @@ def cross_entropy_loss(logits, labels): loss = loss_fct(logits, labels) return loss - if self.is_encoder_decoder: - labels = concatenated_batch["concatenated_labels"].clone() - else: - labels = concatenated_batch["concatenated_input_ids"].clone() - attention_mask = concatenated_batch["concatenated_attention_mask"] - labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) - + labels = concatenated_batch["concatenated_labels"].clone() chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) all_logps = self.get_batch_logps( all_logits, - concatenated_batch["concatenated_labels"], + labels, average_log_prob=True, is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, @@ -794,8 +788,12 @@ def cross_entropy_loss(logits, labels): chosen_logps = all_logps[:len_chosen] rejected_logps = all_logps[len_chosen:] - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] + if not self.is_encoder_decoder: + chosen_logits = all_logits[:len_chosen, :-1, :] + rejected_logits = all_logits[len_chosen:, :-1, :] + else: + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] if self.aux_loss_enabled: return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)