From ed9ea74b62ea5aab3a2c597c4d73fadafc51f5dc Mon Sep 17 00:00:00 2001 From: Gaetan LOPEZ LATOUCHE <66413927+gaetanlop@users.noreply.github.com> Date: Tue, 8 Oct 2024 13:52:54 -0400 Subject: [PATCH] [DPO] Adding weighted preference optimization (WPO) (#2141) * skeleton * add weighting arg in config * formatting * fix doc * do not compute gradients in weighting term * fixed detach * add WPO doc --------- Co-authored-by: Kashif Rasul --- docs/source/dpo_trainer.mdx | 4 ++ tests/test_dpo_trainer.py | 41 +++++++++++++++++++++ trl/trainer/dpo_config.py | 4 +- trl/trainer/dpo_trainer.py | 73 ++++++++++++++++++++++++++++--------- 4 files changed, 103 insertions(+), 19 deletions(-) diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index d5787b3eb3..72467c0b06 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -161,6 +161,10 @@ The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing th The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://huggingface.co/papers/2405.16436) that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, set the `rpo_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this weight to `1.0`. +### WPO loss + +The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`]. + ### For Mixture of Experts Models: Enabling the auxiliary loss MOEs are the most efficient if the load is about equally distributed between experts. diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index da36ce0f86..6dff6a3892 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -281,6 +281,47 @@ def test_dpo_trainer(self, name, loss_type, pre_compute): if param.sum() != 0: assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) + def test_dpo_trainer_with_weighting(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + loss_type="sigmoid", + precompute_ref_log_probs=False, + use_weighting=True, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = DPOTrainer( + model=self.model, + ref_model=self.ref_model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) + @parameterized.expand( [ [None, "Test when rpo_alpha is set to None"], diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 5cd8b649a9..b84dfb47dd 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -65,7 +65,8 @@ class DPOConfig(TrainingArguments): - `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper. - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - + use_weighting (`bool`, *optional*, defaults to `False`): + Whether or not to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper. label_pad_token_id (`int`, *optional*, defaults to `-100`): Label pad token id. This argument is required if you want to use the default data collator. padding_value (`Optional[int]`, *optional*, defaults to `None`): @@ -150,6 +151,7 @@ class DPOConfig(TrainingArguments): "apo_zero", "apo_down", ] = "sigmoid" + use_weighting: bool = False label_pad_token_id: int = -100 padding_value: Optional[int] = None truncation_mode: str = "keep_end" diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index ce21afdc2e..8b7423e885 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -810,6 +810,7 @@ def make_inputs_require_grad(module, input, output): self.label_smoothing = args.label_smoothing self.loss_type = args.loss_type self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + self.use_weighting = args.use_weighting self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) if self.aux_loss_enabled and self.aux_loss_coef == 0.0: warnings.warn( @@ -1346,7 +1347,8 @@ def get_batch_logps( labels: torch.LongTensor, label_pad_token_id: int = -100, is_encoder_decoder: bool = False, - ) -> Tuple[torch.FloatTensor, torch.LongTensor]: + use_weighting: bool = False, + ) -> Tuple[torch.FloatTensor, torch.LongTensor, Optional[torch.FloatTensor]]: """Compute the log probabilities of the given labels under the given logits. Args: @@ -1354,9 +1356,13 @@ def get_batch_logps( labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) label_pad_token_id: The label pad token id. is_encoder_decoder: Whether the model is an encoder-decoder model. + use_weighting: Whether to apply weighting as done in the [WPO](https://huggingface.co/papers/2406.11827) paper. - Returns: - A Tuple of two tensor of shape ((batch_size,), (batch_size,)) containing the sum of log probabilities of the given labels under the given logits in the first tensor and the number of non-masked tokens in the second tensor. + Returns + A Tuple of three tensors of shape ((batch_size,), (batch_size,), Optional[(batch_size,)]) containing: + - The sum of log probabilities of the given labels under the given logits. + - The number of non-masked tokens. + - The wpo weighting (if use_weighting is True, otherwise None). """ if logits.shape[:-1] != labels.shape: raise ValueError( @@ -1373,7 +1379,17 @@ def get_batch_logps( per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) + all_logps = (per_token_logps * loss_mask).sum(-1) + + all_weights = None + if use_weighting: + # eqn (2) of the WPO paper: https://huggingface.co/papers/2406.11827 + probs = F.softmax(logits, dim=-1) + weights_adjustment_factor = torch.log((probs**2).sum(-1)) + per_token_logps_adjusted = per_token_logps - weights_adjustment_factor + all_weights = ((per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1)).detach() + + return all_logps, loss_mask.sum(-1), all_weights def concatenated_forward( self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] @@ -1419,12 +1435,13 @@ def concatenated_forward( seq_len = concatenated_batch["concatenated_labels"].shape[1] all_logits = all_logits[:, -seq_len:] - all_logps, size_completion = self.get_batch_logps( + all_logps, size_completion, all_weights = self.get_batch_logps( all_logits, concatenated_batch["concatenated_labels"], # average_log_prob=self.loss_type == "ipo", is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, + use_weighting=self.use_weighting, ) def cross_entropy_loss(logits, labels): @@ -1447,6 +1464,12 @@ def cross_entropy_loss(logits, labels): if self.loss_type == "ipo": all_logps = all_logps / size_completion + policy_weights = None + if self.use_weighting: + chosen_weights = all_weights[:len_chosen] + rejected_weights = all_weights[len_chosen:] + policy_weights = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1) + chosen_logps = all_logps[:len_chosen] rejected_logps = all_logps[len_chosen:] @@ -1454,9 +1477,17 @@ def cross_entropy_loss(logits, labels): rejected_logits = all_logits[len_chosen:] if self.aux_loss_enabled: - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss) + return ( + chosen_logps, + rejected_logps, + chosen_logits, + rejected_logits, + nll_loss, + policy_weights, + outputs.aux_loss, + ) - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss) + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, policy_weights) def get_batch_loss_metrics( self, @@ -1474,9 +1505,10 @@ def get_batch_loss_metrics( policy_chosen_logits, policy_rejected_logits, policy_nll_loss, - ) = forward_output[:5] + policy_weights, + ) = forward_output[:6] if self.aux_loss_enabled: - aux_loss = forward_output[5] + aux_loss = forward_output[6] # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model if ( @@ -1510,6 +1542,9 @@ def get_batch_loss_metrics( # RPO loss from V3 of the paper: losses = losses + policy_nll_loss * self.args.rpo_alpha + if self.use_weighting: + losses = losses * policy_weights + prefix = "eval_" if train_eval == "eval" else "" metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() @@ -1740,15 +1775,17 @@ def create_model_card( if hasattr(self.model.config, "unsloth_version"): tags.append("unsloth") - citation = textwrap.dedent("""\ - @inproceedings{rafailov2023direct, - title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}}, - author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn}, - year = 2023, - booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023}, - url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html}, - editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine}, - }""") + citation = textwrap.dedent( + """\ + @inproceedings{rafailov2023direct, + title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}}, + author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn}, + year = 2023, + booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023}, + url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html}, + editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine}, + }""" + ) model_card = generate_model_card( base_model=base_model,