diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 1016c124e6..bdb16b92fb 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -1238,24 +1238,36 @@ def get_batch_loss_metrics( chosen_embeddings, rejected_embeddings, ) - metrics["delta"] = delta.item() + metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item() num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) - all_num_chosen = self.accelerator.gather(num_chosen).sum().item() - all_num_rejected = self.accelerator.gather(num_rejected).sum().item() + all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item() + all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item() if all_num_chosen > 0: - metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item() - metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item() - metrics["logits/chosen_sum"] = self.accelerator.gather(policy_chosen_logits.nansum()).nansum().item() + metrics["rewards/chosen_sum"] = ( + self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item() + ) + metrics["logps/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item() + ) + metrics["logits/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item() + ) metrics["count/chosen"] = all_num_chosen if all_num_rejected > 0: - metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item() - metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item() - metrics["logits/rejected_sum"] = self.accelerator.gather(policy_rejected_logits.nansum()).nansum().item() + metrics["rewards/rejected_sum"] = ( + self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item() + ) + metrics["logps/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item() + ) + metrics["logits/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item() + ) metrics["count/rejected"] = all_num_rejected loss = losses.nanmean() diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 5042363274..f6c8dfd5c6 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -817,15 +817,25 @@ def get_batch_loss_metrics( reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() - metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() - metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu() - metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu() - metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu() - metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu() - metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu() - metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() - metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu() + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item() + metrics[f"{prefix}rewards/margins"] = ( + self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item() + ) + metrics[f"{prefix}logps/rejected"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item() + ) + metrics[f"{prefix}logps/chosen"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item() + ) + metrics[f"{prefix}logits/rejected"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean().item() + ) + metrics[f"{prefix}logits/chosen"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean().item() + ) + metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item() if self.aux_loss_enabled: loss += self.aux_loss_coef * aux_loss diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index cbfb994c10..ef56e15922 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1249,18 +1249,32 @@ def get_batch_loss_metrics( losses = losses + self.aux_loss_coef * model_output["aux_loss"] prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() - metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() - metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu() - metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu() - metrics[f"{prefix}logps/chosen"] = model_output["chosen_logps"].detach().mean().cpu() - metrics[f"{prefix}logps/rejected"] = model_output["rejected_logps"].detach().mean().cpu() - metrics[f"{prefix}logits/chosen"] = model_output["mean_chosen_logits"].detach().cpu() - metrics[f"{prefix}logits/rejected"] = model_output["mean_rejected_logits"].detach().cpu() + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item() + metrics[f"{prefix}rewards/margins"] = ( + self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item() + ) + metrics[f"{prefix}logps/chosen"] = ( + self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item() + ) + metrics[f"{prefix}logps/rejected"] = ( + self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item() + ) + metrics[f"{prefix}logits/chosen"] = ( + self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item() + ) + metrics[f"{prefix}logits/rejected"] = ( + self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item() + ) if self.args.rpo_alpha is not None: - metrics[f"{prefix}nll_loss"] = model_output["nll_loss"].detach().mean().cpu() + metrics[f"{prefix}nll_loss"] = ( + self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item() + ) if self.aux_loss_enabled: - metrics[f"{prefix}aux_loss"] = model_output["aux_loss"].detach().cpu() + metrics[f"{prefix}aux_loss"] = ( + self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item() + ) return losses.mean(), metrics diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index bfedfc6fae..9406e276ad 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -1140,7 +1140,7 @@ def kto_loss( """ if self.calculate_KL: kl = (policy_KL_logps - reference_KL_logps).mean().detach() - kl = self.accelerator.gather(kl).mean().clamp(min=0) + kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0) else: kl = torch.zeros(1).to(policy_chosen_logps.device) @@ -1249,19 +1249,31 @@ def get_batch_loss_metrics( num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) - all_num_chosen = self.accelerator.gather(num_chosen).sum().item() - all_num_rejected = self.accelerator.gather(num_rejected).sum().item() + all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item() + all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item() if all_num_chosen > 0: - metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item() - metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item() - metrics["logits/chosen_sum"] = self.accelerator.gather(policy_chosen_logits.nansum()).nansum().item() + metrics["rewards/chosen_sum"] = ( + self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item() + ) + metrics["logps/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item() + ) + metrics["logits/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item() + ) metrics["count/chosen"] = all_num_chosen if all_num_rejected > 0: - metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item() - metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item() - metrics["logits/rejected_sum"] = self.accelerator.gather(policy_rejected_logits.nansum()).nansum().item() + metrics["rewards/rejected_sum"] = ( + self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item() + ) + metrics["logps/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item() + ) + metrics["logits/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item() + ) metrics["count/rejected"] = all_num_rejected loss = losses.nanmean() diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index 1d714e2c1d..a5767f9580 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -334,7 +334,7 @@ def _log_statistics( ): # Helper function to gather and compute mean def gather_mean(tensor): - return self.accelerator.gather(tensor).mean().item() + return self.accelerator.gather_for_metrics(tensor).mean().item() # Log score self.stats["loss/score"].append(gather_mean(score)) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index ebab5cdcfc..5c3fc260ca 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -549,28 +549,32 @@ def training_step( # Log everything if self.reward_model is not None: scores_margin = scores[chosen_indices] - scores[rejected_indices] - self.stats["objective/scores_margin"].append(self.accelerator.gather(scores_margin.mean()).mean().item()) - self.stats["objective/scores"].append(self.accelerator.gather(scores.mean()).mean().item()) + self.stats["objective/scores_margin"].append( + self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item() + ) + self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item()) self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item()) - self.stats["logps/chosen"].append(self.accelerator.gather(chosen_logprobs_sum).mean().item()) - self.stats["logps/rejected"].append(self.accelerator.gather(rejected_logprobs_sum).mean().item()) + self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item()) + self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item()) kl = logprobs - ref_logprobs mean_kl = kl.sum(1).mean() - self.stats["objective/kl"].append(self.accelerator.gather(mean_kl).mean().item()) + self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) non_score_reward = (-self.beta * kl).sum(1) mean_non_score_reward = non_score_reward.mean() - self.stats["objective/non_score_reward"].append(self.accelerator.gather(mean_non_score_reward).mean().item()) + self.stats["objective/non_score_reward"].append( + self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() + ) if self.reward_model is not None: rlhf_reward = scores + non_score_reward - self.stats["objective/rlhf_reward"].append(self.accelerator.gather(rlhf_reward).mean().item()) + self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item()) mean_entropy = -logprobs.sum(1).mean() - self.stats["objective/entropy"].append(self.accelerator.gather(mean_entropy).mean().item()) + self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item()) chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) - gathered_chosen_rewards = self.accelerator.gather(chosen_rewards) + gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards) self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item()) rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) - gathered_rejected_rewards = self.accelerator.gather(rejected_rewards) + gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards) self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item()) margin = gathered_chosen_rewards - gathered_rejected_rewards self.stats["rewards/margins"].append(margin.mean().item()) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 51f81c7775..bd26b34c50 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -833,17 +833,21 @@ def get_batch_loss_metrics( reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean() - metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean() - metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean() - metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean() - metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean() - metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean() - metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean() - metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean() - metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean() - metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio - metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean() + metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics( + chosen_rewards - rejected_rewards + ).mean() + metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean() + metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean() + metrics[f"{prefix}logits/rejected"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean() + ) + metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean() + metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean() + metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean() + metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean() if is_torch_xla_available(): xm.mark_step() # needed because .item() calls for k, v in metrics.items(): diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 51897eeb44..7e0dc635b2 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -615,19 +615,21 @@ def repeat_generator(): eps = int(self.state.episode / (time.time() - start_time)) metrics = {} metrics["eps"] = eps - metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() - metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item() - metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item() - metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item() - metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() - metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item() - metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item() - metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item() - metrics["loss/value_avg"] = self.accelerator.gather(vf_loss_stats).mean().item() - metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item() - metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item() - metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item() - metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item() + metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item() + metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item() + metrics["objective/non_score_reward"] = ( + self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() + ) + metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item() + metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item() + metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item() + metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item() + metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item() + metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item() + metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item() + metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item() + metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item() + metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item() metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() metrics["lr"] = self.lr_scheduler.get_last_lr()[0] metrics["episode"] = self.state.episode @@ -715,7 +717,7 @@ def generate_completions(self, sampling: bool = False): _, score, _ = get_reward( self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length ) - table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) + table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy()) if sampling: break diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 23ea1ca21f..4b736bd270 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -460,18 +460,20 @@ def repeat_generator(): eps = int(self.state.episode / (time.time() - start_time)) metrics = {} metrics["eps"] = eps - metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() - metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item() - metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item() - metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item() - metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() - metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item() - metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item() - metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item() - metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item() - metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item() - metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item() - metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item() + metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item() + metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item() + metrics["objective/non_score_reward"] = ( + self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() + ) + metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item() + metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item() + metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item() + metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item() + metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item() + metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item() + metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item() + metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item() + metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item() metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() metrics["lr"] = self.lr_scheduler.get_last_lr()[0] metrics["episode"] = self.state.episode @@ -538,7 +540,7 @@ def generate_completions(self, sampling: bool = False): _, score, _ = get_reward( self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length ) - table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) + table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy()) if sampling: break diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index 1be32ab1de..1d8c4ae1ac 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -360,7 +360,7 @@ def _log_statistics( ): # Helper function to gather and compute mean def gather_mean(tensor): - return self.accelerator.gather(tensor).mean().item() + return self.accelerator.gather_for_metrics(tensor).mean().item() # Log losses self.stats["loss/dpo"].append(gather_mean(dpo_losses))