From 4ade2eb83c2cb63da3621ea644cb4e2f68cf6681 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 7 Jan 2025 19:40:02 +0000 Subject: [PATCH] fix num logits to keep --- tests/test_dpo_trainer.py | 4 ++-- trl/trainer/dpo_trainer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index ea330ce17a..a786f6a41e 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1112,7 +1112,7 @@ def test_dpo_trainer_use_num_logits_to_keep(self): model=model, ref_model=None, args=training_args, - tokenizer=tokenizer, + processing_class=tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], ) @@ -1122,7 +1122,7 @@ def test_dpo_trainer_use_num_logits_to_keep(self): model=model, ref_model=None, args=training_args, - tokenizer=tokenizer, + processing_class=tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], ) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index f0cc2dfc2a..d51e7f7f1a 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1161,8 +1161,8 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to # [0, 0, 0, x, x, x, 0]] # ^ start computing logits from here ([:, -(7-3+1):]) first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() - num_logits_to_keep = loss_mask.shape[1] - first_compute_index - model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label + num_logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label + model_kwargs["num_logits_to_keep"] = num_logits_to_keep if self.padding_free: # Flatten the input_ids, position_ids, and loss_mask