Skip to content

Commit

Permalink
DPO support remove_unused_columns (#2233)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Oct 16, 2024
1 parent 2ba3005 commit 02f4e75
Showing 1 changed file with 17 additions and 24 deletions.
41 changes: 17 additions & 24 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,19 +714,6 @@ def make_inputs_require_grad(module, input, output):
is_encoder_decoder=self.is_encoder_decoder,
)

if args.remove_unused_columns:
args.remove_unused_columns = False
# warn users
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
" we have set it for you, but you should do it yourself in the future.",
UserWarning,
)

self.use_dpo_data_collator = True
else:
self.use_dpo_data_collator = False

if not disable_dropout:
warnings.warn(
"You passed `disable_dropout` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
Expand Down Expand Up @@ -937,6 +924,23 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
model.eval()
return model

def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
# In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
# Instead, we set them to the columns expected by `DPODataCollatorWithPadding`, hence the override.
if self._signature_columns is None:
self._signature_columns = [
"chosen_input_ids",
"chosen_attention_mask",
"chosen_labels",
"rejected_input_ids",
"rejected_attention_mask",
"rejected_labels",
"prompt_input_ids",
"prompt_attention_mask",
]

def get_train_dataloader(self) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].
Expand Down Expand Up @@ -1544,12 +1548,6 @@ def compute_loss(
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)

compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
with compute_loss_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
Expand Down Expand Up @@ -1616,11 +1614,6 @@ def prediction_step(
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
):
if not self.use_dpo_data_collator:
warnings.warn(
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
if ignore_keys is None:
if hasattr(model, "config"):
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
Expand Down

0 comments on commit 02f4e75

Please sign in to comment.