Skip to content

Commit

Permalink
[KTO] fix various bugs (#1402)
Browse files Browse the repository at this point in the history
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

* fix choice of mismatched examples for KL term

* describe weights

* fix hanging issue in distributed training

* linting

* move metrics to cpu

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <[email protected]>

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

---------

Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: lewtun <[email protected]>
  • Loading branch information
3 people authored Mar 8, 2024
1 parent 22b4f54 commit 4d862da
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 21 deletions.
9 changes: 8 additions & 1 deletion docs/source/kto_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,18 @@ The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that

## Using the `KTOTrainer`

For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train, a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response, the `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.

The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).

The `desirable_weight` and `undesirable_weight` refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` * number of positives) to (`undesirable_weight` * number of negatives) is in the range 1:1 to 4:3.

```py
training_args = KTOConfig(
beta=0.1,
desirable_weight=1.0,
undesirable_weight=1.0,
)

kto_trainer = KTOTrainer(
Expand Down
42 changes: 22 additions & 20 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,18 @@ def make_inputs_require_grad(module, input, output):
self.undesirable_weight = args.undesirable_weight

# get KL datasets
train_KL_dataset = train_dataset.map(self.get_KL_dataset, batched=True, batch_size=1000)
total_batch_size = (
max(torch.cuda.device_count(), 1) * args.per_device_train_batch_size * args.gradient_accumulation_steps
)
if total_batch_size <= 1:
raise ValueError(
"Batch size is 1 (too small). KTO will not work properly because the KL term will be equivalent to the implied reward."
)
# note: for best results, mismatched outputs y' used to estimate the KL term for a batch should be the
# same as the matched outputs y used to estimate the rewards in that batch, just paired with different x
train_KL_dataset = train_dataset.map(self.get_KL_dataset, batched=True, batch_size=total_batch_size)
if eval_dataset is not None:
eval_KL_dataset = eval_dataset.map(self.get_KL_dataset, batched=True, batch_size=1000)
eval_KL_dataset = eval_dataset.map(self.get_KL_dataset, batched=True, batch_size=total_batch_size)

# tokenize the datasets
train_dataset = train_dataset.map(
Expand Down Expand Up @@ -669,7 +678,7 @@ def build_tokenized_answer(self, prompt, answer):

def get_KL_dataset(self, batch) -> Dict:
"""Creates mismatched pairs of prompts and completions for the KL dataset."""
batch["completion"] = random.sample(batch["completion"], len(batch["completion"]))
batch["completion"] = batch["completion"][::-1]
return batch

def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None, prefix="") -> Dict:
Expand Down Expand Up @@ -933,6 +942,7 @@ def kto_loss(
(self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
0,
)

return losses, chosen_rewards, rejected_rewards, KL

def get_batch_loss_metrics(
Expand Down Expand Up @@ -990,26 +1000,18 @@ def get_batch_loss_metrics(
reference_KL_logps,
)

# lists can't be empty -- if they are, then accelerate.gather will hang
if policy_chosen_logps.shape[0] == 0:
policy_chosen_logps = torch.Tensor([torch.nan]).to(self.accelerator.device)

if policy_rejected_logps.shape[0] == 0:
policy_rejected_logps = torch.Tensor([torch.nan]).to(self.accelerator.device)

mean_chosen_reward = self.accelerator.gather(chosen_rewards.detach()).nanmean().nan_to_num(0)
mean_rejected_reward = self.accelerator.gather(rejected_rewards.detach()).nanmean().nan_to_num(0)
mean_margin = mean_chosen_reward - mean_rejected_reward
mean_logps_chosen = self.accelerator.gather(policy_chosen_logps.detach()).nanmean().nan_to_num(0)
mean_logps_rejected = self.accelerator.gather(policy_rejected_logps.detach()).nanmean().nan_to_num(0)
mean_chosen_reward = chosen_rewards.nanmean().detach()
mean_rejected_reward = rejected_rewards.nanmean().detach()
mean_chosen_logps = policy_chosen_logps.nanmean().detach()
mean_rejected_logps = policy_rejected_logps.nanmean().detach()

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = mean_chosen_reward.cpu()
metrics[f"{prefix}rewards/rejected"] = mean_rejected_reward.cpu()
metrics[f"{prefix}rewards/margins"] = mean_margin.cpu()
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather(mean_chosen_reward).nanmean().cpu()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather(mean_rejected_reward).nanmean().cpu()
metrics[f"{prefix}rewards/margins"] = metrics[f"{prefix}rewards/chosen"] - metrics[f"{prefix}rewards/rejected"]
metrics[f"{prefix}kl"] = kl.item() # has already been gathered in kto_loss
metrics[f"{prefix}logps/rejected"] = mean_logps_chosen.cpu()
metrics[f"{prefix}logps/chosen"] = mean_logps_rejected.cpu()
metrics[f"{prefix}logps/chosen"] = self.accelerator.gather(mean_chosen_logps).nanmean().cpu()
metrics[f"{prefix}logps/rejected"] = self.accelerator.gather(mean_rejected_logps).nanmean().cpu()

loss = (
losses.mean()
Expand Down

0 comments on commit 4d862da

Please sign in to comment.