diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index ea3f24a39d..fe43b133d9 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -130,6 +130,9 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d pixel_values = [torch.tensor(example["pixel_values"]) for example in examples] if "pixel_attention_mask" in examples[0]: pixel_attention_mask = [torch.tensor(example["pixel_attention_mask"]) for example in examples] + if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]: + ref_chosen_logps = torch.tensor([example["ref_chosen_logps"] for example in examples]) + ref_rejected_logps = torch.tensor([example["ref_rejected_logps"] for example in examples]) # Pad output = {} @@ -145,6 +148,9 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d output["pixel_attention_mask"] = pad(pixel_attention_mask, padding_value=0) if "image_sizes" in examples[0]: output["image_sizes"] = torch.tensor([example["image_sizes"] for example in examples]) + if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]: + output["ref_chosen_logps"] = ref_chosen_logps + output["ref_rejected_logps"] = ref_rejected_logps return output @@ -162,7 +168,7 @@ class DPOTrainer(Trainer): args (`DPOConfig`): The DPO config arguments to use for training. data_collator (`transformers.DataCollator`): - The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used + The data collator to use for training. If None is specified, the default data collator (`PreferenceCollator`) will be used which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. train_dataset (`datasets.Dataset`): The dataset to use for training. @@ -672,9 +678,16 @@ def _set_signature_columns_if_needed(self): # If `self.args.remove_unused_columns` is True, non-signature columns are removed. # By default, this method sets `self._signature_columns` to the model's expected inputs. # In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work. - # Instead, we set them to the columns expected by `DPODataCollatorWithPadding`, hence the override. + # Instead, we set them to the columns expected by `PreferenceCollator`, hence the override. if self._signature_columns is None: - self._signature_columns = ["prompt_input_ids", "chosen_input_ids", "rejected_input_ids", "image_sizes"] + self._signature_columns = [ + "prompt_input_ids", + "chosen_input_ids", + "rejected_input_ids", + "image_sizes", + "ref_chosen_logps", + "ref_rejected_logps", + ] def get_train_dataloader(self) -> DataLoader: """