Skip to content

Commit

Permalink
[KTO] fix interleaving, reporting, hanging bugs (#1499)
Browse files Browse the repository at this point in the history
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

* fix choice of mismatched examples for KL term

* describe weights

* fix hanging issue in distributed training

* linting

* move metrics to cpu

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <[email protected]>

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* fix tokenization error: lack of bos

* change user warning for weight hyperparams

* minor update to docs

* reshape attention mask

* reformat

* add test for bos/eos tokens

* move dependency location

* Update tests/test_kto_trainer.py

* don't report nan metrics

* don't report nan metrics and remove data interleaving

* fix bugs in calculating metrics

* no need to gather KL term

* minor changes

* use nanmean for losses

* remove disabling of wandb

* revert changes

---------

Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: lewtun <[email protected]>
  • Loading branch information
3 people authored Apr 3, 2024
1 parent ab0d11d commit 4f8057a
Showing 1 changed file with 57 additions and 52 deletions.
109 changes: 57 additions & 52 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch.nn.functional as F
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available, tqdm
from datasets import Dataset, concatenate_datasets, interleave_datasets
from datasets import Dataset, concatenate_datasets
from torch.utils.data import DataLoader, SequentialSampler
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -485,6 +485,10 @@ def make_inputs_require_grad(module, input, output):
self.undesirable_weight = args.undesirable_weight

with PartialState().local_main_process_first():
# Shuffle the datasets
train_dataset = train_dataset.shuffle(seed=args.data_seed)
if eval_dataset is not None:
eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
# Tokenize and prepare the training datasets
train_dataset = train_dataset.map(
_tokenize,
Expand All @@ -500,8 +504,8 @@ def make_inputs_require_grad(module, input, output):
raise ValueError(
"Batch size is 1 (too small). KTO will not work properly because the KL term will be equivalent to the implied reward."
)
# Note: for best results, mismatched outputs y' used to estimate the KL term for a batch should be the
# same as the matched outputs y used to estimate the rewards in that batch, just paired with different x
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
train_kl_dataset = train_dataset.map(
_get_kl_dataset, batched=True, batch_size=total_batch_size, desc="Extracting KL train dataset"
)
Expand Down Expand Up @@ -601,30 +605,12 @@ def make_inputs_require_grad(module, input, output):
UserWarning,
)

# split the dataset and interleave them together with equal probability of choosing chosen or rejected
interleaved_train_dataset = interleave_datasets(
[desirable, undesirable],
stopping_strategy="all_exhausted",
)
interleaved_train_dataset = interleaved_train_dataset.shuffle(seed=args.data_seed)

if eval_dataset is not None:
interleaved_eval_dataset = interleave_datasets(
[
eval_dataset.filter(lambda x: x["label"], num_proc=args.dataset_num_proc),
eval_dataset.filter(lambda x: not x["label"], num_proc=args.dataset_num_proc),
],
stopping_strategy="all_exhausted",
)
else:
interleaved_eval_dataset = None

super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=interleaved_train_dataset,
eval_dataset=interleaved_eval_dataset,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
Expand Down Expand Up @@ -974,11 +960,11 @@ def kto_loss(
"""Compute the KTO loss for a batch of policy and reference model log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
Returns:
Expand All @@ -996,17 +982,17 @@ def kto_loss(
chosen_rewards = self.beta * chosen_logratios.detach()
else:
# lists can't be empty -- if they are, then accelerate.gather will hang
chosen_losses = torch.Tensor([torch.nan]).to(self.accelerator.device)
chosen_rewards = torch.Tensor([torch.nan]).to(self.accelerator.device)
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)

if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
rejected_logratios = policy_rejected_logps - reference_rejected_logps
rejected_losses = 1 - F.sigmoid(self.beta * (KL - rejected_logratios))
rejected_rewards = self.beta * rejected_logratios.detach()
else:
# lists can't be empty -- if they are, then accelerate.gather will hang
rejected_losses = torch.Tensor([torch.nan]).to(self.accelerator.device)
rejected_rewards = torch.Tensor([torch.nan]).to(self.accelerator.device)
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)

losses = torch.cat(
(self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
Expand Down Expand Up @@ -1061,7 +1047,7 @@ def get_batch_loss_metrics(
reference_KL_logps,
) = self.forward(self.ref_model, batch)

losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
losses, chosen_rewards, rejected_rewards, KL = self.kto_loss(
policy_chosen_logps,
policy_rejected_logps,
policy_KL_logps,
Expand All @@ -1070,25 +1056,38 @@ def get_batch_loss_metrics(
reference_KL_logps,
)

mean_chosen_reward = chosen_rewards.nanmean().detach()
mean_rejected_reward = rejected_rewards.nanmean().detach()
mean_chosen_logps = policy_chosen_logps.nanmean().detach()
mean_rejected_logps = policy_rejected_logps.nanmean().detach()
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)

all_num_chosen = self.accelerator.gather(num_chosen)
all_num_rejected = self.accelerator.gather(num_rejected)

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather(mean_chosen_reward).nanmean().cpu()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather(mean_rejected_reward).nanmean().cpu()
metrics[f"{prefix}rewards/margins"] = metrics[f"{prefix}rewards/chosen"] - metrics[f"{prefix}rewards/rejected"]
metrics[f"{prefix}kl"] = kl.item() # has already been gathered in kto_loss
metrics[f"{prefix}logps/chosen"] = self.accelerator.gather(mean_chosen_logps).nanmean().cpu()
metrics[f"{prefix}logps/rejected"] = self.accelerator.gather(mean_rejected_logps).nanmean().cpu()

loss = (
losses.mean()
if losses.shape[0] != 0
else torch.tensor(float("nan"), requires_grad=True).to(self.accelerator.device)
)
return loss, metrics

if all_num_chosen.sum().item() > 0:
metrics[f"{prefix}rewards/chosen"] = (
(self.accelerator.gather(chosen_rewards.mean()) * all_num_chosen).nansum() / all_num_chosen.sum()
).item()
metrics[f"{prefix}logps/chosen"] = (
(self.accelerator.gather(policy_chosen_logps.mean()) * all_num_chosen).nansum() / all_num_chosen.sum()
).item()

if all_num_rejected.sum().item() > 0:
metrics[f"{prefix}rewards/rejected"] = (
(self.accelerator.gather(rejected_rewards.mean()) * all_num_rejected).nansum() / all_num_rejected.sum()
).item()
metrics[f"{prefix}logps/rejected"] = (
(self.accelerator.gather(policy_rejected_logps.mean()) * all_num_rejected).nansum()
/ all_num_rejected.sum()
).item()

metrics[f"{prefix}kl"] = KL.item()
if all_num_chosen.sum().item() > 0 and all_num_rejected.sum().item() > 0:
metrics[f"{prefix}rewards/margins"] = (
metrics[f"{prefix}rewards/chosen"] - metrics[f"{prefix}rewards/rejected"]
)

return losses.nanmean(), metrics

def compute_loss(
self,
Expand All @@ -1101,9 +1100,13 @@ def compute_loss(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
with compute_loss_context_manager():
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
loss = loss.to(self.args.device)
# force log the metrics
if self.accelerator.is_main_process:
self.store_metrics(metrics, train_eval="train")
Expand All @@ -1114,10 +1117,12 @@ def compute_loss(

def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
if isinstance(value, list):
self._stored_metrics[train_eval][key].extend(value)
else:
self._stored_metrics[train_eval][key].append(value)

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
# We use a sequential sampler for training as the order of the interleaved dataset is important
if self.train_dataset is None or not has_length(self.train_dataset):
return None
return SequentialSampler(self.train_dataset)
Expand Down

0 comments on commit 4f8057a

Please sign in to comment.