From 02f4e750c07c5a470f2d82a3a59e011401b5c63e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 16 Oct 2024 10:00:27 +0200 Subject: [PATCH] DPO support `remove_unused_columns` (#2233) --- trl/trainer/dpo_trainer.py | 41 ++++++++++++++++---------------------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 30dffc1f78..8b9843cd91 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -714,19 +714,6 @@ def make_inputs_require_grad(module, input, output): is_encoder_decoder=self.is_encoder_decoder, ) - if args.remove_unused_columns: - args.remove_unused_columns = False - # warn users - warnings.warn( - "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" - " we have set it for you, but you should do it yourself in the future.", - UserWarning, - ) - - self.use_dpo_data_collator = True - else: - self.use_dpo_data_collator = False - if not disable_dropout: warnings.warn( "You passed `disable_dropout` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." @@ -937,6 +924,23 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper): model.eval() return model + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs. + # In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work. + # Instead, we set them to the columns expected by `DPODataCollatorWithPadding`, hence the override. + if self._signature_columns is None: + self._signature_columns = [ + "chosen_input_ids", + "chosen_attention_mask", + "chosen_labels", + "rejected_input_ids", + "rejected_attention_mask", + "rejected_labels", + "prompt_input_ids", + "prompt_attention_mask", + ] + def get_train_dataloader(self) -> DataLoader: """ Returns the training [`~torch.utils.data.DataLoader`]. @@ -1544,12 +1548,6 @@ def compute_loss( inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: - if not self.use_dpo_data_collator: - warnings.warn( - "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " - "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" - ) - compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext() with compute_loss_context_manager: loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") @@ -1616,11 +1614,6 @@ def prediction_step( prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ): - if not self.use_dpo_data_collator: - warnings.warn( - "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " - "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" - ) if ignore_keys is None: if hasattr(model, "config"): ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])