diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 7ba14f0830..c1d7f2887a 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -320,6 +320,7 @@ def make_inputs_require_grad(module, input, output): self.model_adapter_name = model_adapter_name self.ref_adapter_name = ref_adapter_name + # Get the reference model if ref_model: self.ref_model = ref_model elif self.is_peft_model or False: @@ -328,48 +329,20 @@ def make_inputs_require_grad(module, input, output): else: self.ref_model = create_reference_model(model) - if processing_class is None: - raise ValueError( - "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding" - ) - if args.max_length is None: - warnings.warn( - "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init" - " it will be set to `512` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_length = 512 - if args.max_length is not None: - max_length = args.max_length - - if args.max_prompt_length is None: - warnings.warn( - "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init" - " it will be set to `128` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_prompt_length = 128 - if args.max_prompt_length is not None: - max_prompt_length = args.max_prompt_length - - max_completion_length = None - - if data_collator is None: - data_collator = DataCollatorForUnpairedPreference(pad_token_id=processing_class.pad_token_id) - + # Disable dropout if needed if args.disable_dropout: disable_dropout_in_model(model) if self.ref_model is not None: disable_dropout_in_model(self.ref_model) - self.max_length = max_length + # Define the data collator + self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id + if data_collator is None: + data_collator = DataCollatorForUnpairedPreference(pad_token_id=self.padding_value) + self.generate_during_eval = args.generate_during_eval self.label_pad_token_id = args.label_pad_token_id - self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id - self.max_prompt_length = max_prompt_length - self.truncation_mode = args.truncation_mode - self.max_completion_length = max_completion_length - self.processing_class = processing_class + self.max_length = args.max_length # metric self._stored_metrics = defaultdict(lambda: defaultdict(list)) @@ -388,7 +361,7 @@ def make_inputs_require_grad(module, input, output): # issued. model.warnings_issued["estimate_tokens"] = True - # 4. Handle the dataset - UNCOMMENT WHEN _prepare_dataset READY + # Dataset preparation train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") if eval_dataset is not None: if isinstance(eval_dataset, dict): @@ -399,7 +372,7 @@ def make_inputs_require_grad(module, input, output): else: eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") - # Calculate dataset desirability balance + # Calculate dataset desirability balance and display warning if weights are not in an ideal range num_desirable = max(sum(train_dataset["label"]), 1) num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary @@ -780,22 +753,28 @@ def get_batch_loss_metrics( all_num_rejected = self.accelerator.gather(num_rejected).sum().item() if all_num_chosen > 0: - metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item() - metrics["logps/chosen_sum"] = ( - self.accelerator.gather(model_output["chosen_logps"].nansum()).nansum().item() + metrics["rewards/chosen"] = ( + self.accelerator.gather(chosen_rewards.nansum()).nansum().item() / all_num_chosen + ) + metrics["logps/chosen"] = ( + self.accelerator.gather(model_output["chosen_logps"].nansum()).nansum().item() / all_num_chosen + ) + metrics["logits/chosen"] = ( + self.accelerator.gather(model_output["sum_chosen_logits"]).nansum().item() / all_num_chosen ) - metrics["logits/chosen_sum"] = self.accelerator.gather(model_output["sum_chosen_logits"]).nansum().item() - metrics["count/chosen"] = all_num_chosen if all_num_rejected > 0: - metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item() - metrics["logps/rejected_sum"] = ( - self.accelerator.gather(model_output["rejected_logps"].nansum()).nansum().item() + metrics["rewards/rejected"] = ( + self.accelerator.gather(rejected_rewards.nansum()).nansum().item() / all_num_rejected ) - metrics["logits/rejected_sum"] = ( - self.accelerator.gather(model_output["sum_rejected_logits"]).nansum().item() + metrics["logps/rejected"] = ( + self.accelerator.gather(model_output["rejected_logps"].nansum()).nansum().item() / all_num_rejected ) - metrics["count/rejected"] = all_num_rejected + metrics["logits/rejected"] = ( + self.accelerator.gather(model_output["sum_rejected_logits"]).nansum().item() / all_num_rejected + ) + + metrics["rewards/margins"] = metrics["rewards/chosen"] - metrics["rewards/rejected"] loss = losses.nanmean() @@ -845,7 +824,7 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True, - pad_token_id=self.processing_class.pad_token_id, + pad_token_id=self.padding_value, ) # if ref_output in batch use that otherwise use the reference model @@ -859,7 +838,7 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True, - pad_token_id=self.processing_class.pad_token_id, + pad_token_id=self.padding_value, ) else: ref_output = self.ref_model.generate( @@ -867,13 +846,13 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True, - pad_token_id=self.processing_class.pad_token_id, + pad_token_id=self.padding_value, ) - policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) + policy_output = pad_to_length(policy_output, self.max_length, self.padding_value) policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) - ref_output = pad_to_length(ref_output, self.max_length, self.processing_class.pad_token_id) + ref_output = pad_to_length(ref_output, self.max_length, self.padding_value) ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True) return policy_output_decoded, ref_output_decoded @@ -971,37 +950,11 @@ def evaluation_loop( return initial_output def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - """ - Log `logs` on the various objects watching training, including stored metrics. - - Args: - logs (`dict[str, float]`): - The values to log. - start_time (`float` or `None`, *optional*, defaults to `None`): - Start time of the training. - """ # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" - # train metrics should have no prefix, eval should have 'eval_' - prefix = "eval_" if train_eval == "eval" else "" - # accumulate average metrics from sums and lengths - for split in ["chosen", "rejected"]: - if f"count/{split}" in self._stored_metrics[train_eval]: - count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item() - for metric in ["rewards", "logps", "logits"]: - logs[f"{prefix}{metric}/{split}"] = ( - torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item() - / count_sum - ) - # delete obsolete metric - del self._stored_metrics[train_eval][f"{metric}/{split}_sum"] - del self._stored_metrics[train_eval][f"count/{split}"] - # calculate reward margin - if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: - logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"] # Add averaged stored metrics to logs for key, metrics in self._stored_metrics[train_eval].items(): - logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() + logs[key] = sum(metrics) / len(metrics) del self._stored_metrics[train_eval] if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):