Skip to content

Commit

Permalink
[DPO] Adding weighted preference optimization (WPO) (#2141)
Browse files Browse the repository at this point in the history
* skeleton

* add weighting arg in config

* formatting

* fix doc

* do not compute gradients in weighting term

* fixed detach

* add WPO doc

---------

Co-authored-by: Kashif Rasul <[email protected]>
  • Loading branch information
gaetanlop and kashif authored Oct 8, 2024
1 parent 511c92c commit ed9ea74
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 19 deletions.
4 changes: 4 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing th

The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://huggingface.co/papers/2405.16436) that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, set the `rpo_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this weight to `1.0`.

### WPO loss

The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`].

### For Mixture of Experts Models: Enabling the auxiliary loss

MOEs are the most efficient if the load is about equally distributed between experts.
Expand Down
41 changes: 41 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,47 @@ def test_dpo_trainer(self, name, loss_type, pre_compute):
if param.sum() != 0:
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)

def test_dpo_trainer_with_weighting(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
eval_strategy="steps",
beta=0.1,
loss_type="sigmoid",
precompute_ref_log_probs=False,
use_weighting=True,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

trainer = DPOTrainer(
model=self.model,
ref_model=self.ref_model,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)

@parameterized.expand(
[
[None, "Test when rpo_alpha is set to None"],
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class DPOConfig(TrainingArguments):
- `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper.
- `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
- `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
use_weighting (`bool`, *optional*, defaults to `False`):
Whether or not to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper.
label_pad_token_id (`int`, *optional*, defaults to `-100`):
Label pad token id. This argument is required if you want to use the default data collator.
padding_value (`Optional[int]`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -150,6 +151,7 @@ class DPOConfig(TrainingArguments):
"apo_zero",
"apo_down",
] = "sigmoid"
use_weighting: bool = False
label_pad_token_id: int = -100
padding_value: Optional[int] = None
truncation_mode: str = "keep_end"
Expand Down
73 changes: 55 additions & 18 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@ def make_inputs_require_grad(module, input, output):
self.label_smoothing = args.label_smoothing
self.loss_type = args.loss_type
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
self.use_weighting = args.use_weighting
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
warnings.warn(
Expand Down Expand Up @@ -1346,17 +1347,22 @@ def get_batch_logps(
labels: torch.LongTensor,
label_pad_token_id: int = -100,
is_encoder_decoder: bool = False,
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
use_weighting: bool = False,
) -> Tuple[torch.FloatTensor, torch.LongTensor, Optional[torch.FloatTensor]]:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
label_pad_token_id: The label pad token id.
is_encoder_decoder: Whether the model is an encoder-decoder model.
use_weighting: Whether to apply weighting as done in the [WPO](https://huggingface.co/papers/2406.11827) paper.
Returns:
A Tuple of two tensor of shape ((batch_size,), (batch_size,)) containing the sum of log probabilities of the given labels under the given logits in the first tensor and the number of non-masked tokens in the second tensor.
Returns
A Tuple of three tensors of shape ((batch_size,), (batch_size,), Optional[(batch_size,)]) containing:
- The sum of log probabilities of the given labels under the given logits.
- The number of non-masked tokens.
- The wpo weighting (if use_weighting is True, otherwise None).
"""
if logits.shape[:-1] != labels.shape:
raise ValueError(
Expand All @@ -1373,7 +1379,17 @@ def get_batch_logps(

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
all_logps = (per_token_logps * loss_mask).sum(-1)

all_weights = None
if use_weighting:
# eqn (2) of the WPO paper: https://huggingface.co/papers/2406.11827
probs = F.softmax(logits, dim=-1)
weights_adjustment_factor = torch.log((probs**2).sum(-1))
per_token_logps_adjusted = per_token_logps - weights_adjustment_factor
all_weights = ((per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1)).detach()

return all_logps, loss_mask.sum(-1), all_weights

def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
Expand Down Expand Up @@ -1419,12 +1435,13 @@ def concatenated_forward(
seq_len = concatenated_batch["concatenated_labels"].shape[1]
all_logits = all_logits[:, -seq_len:]

all_logps, size_completion = self.get_batch_logps(
all_logps, size_completion, all_weights = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
# average_log_prob=self.loss_type == "ipo",
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
use_weighting=self.use_weighting,
)

def cross_entropy_loss(logits, labels):
Expand All @@ -1447,16 +1464,30 @@ def cross_entropy_loss(logits, labels):
if self.loss_type == "ipo":
all_logps = all_logps / size_completion

policy_weights = None
if self.use_weighting:
chosen_weights = all_weights[:len_chosen]
rejected_weights = all_weights[len_chosen:]
policy_weights = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1)

chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]

chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
return (
chosen_logps,
rejected_logps,
chosen_logits,
rejected_logits,
nll_loss,
policy_weights,
outputs.aux_loss,
)

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, policy_weights)

def get_batch_loss_metrics(
self,
Expand All @@ -1474,9 +1505,10 @@ def get_batch_loss_metrics(
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
policy_weights,
) = forward_output[:6]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
aux_loss = forward_output[6]

# if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
if (
Expand Down Expand Up @@ -1510,6 +1542,9 @@ def get_batch_loss_metrics(
# RPO loss from V3 of the paper:
losses = losses + policy_nll_loss * self.args.rpo_alpha

if self.use_weighting:
losses = losses * policy_weights

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
Expand Down Expand Up @@ -1740,15 +1775,17 @@ def create_model_card(
if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

citation = textwrap.dedent("""\
@inproceedings{rafailov2023direct,
title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}},
author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn},
year = 2023,
booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023},
url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html},
editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine},
}""")
citation = textwrap.dedent(
"""\
@inproceedings{rafailov2023direct,
title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}},
author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn},
year = 2023,
booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023},
url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html},
editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine},
}"""
)

model_card = generate_model_card(
base_model=base_model,
Expand Down

0 comments on commit ed9ea74

Please sign in to comment.