Skip to content

Commit

Permalink
fix num logits to keep
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Jan 7, 2025
1 parent d806e31 commit 4ade2eb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ def test_dpo_trainer_use_num_logits_to_keep(self):
model=model,
ref_model=None,
args=training_args,
tokenizer=tokenizer,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)
Expand All @@ -1122,7 +1122,7 @@ def test_dpo_trainer_use_num_logits_to_keep(self):
model=model,
ref_model=None,
args=training_args,
tokenizer=tokenizer,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,8 +1161,8 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
# [0, 0, 0, x, x, x, 0]]
# ^ start computing logits from here ([:, -(7-3+1):])
first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
num_logits_to_keep = loss_mask.shape[1] - first_compute_index
model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label
num_logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label
model_kwargs["num_logits_to_keep"] = num_logits_to_keep

if self.padding_free:
# Flatten the input_ids, position_ids, and loss_mask
Expand Down

0 comments on commit 4ade2eb

Please sign in to comment.