Skip to content

Commit

Permalink
πŸ§‘β€πŸ€β€πŸ§‘ Proper metrics gathering across ranks before logging (#2474)
Browse files Browse the repository at this point in the history
* dpo_trainer gather metrics across ranks before logging

according to #2468

* fix everywhere

* gather_for_metrics

---------

Co-authored-by: Quentin GallouΓ©dec <[email protected]>
Co-authored-by: Quentin GallouΓ©dec <[email protected]>
  • Loading branch information
3 people authored Jan 7, 2025
1 parent 1d23ecc commit a50124d
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 87 deletions.
30 changes: 21 additions & 9 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,24 +1238,36 @@ def get_batch_loss_metrics(
chosen_embeddings,
rejected_embeddings,
)
metrics["delta"] = delta.item()
metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()

num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)

all_num_chosen = self.accelerator.gather(num_chosen).sum().item()
all_num_rejected = self.accelerator.gather(num_rejected).sum().item()
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()

if all_num_chosen > 0:
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item()
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
metrics["logits/chosen_sum"] = self.accelerator.gather(policy_chosen_logits.nansum()).nansum().item()
metrics["rewards/chosen_sum"] = (
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
)
metrics["logps/chosen_sum"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
)
metrics["logits/chosen_sum"] = (
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
)
metrics["count/chosen"] = all_num_chosen

if all_num_rejected > 0:
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
metrics["logits/rejected_sum"] = self.accelerator.gather(policy_rejected_logits.nansum()).nansum().item()
metrics["rewards/rejected_sum"] = (
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
)
metrics["logps/rejected_sum"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
)
metrics["logits/rejected_sum"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
)
metrics["count/rejected"] = all_num_rejected

loss = losses.nanmean()
Expand Down
28 changes: 19 additions & 9 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,15 +817,25 @@ def get_batch_loss_metrics(
reward_accuracies = (chosen_rewards > rejected_rewards).float()

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
metrics[f"{prefix}rewards/margins"] = (
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
)
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
)
metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean().item()
)
metrics[f"{prefix}logits/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean().item()
)
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()

if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
Expand Down
34 changes: 24 additions & 10 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,18 +1249,32 @@ def get_batch_loss_metrics(
losses = losses + self.aux_loss_coef * model_output["aux_loss"]

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics[f"{prefix}logps/chosen"] = model_output["chosen_logps"].detach().mean().cpu()
metrics[f"{prefix}logps/rejected"] = model_output["rejected_logps"].detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = model_output["mean_chosen_logits"].detach().cpu()
metrics[f"{prefix}logits/rejected"] = model_output["mean_rejected_logits"].detach().cpu()
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
metrics[f"{prefix}rewards/margins"] = (
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item()
)
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item()
)
metrics[f"{prefix}logits/chosen"] = (
self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item()
)
metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item()
)
if self.args.rpo_alpha is not None:
metrics[f"{prefix}nll_loss"] = model_output["nll_loss"].detach().mean().cpu()
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item()
)
if self.aux_loss_enabled:
metrics[f"{prefix}aux_loss"] = model_output["aux_loss"].detach().cpu()
metrics[f"{prefix}aux_loss"] = (
self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item()
)

return losses.mean(), metrics

Expand Down
30 changes: 21 additions & 9 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,7 @@ def kto_loss(
"""
if self.calculate_KL:
kl = (policy_KL_logps - reference_KL_logps).mean().detach()
kl = self.accelerator.gather(kl).mean().clamp(min=0)
kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
else:
kl = torch.zeros(1).to(policy_chosen_logps.device)

Expand Down Expand Up @@ -1249,19 +1249,31 @@ def get_batch_loss_metrics(
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)

all_num_chosen = self.accelerator.gather(num_chosen).sum().item()
all_num_rejected = self.accelerator.gather(num_rejected).sum().item()
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()

if all_num_chosen > 0:
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item()
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
metrics["logits/chosen_sum"] = self.accelerator.gather(policy_chosen_logits.nansum()).nansum().item()
metrics["rewards/chosen_sum"] = (
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
)
metrics["logps/chosen_sum"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
)
metrics["logits/chosen_sum"] = (
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
)
metrics["count/chosen"] = all_num_chosen

if all_num_rejected > 0:
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
metrics["logits/rejected_sum"] = self.accelerator.gather(policy_rejected_logits.nansum()).nansum().item()
metrics["rewards/rejected_sum"] = (
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
)
metrics["logps/rejected_sum"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
)
metrics["logits/rejected_sum"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
)
metrics["count/rejected"] = all_num_rejected

loss = losses.nanmean()
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _log_statistics(
):
# Helper function to gather and compute mean
def gather_mean(tensor):
return self.accelerator.gather(tensor).mean().item()
return self.accelerator.gather_for_metrics(tensor).mean().item()

# Log score
self.stats["loss/score"].append(gather_mean(score))
Expand Down
24 changes: 14 additions & 10 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,28 +549,32 @@ def training_step(
# Log everything
if self.reward_model is not None:
scores_margin = scores[chosen_indices] - scores[rejected_indices]
self.stats["objective/scores_margin"].append(self.accelerator.gather(scores_margin.mean()).mean().item())
self.stats["objective/scores"].append(self.accelerator.gather(scores.mean()).mean().item())
self.stats["objective/scores_margin"].append(
self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
)
self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item())
self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
self.stats["logps/chosen"].append(self.accelerator.gather(chosen_logprobs_sum).mean().item())
self.stats["logps/rejected"].append(self.accelerator.gather(rejected_logprobs_sum).mean().item())
self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())

kl = logprobs - ref_logprobs
mean_kl = kl.sum(1).mean()
self.stats["objective/kl"].append(self.accelerator.gather(mean_kl).mean().item())
self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
non_score_reward = (-self.beta * kl).sum(1)
mean_non_score_reward = non_score_reward.mean()
self.stats["objective/non_score_reward"].append(self.accelerator.gather(mean_non_score_reward).mean().item())
self.stats["objective/non_score_reward"].append(
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
)
if self.reward_model is not None:
rlhf_reward = scores + non_score_reward
self.stats["objective/rlhf_reward"].append(self.accelerator.gather(rlhf_reward).mean().item())
self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
mean_entropy = -logprobs.sum(1).mean()
self.stats["objective/entropy"].append(self.accelerator.gather(mean_entropy).mean().item())
self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
gathered_chosen_rewards = self.accelerator.gather(chosen_rewards)
gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
gathered_rejected_rewards = self.accelerator.gather(rejected_rewards)
gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
margin = gathered_chosen_rewards - gathered_rejected_rewards
self.stats["rewards/margins"].append(margin.mean().item())
Expand Down
26 changes: 15 additions & 11 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,17 +833,21 @@ def get_batch_loss_metrics(
reward_accuracies = (chosen_rewards > rejected_rewards).float()

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
chosen_rewards - rejected_rewards
).mean()
metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
)
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()
if is_torch_xla_available():
xm.mark_step() # needed because .item() calls
for k, v in metrics.items():
Expand Down
30 changes: 16 additions & 14 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,19 +615,21 @@ def repeat_generator():
eps = int(self.state.episode / (time.time() - start_time))
metrics = {}
metrics["eps"] = eps
metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item()
metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item()
metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item()
metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item()
metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item()
metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item()
metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item()
metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item()
metrics["loss/value_avg"] = self.accelerator.gather(vf_loss_stats).mean().item()
metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item()
metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item()
metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item()
metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item()
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
metrics["objective/non_score_reward"] = (
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
)
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
metrics["episode"] = self.state.episode
Expand Down Expand Up @@ -715,7 +717,7 @@ def generate_completions(self, sampling: bool = False):
_, score, _ = get_reward(
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
)
table["score"].extend(self.accelerator.gather(score).float().cpu().numpy())
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())

if sampling:
break
Expand Down
Loading

0 comments on commit a50124d

Please sign in to comment.