From 14e0d788078be6406e580a2e8aa94cd451e5f909 Mon Sep 17 00:00:00 2001 From: Kawin Date: Thu, 29 Feb 2024 00:01:52 -0800 Subject: [PATCH] fix bugs in KTO implementation (#1380) * add warning for imbalanced data * update documentation * update script commands to be same as in dpo * use batch_size KL examples and batch_size target examples to calculate batch_size losses * fix deepspeed issue * speed up forward with no_grad for KL * add some removed metrics * Update trl/trainer/kto_trainer.py * Update trl/trainer/kto_trainer.py * Update trl/trainer/kto_trainer.py add reference to paper Co-authored-by: lewtun * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul * add more detailed comments * convert assert to ValueError * Update kto_trainer.py * precommit formatting --------- Co-authored-by: Kashif Rasul Co-authored-by: lewtun --- docs/source/_toctree.yml | 4 +- docs/source/kto_trainer.mdx | 9 +- examples/scripts/kto.py | 53 +++++--- trl/trainer/kto_trainer.py | 261 +++++++++++++++++++++++++----------- 4 files changed, 228 insertions(+), 99 deletions(-) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index cf17799282..f5f9654c42 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -29,14 +29,14 @@ title: Best of N Sampling - local: dpo_trainer title: DPO Trainer + - local: kto_trainer + title: KTO Trainer - local: ddpo_trainer title: Denoising Diffusion Policy Optimization - local: iterative_sft_trainer title: Iterative Supervised Fine-Tuning - local: text_environments title: Text Environments - - local: kto_trainer - title: KTO Trainer title: API - sections: - local: example_overview diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.mdx index 16c1f2c77b..9d6d57f29c 100644 --- a/docs/source/kto_trainer.mdx +++ b/docs/source/kto_trainer.mdx @@ -1,6 +1,10 @@ # KTO Trainer -TRL supports the Kahneman-Tversky Optimization (KTO) Trainer for training language models from unpaired preference data, as described in the [report](https://arxiv.org/abs/2402.01306) by Kawin Ethayarajh, Winnie Xu, Niklas Muennighoff, Dan Jurafsky, and Douwe Kiela. +TRL supports the Kahneman-Tversky Optimization (KTO) Trainer for aligning language models with binary feedback data (e.g., upvote/downvote), as described in the [paper](https://arxiv.org/abs/2402.01306) by Kawin Ethayarajh, Winnie Xu, Niklas Muennighoff, Dan Jurafsky, and Douwe Kiela. +For a full example have a look at [`examples/scripts/kto.py`]. + +Depending on how good your base model is, you may or may not need to do SFT before KTO. +This is different from standard RLHF and DPO, which always require SFT. ## Expected dataset format @@ -44,7 +48,8 @@ kto_dataset_dict = { } ``` -where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired or undesired. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. +where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`). +A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. ## Expected model format The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index c327588206..5726434da8 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -13,24 +13,44 @@ # limitations under the License. """ -Run the KTO training script with the following command with some example arguments: +Run the KTO training script with the following command with some example arguments. +In general, the optimal configuration for KTO will be similar to that of DPO: +# regular: python examples/scripts/kto.py \ - --model_name_or_path "gpt2" \ - --per_device_train_batch_size 8 \ - --gradient_accumulation_steps 2 \ - --learning_rate 1e-4 \ + --model_name_or_path=gpt2 \ + --per_device_train_batch_size 4 \ --max_steps 1000 \ - --report_to "wandb" \ - --gradient_checkpointing True \ - --output_dir="./test" \ - --use_peft True \ - --lora_r 64 \ - --lora_alpha 16 \ - --evaluation_strategy "steps" \ - --logging_first_step True \ + --learning_rate 1e-3 \ + --gradient_accumulation_steps 1 \ --logging_steps 10 \ - --eval_steps 500 + --eval_steps 500 \ + --output_dir="kto_anthropic_hh" \ + --warmup_steps 150 \ + --report_to wandb \ + --bf16 \ + --logging_first_step \ + --no_remove_unused_columns + +# peft: +python examples/scripts/kto.py \ + --model_name_or_path=gpt2 \ + --per_device_train_batch_size 4 \ + --max_steps 1000 \ + --learning_rate 1e-3 \ + --gradient_accumulation_steps 1 \ + --logging_steps 10 \ + --eval_steps 500 \ + --output_dir="kto_anthropic_hh" \ + --optim rmsprop \ + --warmup_steps 150 \ + --report_to wandb \ + --bf16 \ + --logging_first_step \ + --no_remove_unused_columns \ + --use_peft \ + --lora_r=16 \ + --lora_alpha=16 """ from dataclasses import dataclass, field @@ -57,7 +77,10 @@ def extract_anthropic_prompt(prompt_and_response): """Extract the anthropic prompt from a prompt and response pair.""" search_term = "\n\nAssistant:" search_term_idx = prompt_and_response.rfind(search_term) - assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" + + if search_term_idx == -1: + raise ValueError(f"Prompt and response does not contain '{search_term}'") + return prompt_and_response[: search_term_idx + len(search_term)] diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 4ce04d62a0..aa69641570 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -27,7 +27,7 @@ import torch.nn as nn import torch.nn.functional as F from accelerate.utils import is_deepspeed_available, tqdm -from datasets import Dataset, interleave_datasets +from datasets import Dataset, concatenate_datasets, interleave_datasets from torch.utils.data import DataLoader, SequentialSampler from transformers import ( AutoModelForCausalLM, @@ -312,25 +312,59 @@ def make_inputs_require_grad(module, input, output): self.desirable_weight = args.desirable_weight self.undesirable_weight = args.undesirable_weight - # tokenize the dataset - columns_to_remove = [col for col in train_dataset.column_names if col not in ["prompt", "completion", "label"]] + # get KL datasets + train_KL_dataset = train_dataset.map(self.get_KL_dataset, batched=True, batch_size=1000) + if eval_dataset is not None: + eval_KL_dataset = eval_dataset.map(self.get_KL_dataset, batched=True, batch_size=1000) + + # tokenize the datasets train_dataset = train_dataset.map( - self.shuffle_completion, batched=True, batch_size=128, remove_columns=columns_to_remove + lambda row: self.tokenize_row(row, prefix=""), remove_columns=train_dataset.column_names + ) + train_KL_dataset = train_KL_dataset.map( + lambda row: self.tokenize_row(row, prefix="KL_"), remove_columns=train_KL_dataset.column_names ) - train_dataset = train_dataset.map(self.tokenize_row) if eval_dataset is not None: - columns_to_remove = [ - col for col in eval_dataset.column_names if col not in ["prompt", "completion", "label"] - ] eval_dataset = eval_dataset.map( - self.shuffle_completion, batched=True, batch_size=128, remove_columns=columns_to_remove + lambda row: self.tokenize_row(row, prefix=""), remove_columns=eval_dataset.column_names + ) + eval_KL_dataset = eval_KL_dataset.map( + lambda row: self.tokenize_row(row, prefix="KL_"), remove_columns=eval_KL_dataset.column_names ) - eval_dataset = eval_dataset.map(self.tokenize_row) + + # merge the datasets + train_dataset = concatenate_datasets([train_dataset, train_KL_dataset], axis=1) + eval_dataset = concatenate_datasets([eval_dataset, eval_KL_dataset], axis=1) + + desirable = train_dataset.filter(lambda x: x["label"]) + undesirable = train_dataset.filter(lambda x: not x["label"]) + + if len(desirable) != len(undesirable): + # The lower and upper bounds come from Eq. (8) of https://arxiv.org/abs/2402.01306 + des_weight_lower_bound = (len(undesirable) * self.undesirable_weight / len(desirable)) * 1 + des_weight_upper_bound = (len(undesirable) * self.undesirable_weight / len(desirable)) * 1.33 + und_weight_lower_bound = (len(desirable) * self.desirable_weight / len(undesirable)) / 1.33 + und_weight_upper_bound = (len(desirable) * self.desirable_weight / len(undesirable)) / 1 + + des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound + und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound + + if not (des_weight_in_range or und_weight_in_range): + warnings.warn( + f""" + You have different amounts of desirable/positive and undesirable/negative examples but the \ + weights on the desirable and undesirable losses don't seem to be in an ideal range. Based \ + on your data, we recommend EITHER desirable_weight in \ + [{des_weight_lower_bound}, {des_weight_upper_bound}] or undesirable_weight in \ + [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). See the documentation \ + on how to optimally set these weights.""", + UserWarning, + ) # split the dataset and interleave them together with equal probability of choosing chosen or rejected interleaved_train_dataset = interleave_datasets( - [train_dataset.filter(lambda x: x["label"]), train_dataset.filter(lambda x: not x["label"])], + [desirable, undesirable], stopping_strategy="all_exhausted", ) interleaved_train_dataset = interleaved_train_dataset.shuffle(seed=args.data_seed) @@ -343,10 +377,6 @@ def make_inputs_require_grad(module, input, output): else: interleaved_eval_dataset = None - # Increase the effective batch size by 2x to account for the detached KL terms - args.per_device_train_batch_size = args.per_device_train_batch_size * 2 - args.per_device_eval_batch_size = args.per_device_eval_batch_size * 2 - super().__init__( model=model, args=args, @@ -437,14 +467,25 @@ def get_train_dataloader(self) -> DataLoader: # prepare dataloader data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) - reference_logps = [] + reference_completion_logps = [] + reference_KL_logps = [] + for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): - reference_logp = self.compute_reference_log_probs(padded_batch) - reference_logp = self.accelerator.gather_for_metrics(reference_logp) - reference_logps.append(reference_logp.cpu()) + reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch) + + reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) + reference_completion_logps.append(reference_completion_logp.cpu()) + + reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp) + reference_KL_logps.append(reference_KL_logp.cpu()) + + self.train_dataset = self.train_dataset.add_column( + name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() + ) + self.train_dataset = self.train_dataset.add_column( + name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy() + ) - all_reference_logps = torch.cat(reference_logps).float().numpy() - self.train_dataset = self.train_dataset.add_column(name="reference_logps", column=all_reference_logps) self._precomputed_train_ref_log_probs = True return super().get_train_dataloader() @@ -476,14 +517,24 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa # prepare dataloader data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) - reference_logps = [] + reference_completion_logps = [] + reference_KL_logps = [] + for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): - reference_logp = self.compute_reference_log_probs(padded_batch) - reference_logp = self.accelerator.gather_for_metrics(reference_logp) - reference_logps.append(reference_logp.cpu()) + reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch) + + reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) + reference_completion_logps.append(reference_completion_logp.cpu()) - all_reference_logps = torch.cat(reference_logps).float().numpy() - eval_dataset = eval_dataset.add_column(name="reference_logps", column=all_reference_logps) + reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp) + reference_KL_logps.append(reference_KL_logp.cpu()) + + eval_dataset = eval_dataset.add_column( + name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() + ) + eval_dataset = eval_dataset.add_column( + name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy() + ) # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs if self.eval_dataset is not None: @@ -500,38 +551,72 @@ def compute_reference_log_probs(self, padded_batch: Dict) -> Dict: self.model ).disable_adapter() if self.is_peft_model else nullcontext(): if self.is_encoder_decoder: - all_logits = self.model( + completion_logits = self.model( padded_batch["prompt_input_ids"], attention_mask=padded_batch["prompt_attention_mask"], decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), labels=padded_batch["completion_labels"], ).logits + + KL_logits = self.model( + padded_batch["KL_prompt_input_ids"], + attention_mask=padded_batch["KL_prompt_attention_mask"], + decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"), + labels=padded_batch["KL_completion_labels"], + ).logits else: - all_logits = self.model( + completion_logits = self.model( padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"], ).logits + + KL_logits = self.model( + padded_batch["KL_completion_input_ids"], + attention_mask=padded_batch["KL_completion_attention_mask"], + ).logits else: if self.is_encoder_decoder: - all_logits = self.ref_model( + completion_logits = self.ref_model( padded_batch["prompt_input_ids"], attention_mask=padded_batch["prompt_attention_mask"], decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), labels=padded_batch["completion_labels"], ).logits + + KL_logits = self.ref_model( + padded_batch["KL_prompt_input_ids"], + attention_mask=padded_batch["KL_prompt_attention_mask"], + decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"), + labels=padded_batch["KL_completion_labels"], + ).logits else: - all_logits = self.ref_model( + completion_logits = self.ref_model( padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"] ).logits - return self.get_batch_logps( - all_logits, + KL_logits = self.ref_model( + padded_batch["KL_completion_input_ids"], + attention_mask=padded_batch["KL_completion_attention_mask"], + ).logits + + completion_logps = self.get_batch_logps( + completion_logits, padded_batch["completion_labels"], average_log_prob=False, is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) + KL_logps = self.get_batch_logps( + KL_logits, + padded_batch["KL_completion_labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + return completion_logps, KL_logps + def build_tokenized_answer(self, prompt, answer): """ Llama tokenizer does not satisfy `enc(a + b) = enc(a) + enc(b)`. @@ -582,14 +667,12 @@ def build_tokenized_answer(self, prompt, answer): attention_mask=answer_attention_mask, ) - def shuffle_completion(self, batch) -> Dict: - batch["kl"] = [False] * len(batch["prompt"]) + [True] * len(batch["prompt"]) - batch["prompt"] = batch["prompt"] + batch["prompt"] - batch["completion"] = batch["completion"] + random.sample(batch["completion"], len(batch["completion"])) - batch["label"] = batch["label"] + batch["label"] + def get_KL_dataset(self, batch) -> Dict: + """Creates mismatched pairs of prompts and completions for the KL dataset.""" + batch["completion"] = random.sample(batch["completion"], len(batch["completion"])) return batch - def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) -> Dict: + def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None, prefix="") -> Dict: """Tokenize a single row from a KTO specific dataset. At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation @@ -600,10 +683,15 @@ def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) the sum of the length of the prompt and the completion response, with label_pad_token_id for the prompt tokens. """ - batch = {} prompt = feature["prompt"] completion = feature["completion"] + batch = { + f"{prefix}prompt": prompt, + f"{prefix}completion": completion, + f"{prefix}label": feature["label"], + } + if not self.is_encoder_decoder: # Check issues below for more details # 1. https://github.com/huggingface/trl/issues/907 @@ -670,7 +758,7 @@ def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) for type_key, tokens in toks.items(): if type_key == "token_type_ids": continue - batch[f"{k}{type_key}"] = tokens + batch[f"{prefix}{k}{type_key}"] = tokens else: completion_tokens = self.tokenizer( @@ -680,13 +768,13 @@ def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True ) - batch["prompt_input_ids"] = prompt_tokens["input_ids"] - batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] + batch[f"{prefix}prompt_input_ids"] = prompt_tokens["input_ids"] + batch[f"{prefix}prompt_attention_mask"] = prompt_tokens["attention_mask"] - batch["completion_labels"] = completion_tokens["input_ids"] - batch["completion_attention_mask"] = completion_tokens["attention_mask"] + batch[f"{prefix}completion_labels"] = completion_tokens["input_ids"] + batch[f"{prefix}completion_attention_mask"] = completion_tokens["attention_mask"] if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): - batch["completion_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + batch[f"{prefix}completion_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( labels=torch.tensor(batch["completion_labels"]) ) @@ -735,46 +823,65 @@ def get_batch_logps( def forward( self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - target_indicies = [i for i in range(len(batch["kl"])) if batch["kl"][i] is False] - kl_indicies = [i for i in range(len(batch["kl"])) if batch["kl"][i] is True] - if self.is_encoder_decoder: - all_logits = model( + with torch.no_grad(): + KL_logits = model( + batch["KL_prompt_input_ids"], + attention_mask=batch["KL_prompt_attention_mask"], + decoder_input_ids=batch.get("KL_completion_decoder_input_ids"), + labels=batch["KL_completion_labels"], + ).logits + + completion_logits = model( batch["prompt_input_ids"], attention_mask=batch["prompt_attention_mask"], decoder_input_ids=batch.get("completion_decoder_input_ids"), labels=batch["completion_labels"], ).logits else: - all_logits = model( + with torch.no_grad(): + KL_logits = model( + batch["KL_completion_input_ids"], + attention_mask=batch["KL_completion_attention_mask"], + ).logits + + completion_logits = model( batch["completion_input_ids"], attention_mask=batch["completion_attention_mask"], ).logits - all_logps = self.get_batch_logps( - all_logits, + completion_logps = self.get_batch_logps( + completion_logits, batch["completion_labels"], average_log_prob=False, is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) - target_logps = all_logps[target_indicies, ...] - target_logits = all_logits[target_indicies, ...] - kl_logps = all_logps[kl_indicies, ...] + KL_logps = self.get_batch_logps( + KL_logits, + batch["KL_completion_labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) - train_label = [batch["label"][i] for i in target_indicies] + if completion_logps.shape[0] != len(batch["label"]): + raise ValueError( + "There is a mismatch between the number of examples in this batch and the number of " + "examples for which an output sequence was predicted." + ) - chosen_idx = [i for i in range(target_logps.shape[0]) if train_label[i] is True] - rejected_idx = [i for i in range(target_logps.shape[0]) if train_label[i] is False] + chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True] + rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False] - chosen_logps = target_logps[chosen_idx, ...] - rejected_logps = target_logps[rejected_idx, ...] + chosen_logps = completion_logps[chosen_idx, ...] + rejected_logps = completion_logps[rejected_idx, ...] - chosen_logits = target_logits[chosen_idx, ...] - rejected_logits = target_logits[rejected_idx, ...] + chosen_logits = completion_logits[chosen_idx, ...] + rejected_logits = completion_logits[rejected_idx, ...] - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps) + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps) def kto_loss( self, @@ -834,6 +941,7 @@ def get_batch_loss_metrics( ): """Compute the KTO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} + batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} ( policy_chosen_logps, @@ -845,18 +953,12 @@ def get_batch_loss_metrics( # if reference_logps in batch use them, otherwise use the reference model if "reference_logps" in batch: - kl_indicies = [i for i in range(len(batch["kl"])) if batch["kl"][i] is True] - reference_KL_logps = batch["reference_logps"][kl_indicies, ...] + chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True] + rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False] - target_indicies = [i for i in range(len(batch["kl"])) if batch["kl"][i] is False] - target_logps = batch["reference_logps"][target_indicies, ...] - target_labels = [batch["label"][i] for i in target_indicies] - - chosen_idx = [i for i in range(target_logps.shape[0]) if target_labels[i] is True] - rejected_idx = [i for i in range(target_logps.shape[0]) if target_labels[i] is False] - - reference_chosen_logps = target_logps[chosen_idx, ...] - reference_rejected_logps = target_logps[rejected_idx, ...] + reference_chosen_logps = batch["reference_logps"][chosen_idx, ...] + reference_rejected_logps = batch["reference_logps"][rejected_idx, ...] + reference_KL_logps = batch["reference_KL_logps"] else: with torch.no_grad(): if self.ref_model is None: @@ -886,18 +988,17 @@ def get_batch_loss_metrics( reference_KL_logps, ) - reward_accuracies = (chosen_rewards.mean() > rejected_rewards.mean()).float() - prefix = "eval_" if train_eval == "eval" else "" metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() - metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu() - metrics[f"{prefix}rewards/margins"] = (chosen_rewards.mean() - rejected_rewards.mean()).cpu() + metrics[f"{prefix}rewards/margins"] = ( + chosen_rewards.mean().nan_to_num(0) - rejected_rewards.mean().nan_to_num(0) + ).cpu() metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu() metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu() metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu() metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() - metrics[f"{prefix}kl"] = kl.item() + metrics[f"{prefix}kl"] = kl.item() # has already been gathered in kto_loss loss = ( losses.mean()