From 8febff8000da9dafa7dba5ac03a7e32fb3f46d56 Mon Sep 17 00:00:00 2001 From: zhc7 Date: Sat, 14 Dec 2024 01:26:17 +0800 Subject: [PATCH 1/3] dpo_trainer gather metrics across ranks before logging according to https://github.com/huggingface/trl/issues/2468 --- trl/trainer/dpo_trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 7ed0ac387f..f26ae9b4b2 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1424,7 +1424,11 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non train_eval = "train" if "loss" in logs else "eval" # Add averaged stored metrics to logs for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() + if isinstance(metrics[0], torch.Tensor): + gathered = self._nested_gather([m.cuda() for m in metrics]) + metrics = [g.mean() for g in gathered] + meaned = torch.tensor(metrics).mean() + logs[key] = meaned.item() del self._stored_metrics[train_eval] if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): From 07a9ad8df0a0c1a0511b1ed8261a13c8416f9542 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 7 Jan 2025 10:16:14 +0000 Subject: [PATCH 2/3] fix everywhere --- trl/trainer/bco_trainer.py | 2 +- trl/trainer/cpo_trainer.py | 18 +++++++++--------- trl/trainer/dpo_trainer.py | 33 ++++++++++++++++++--------------- trl/trainer/orpo_trainer.py | 22 +++++++++++----------- 4 files changed, 39 insertions(+), 36 deletions(-) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index c2d58ab3f2..abeed929a6 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -1236,7 +1236,7 @@ def get_batch_loss_metrics( chosen_embeddings, rejected_embeddings, ) - metrics["delta"] = delta.item() + metrics["delta"] = self.accelerator.gather(delta).mean().item() num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 6d236cfb37..b50f32409b 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -813,15 +813,15 @@ def get_batch_loss_metrics( reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() - metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() - metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu() - metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu() - metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu() - metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu() - metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu() - metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() - metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu() + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather(chosen_rewards).mean().item() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather(rejected_rewards).mean().item() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather(reward_accuracies).mean().item() + metrics[f"{prefix}rewards/margins"] = self.accelerator.gather(chosen_rewards - rejected_rewards).mean().item() + metrics[f"{prefix}logps/rejected"] = self.accelerator.gather(policy_rejected_logps).detach().mean().item() + metrics[f"{prefix}logps/chosen"] = self.accelerator.gather(policy_chosen_logps).detach().mean().item() + metrics[f"{prefix}logits/rejected"] = self.accelerator.gather(policy_rejected_logits).detach().mean().item() + metrics[f"{prefix}logits/chosen"] = self.accelerator.gather(policy_chosen_logits).detach().mean().item() + metrics[f"{prefix}nll_loss"] = self.accelerator.gather(policy_nll_loss).detach().mean().item() if self.aux_loss_enabled: loss += self.aux_loss_coef * aux_loss diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index a6e5f46caa..db4fedaa54 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1240,18 +1240,24 @@ def get_batch_loss_metrics( losses = losses + self.aux_loss_coef * model_output["aux_loss"] prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() - metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() - metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu() - metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu() - metrics[f"{prefix}logps/chosen"] = model_output["chosen_logps"].detach().mean().cpu() - metrics[f"{prefix}logps/rejected"] = model_output["rejected_logps"].detach().mean().cpu() - metrics[f"{prefix}logits/chosen"] = model_output["mean_chosen_logits"].detach().cpu() - metrics[f"{prefix}logits/rejected"] = model_output["mean_rejected_logits"].detach().cpu() + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather(chosen_rewards).mean().item() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather(rejected_rewards).mean().item() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather(reward_accuracies).mean().item() + metrics[f"{prefix}rewards/margins"] = self.accelerator.gather(chosen_rewards - rejected_rewards).mean().item() + metrics[f"{prefix}logps/chosen"] = self.accelerator.gather(model_output["chosen_logps"]).detach().mean().item() + metrics[f"{prefix}logps/rejected"] = ( + self.accelerator.gather(model_output["rejected_logps"]).detach().mean().item() + ) + metrics[f"{prefix}logits/chosen"] = ( + self.accelerator.gather(model_output["mean_chosen_logits"]).detach().mean().item() + ) + metrics[f"{prefix}logits/rejected"] = ( + self.accelerator.gather(model_output["mean_rejected_logits"]).detach().mean().item() + ) if self.args.rpo_alpha is not None: - metrics[f"{prefix}nll_loss"] = model_output["nll_loss"].detach().mean().cpu() + metrics[f"{prefix}nll_loss"] = self.accelerator.gather(model_output["nll_loss"]).detach().mean().item() if self.aux_loss_enabled: - metrics[f"{prefix}aux_loss"] = model_output["aux_loss"].detach().cpu() + metrics[f"{prefix}aux_loss"] = self.accelerator.gather(model_output["aux_loss"]).detach().mean().item() return losses.mean(), metrics @@ -1421,15 +1427,12 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non start_time (`float` or `None`, *optional*, defaults to `None`): Start time of the training. """ + # This function is called either in the # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" # Add averaged stored metrics to logs for key, metrics in self._stored_metrics[train_eval].items(): - if isinstance(metrics[0], torch.Tensor): - gathered = self._nested_gather([m.cuda() for m in metrics]) - metrics = [g.mean() for g in gathered] - meaned = torch.tensor(metrics).mean() - logs[key] = meaned.item() + logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 50392526db..7b1173f649 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -829,17 +829,17 @@ def get_batch_loss_metrics( reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean() - metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean() - metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean() - metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean() - metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean() - metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean() - metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean() - metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean() - metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean() - metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio - metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather(chosen_rewards).mean() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather(rejected_rewards).mean() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather(reward_accuracies).mean() + metrics[f"{prefix}rewards/margins"] = self.accelerator.gather(chosen_rewards - rejected_rewards).mean() + metrics[f"{prefix}logps/rejected"] = self.accelerator.gather(policy_rejected_logps).detach().mean() + metrics[f"{prefix}logps/chosen"] = self.accelerator.gather(policy_chosen_logps).detach().mean() + metrics[f"{prefix}logits/rejected"] = self.accelerator.gather(policy_rejected_logits).detach().mean() + metrics[f"{prefix}logits/chosen"] = self.accelerator.gather(policy_chosen_logits).detach().mean() + metrics[f"{prefix}nll_loss"] = self.accelerator.gather(policy_nll_loss).detach().mean() + metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather(log_odds_ratio).mean() + metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather(log_odds_chosen).mean() if is_torch_xla_available(): xm.mark_step() # needed because .item() calls for k, v in metrics.items(): From 17383f9df1aa5a89b35be0be924812601eceb57f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 7 Jan 2025 13:44:13 +0000 Subject: [PATCH 3/3] gather_for_metrics --- trl/trainer/bco_trainer.py | 30 +++++++++++++++++++++--------- trl/trainer/cpo_trainer.py | 28 +++++++++++++++++++--------- trl/trainer/dpo_trainer.py | 29 ++++++++++++++++++----------- trl/trainer/kto_trainer.py | 30 +++++++++++++++++++++--------- trl/trainer/nash_md_trainer.py | 2 +- trl/trainer/online_dpo_trainer.py | 24 ++++++++++++++---------- trl/trainer/orpo_trainer.py | 26 +++++++++++++++----------- trl/trainer/ppo_trainer.py | 30 ++++++++++++++++-------------- trl/trainer/rloo_trainer.py | 28 +++++++++++++++------------- trl/trainer/xpo_trainer.py | 2 +- 10 files changed, 141 insertions(+), 88 deletions(-) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index abeed929a6..177144b596 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -1236,24 +1236,36 @@ def get_batch_loss_metrics( chosen_embeddings, rejected_embeddings, ) - metrics["delta"] = self.accelerator.gather(delta).mean().item() + metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item() num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) - all_num_chosen = self.accelerator.gather(num_chosen).sum().item() - all_num_rejected = self.accelerator.gather(num_rejected).sum().item() + all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item() + all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item() if all_num_chosen > 0: - metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item() - metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item() - metrics["logits/chosen_sum"] = self.accelerator.gather(policy_chosen_logits.nansum()).nansum().item() + metrics["rewards/chosen_sum"] = ( + self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item() + ) + metrics["logps/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item() + ) + metrics["logits/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item() + ) metrics["count/chosen"] = all_num_chosen if all_num_rejected > 0: - metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item() - metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item() - metrics["logits/rejected_sum"] = self.accelerator.gather(policy_rejected_logits.nansum()).nansum().item() + metrics["rewards/rejected_sum"] = ( + self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item() + ) + metrics["logps/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item() + ) + metrics["logits/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item() + ) metrics["count/rejected"] = all_num_rejected loss = losses.nanmean() diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index b50f32409b..6eee3b03ad 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -813,15 +813,25 @@ def get_batch_loss_metrics( reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather(chosen_rewards).mean().item() - metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather(rejected_rewards).mean().item() - metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather(reward_accuracies).mean().item() - metrics[f"{prefix}rewards/margins"] = self.accelerator.gather(chosen_rewards - rejected_rewards).mean().item() - metrics[f"{prefix}logps/rejected"] = self.accelerator.gather(policy_rejected_logps).detach().mean().item() - metrics[f"{prefix}logps/chosen"] = self.accelerator.gather(policy_chosen_logps).detach().mean().item() - metrics[f"{prefix}logits/rejected"] = self.accelerator.gather(policy_rejected_logits).detach().mean().item() - metrics[f"{prefix}logits/chosen"] = self.accelerator.gather(policy_chosen_logits).detach().mean().item() - metrics[f"{prefix}nll_loss"] = self.accelerator.gather(policy_nll_loss).detach().mean().item() + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item() + metrics[f"{prefix}rewards/margins"] = ( + self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item() + ) + metrics[f"{prefix}logps/rejected"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item() + ) + metrics[f"{prefix}logps/chosen"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item() + ) + metrics[f"{prefix}logits/rejected"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean().item() + ) + metrics[f"{prefix}logits/chosen"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean().item() + ) + metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item() if self.aux_loss_enabled: loss += self.aux_loss_coef * aux_loss diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index db4fedaa54..cc157567a0 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1240,24 +1240,32 @@ def get_batch_loss_metrics( losses = losses + self.aux_loss_coef * model_output["aux_loss"] prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather(chosen_rewards).mean().item() - metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather(rejected_rewards).mean().item() - metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather(reward_accuracies).mean().item() - metrics[f"{prefix}rewards/margins"] = self.accelerator.gather(chosen_rewards - rejected_rewards).mean().item() - metrics[f"{prefix}logps/chosen"] = self.accelerator.gather(model_output["chosen_logps"]).detach().mean().item() + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item() + metrics[f"{prefix}rewards/margins"] = ( + self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item() + ) + metrics[f"{prefix}logps/chosen"] = ( + self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item() + ) metrics[f"{prefix}logps/rejected"] = ( - self.accelerator.gather(model_output["rejected_logps"]).detach().mean().item() + self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item() ) metrics[f"{prefix}logits/chosen"] = ( - self.accelerator.gather(model_output["mean_chosen_logits"]).detach().mean().item() + self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item() ) metrics[f"{prefix}logits/rejected"] = ( - self.accelerator.gather(model_output["mean_rejected_logits"]).detach().mean().item() + self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item() ) if self.args.rpo_alpha is not None: - metrics[f"{prefix}nll_loss"] = self.accelerator.gather(model_output["nll_loss"]).detach().mean().item() + metrics[f"{prefix}nll_loss"] = ( + self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item() + ) if self.aux_loss_enabled: - metrics[f"{prefix}aux_loss"] = self.accelerator.gather(model_output["aux_loss"]).detach().mean().item() + metrics[f"{prefix}aux_loss"] = ( + self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item() + ) return losses.mean(), metrics @@ -1427,7 +1435,6 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non start_time (`float` or `None`, *optional*, defaults to `None`): Start time of the training. """ - # This function is called either in the # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" # Add averaged stored metrics to logs diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index d054d97e7d..a7a71714d9 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -1138,7 +1138,7 @@ def kto_loss( """ if self.calculate_KL: kl = (policy_KL_logps - reference_KL_logps).mean().detach() - kl = self.accelerator.gather(kl).mean().clamp(min=0) + kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0) else: kl = torch.zeros(1).to(policy_chosen_logps.device) @@ -1247,19 +1247,31 @@ def get_batch_loss_metrics( num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) - all_num_chosen = self.accelerator.gather(num_chosen).sum().item() - all_num_rejected = self.accelerator.gather(num_rejected).sum().item() + all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item() + all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item() if all_num_chosen > 0: - metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item() - metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item() - metrics["logits/chosen_sum"] = self.accelerator.gather(policy_chosen_logits.nansum()).nansum().item() + metrics["rewards/chosen_sum"] = ( + self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item() + ) + metrics["logps/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item() + ) + metrics["logits/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item() + ) metrics["count/chosen"] = all_num_chosen if all_num_rejected > 0: - metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item() - metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item() - metrics["logits/rejected_sum"] = self.accelerator.gather(policy_rejected_logits.nansum()).nansum().item() + metrics["rewards/rejected_sum"] = ( + self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item() + ) + metrics["logps/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item() + ) + metrics["logits/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item() + ) metrics["count/rejected"] = all_num_rejected loss = losses.nanmean() diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index 1d714e2c1d..a5767f9580 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -334,7 +334,7 @@ def _log_statistics( ): # Helper function to gather and compute mean def gather_mean(tensor): - return self.accelerator.gather(tensor).mean().item() + return self.accelerator.gather_for_metrics(tensor).mean().item() # Log score self.stats["loss/score"].append(gather_mean(score)) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 68008881f5..e4f6415dfa 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -547,28 +547,32 @@ def training_step( # Log everything if self.reward_model is not None: scores_margin = scores[chosen_indices] - scores[rejected_indices] - self.stats["objective/scores_margin"].append(self.accelerator.gather(scores_margin.mean()).mean().item()) - self.stats["objective/scores"].append(self.accelerator.gather(scores.mean()).mean().item()) + self.stats["objective/scores_margin"].append( + self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item() + ) + self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item()) self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item()) - self.stats["logps/chosen"].append(self.accelerator.gather(chosen_logprobs_sum).mean().item()) - self.stats["logps/rejected"].append(self.accelerator.gather(rejected_logprobs_sum).mean().item()) + self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item()) + self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item()) kl = logprobs - ref_logprobs mean_kl = kl.sum(1).mean() - self.stats["objective/kl"].append(self.accelerator.gather(mean_kl).mean().item()) + self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) non_score_reward = (-self.beta * kl).sum(1) mean_non_score_reward = non_score_reward.mean() - self.stats["objective/non_score_reward"].append(self.accelerator.gather(mean_non_score_reward).mean().item()) + self.stats["objective/non_score_reward"].append( + self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() + ) if self.reward_model is not None: rlhf_reward = scores + non_score_reward - self.stats["objective/rlhf_reward"].append(self.accelerator.gather(rlhf_reward).mean().item()) + self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item()) mean_entropy = -logprobs.sum(1).mean() - self.stats["objective/entropy"].append(self.accelerator.gather(mean_entropy).mean().item()) + self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item()) chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) - gathered_chosen_rewards = self.accelerator.gather(chosen_rewards) + gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards) self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item()) rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) - gathered_rejected_rewards = self.accelerator.gather(rejected_rewards) + gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards) self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item()) margin = gathered_chosen_rewards - gathered_rejected_rewards self.stats["rewards/margins"].append(margin.mean().item()) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 7b1173f649..fd9f557c47 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -829,17 +829,21 @@ def get_batch_loss_metrics( reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather(chosen_rewards).mean() - metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather(rejected_rewards).mean() - metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather(reward_accuracies).mean() - metrics[f"{prefix}rewards/margins"] = self.accelerator.gather(chosen_rewards - rejected_rewards).mean() - metrics[f"{prefix}logps/rejected"] = self.accelerator.gather(policy_rejected_logps).detach().mean() - metrics[f"{prefix}logps/chosen"] = self.accelerator.gather(policy_chosen_logps).detach().mean() - metrics[f"{prefix}logits/rejected"] = self.accelerator.gather(policy_rejected_logits).detach().mean() - metrics[f"{prefix}logits/chosen"] = self.accelerator.gather(policy_chosen_logits).detach().mean() - metrics[f"{prefix}nll_loss"] = self.accelerator.gather(policy_nll_loss).detach().mean() - metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather(log_odds_ratio).mean() - metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather(log_odds_chosen).mean() + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean() + metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics( + chosen_rewards - rejected_rewards + ).mean() + metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean() + metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean() + metrics[f"{prefix}logits/rejected"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean() + ) + metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean() + metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean() + metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean() + metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean() if is_torch_xla_available(): xm.mark_step() # needed because .item() calls for k, v in metrics.items(): diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 51897eeb44..7e0dc635b2 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -615,19 +615,21 @@ def repeat_generator(): eps = int(self.state.episode / (time.time() - start_time)) metrics = {} metrics["eps"] = eps - metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() - metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item() - metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item() - metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item() - metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() - metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item() - metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item() - metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item() - metrics["loss/value_avg"] = self.accelerator.gather(vf_loss_stats).mean().item() - metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item() - metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item() - metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item() - metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item() + metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item() + metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item() + metrics["objective/non_score_reward"] = ( + self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() + ) + metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item() + metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item() + metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item() + metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item() + metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item() + metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item() + metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item() + metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item() + metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item() + metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item() metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() metrics["lr"] = self.lr_scheduler.get_last_lr()[0] metrics["episode"] = self.state.episode @@ -715,7 +717,7 @@ def generate_completions(self, sampling: bool = False): _, score, _ = get_reward( self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length ) - table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) + table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy()) if sampling: break diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 23ea1ca21f..4b736bd270 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -460,18 +460,20 @@ def repeat_generator(): eps = int(self.state.episode / (time.time() - start_time)) metrics = {} metrics["eps"] = eps - metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() - metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item() - metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item() - metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item() - metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() - metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item() - metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item() - metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item() - metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item() - metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item() - metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item() - metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item() + metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item() + metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item() + metrics["objective/non_score_reward"] = ( + self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() + ) + metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item() + metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item() + metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item() + metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item() + metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item() + metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item() + metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item() + metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item() + metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item() metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() metrics["lr"] = self.lr_scheduler.get_last_lr()[0] metrics["episode"] = self.state.episode @@ -538,7 +540,7 @@ def generate_completions(self, sampling: bool = False): _, score, _ = get_reward( self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length ) - table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) + table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy()) if sampling: break diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index 1be32ab1de..1d8c4ae1ac 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -360,7 +360,7 @@ def _log_statistics( ): # Helper function to gather and compute mean def gather_mean(tensor): - return self.accelerator.gather(tensor).mean().item() + return self.accelerator.gather_for_metrics(tensor).mean().item() # Log losses self.stats["loss/dpo"].append(gather_mean(dpo_losses))