Skip to content

Commit

Permalink
🔮 Fix unused precomputed ref log probs in DPO (#2431)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakru012 authored Dec 3, 2024
1 parent 9001a86 commit 9ff79a6
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
pixel_values = [torch.tensor(example["pixel_values"]) for example in examples]
if "pixel_attention_mask" in examples[0]:
pixel_attention_mask = [torch.tensor(example["pixel_attention_mask"]) for example in examples]
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
ref_chosen_logps = torch.tensor([example["ref_chosen_logps"] for example in examples])
ref_rejected_logps = torch.tensor([example["ref_rejected_logps"] for example in examples])

# Pad
output = {}
Expand All @@ -145,6 +148,9 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
output["pixel_attention_mask"] = pad(pixel_attention_mask, padding_value=0)
if "image_sizes" in examples[0]:
output["image_sizes"] = torch.tensor([example["image_sizes"] for example in examples])
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
output["ref_chosen_logps"] = ref_chosen_logps
output["ref_rejected_logps"] = ref_rejected_logps

return output

Expand All @@ -162,7 +168,7 @@ class DPOTrainer(Trainer):
args (`DPOConfig`):
The DPO config arguments to use for training.
data_collator (`transformers.DataCollator`):
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
The data collator to use for training. If None is specified, the default data collator (`PreferenceCollator`) will be used
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
train_dataset (`datasets.Dataset`):
The dataset to use for training.
Expand Down Expand Up @@ -672,9 +678,16 @@ def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
# In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
# Instead, we set them to the columns expected by `DPODataCollatorWithPadding`, hence the override.
# Instead, we set them to the columns expected by `PreferenceCollator`, hence the override.
if self._signature_columns is None:
self._signature_columns = ["prompt_input_ids", "chosen_input_ids", "rejected_input_ids", "image_sizes"]
self._signature_columns = [
"prompt_input_ids",
"chosen_input_ids",
"rejected_input_ids",
"image_sizes",
"ref_chosen_logps",
"ref_rejected_logps",
]

def get_train_dataloader(self) -> DataLoader:
"""
Expand Down

0 comments on commit 9ff79a6

Please sign in to comment.