Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸ§‘β€πŸ€β€πŸ§‘ Proper metrics gathering across ranks before logging #2474

Merged
merged 6 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,24 +1238,36 @@ def get_batch_loss_metrics(
chosen_embeddings,
rejected_embeddings,
)
metrics["delta"] = delta.item()
metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()

num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)

all_num_chosen = self.accelerator.gather(num_chosen).sum().item()
all_num_rejected = self.accelerator.gather(num_rejected).sum().item()
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()

if all_num_chosen > 0:
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item()
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
metrics["logits/chosen_sum"] = self.accelerator.gather(policy_chosen_logits.nansum()).nansum().item()
metrics["rewards/chosen_sum"] = (
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
)
metrics["logps/chosen_sum"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
)
metrics["logits/chosen_sum"] = (
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
)
metrics["count/chosen"] = all_num_chosen

if all_num_rejected > 0:
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
metrics["logits/rejected_sum"] = self.accelerator.gather(policy_rejected_logits.nansum()).nansum().item()
metrics["rewards/rejected_sum"] = (
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
)
metrics["logps/rejected_sum"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
)
metrics["logits/rejected_sum"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
)
metrics["count/rejected"] = all_num_rejected

loss = losses.nanmean()
Expand Down
28 changes: 19 additions & 9 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,15 +817,25 @@ def get_batch_loss_metrics(
reward_accuracies = (chosen_rewards > rejected_rewards).float()

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
metrics[f"{prefix}rewards/margins"] = (
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
)
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
)
metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean().item()
)
metrics[f"{prefix}logits/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean().item()
)
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()

if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
Expand Down
34 changes: 24 additions & 10 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,18 +1249,32 @@ def get_batch_loss_metrics(
losses = losses + self.aux_loss_coef * model_output["aux_loss"]

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics[f"{prefix}logps/chosen"] = model_output["chosen_logps"].detach().mean().cpu()
metrics[f"{prefix}logps/rejected"] = model_output["rejected_logps"].detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = model_output["mean_chosen_logits"].detach().cpu()
metrics[f"{prefix}logits/rejected"] = model_output["mean_rejected_logits"].detach().cpu()
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
metrics[f"{prefix}rewards/margins"] = (
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item()
)
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item()
)
metrics[f"{prefix}logits/chosen"] = (
self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item()
)
metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item()
)
if self.args.rpo_alpha is not None:
metrics[f"{prefix}nll_loss"] = model_output["nll_loss"].detach().mean().cpu()
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item()
)
if self.aux_loss_enabled:
metrics[f"{prefix}aux_loss"] = model_output["aux_loss"].detach().cpu()
metrics[f"{prefix}aux_loss"] = (
self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item()
)

return losses.mean(), metrics

Expand Down
30 changes: 21 additions & 9 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,7 @@ def kto_loss(
"""
if self.calculate_KL:
kl = (policy_KL_logps - reference_KL_logps).mean().detach()
kl = self.accelerator.gather(kl).mean().clamp(min=0)
kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
else:
kl = torch.zeros(1).to(policy_chosen_logps.device)

Expand Down Expand Up @@ -1249,19 +1249,31 @@ def get_batch_loss_metrics(
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)

all_num_chosen = self.accelerator.gather(num_chosen).sum().item()
all_num_rejected = self.accelerator.gather(num_rejected).sum().item()
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()

if all_num_chosen > 0:
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item()
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
metrics["logits/chosen_sum"] = self.accelerator.gather(policy_chosen_logits.nansum()).nansum().item()
metrics["rewards/chosen_sum"] = (
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
)
metrics["logps/chosen_sum"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
)
metrics["logits/chosen_sum"] = (
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
)
metrics["count/chosen"] = all_num_chosen

if all_num_rejected > 0:
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
metrics["logits/rejected_sum"] = self.accelerator.gather(policy_rejected_logits.nansum()).nansum().item()
metrics["rewards/rejected_sum"] = (
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
)
metrics["logps/rejected_sum"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
)
metrics["logits/rejected_sum"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
)
metrics["count/rejected"] = all_num_rejected

loss = losses.nanmean()
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _log_statistics(
):
# Helper function to gather and compute mean
def gather_mean(tensor):
return self.accelerator.gather(tensor).mean().item()
return self.accelerator.gather_for_metrics(tensor).mean().item()

# Log score
self.stats["loss/score"].append(gather_mean(score))
Expand Down
24 changes: 14 additions & 10 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,28 +549,32 @@ def training_step(
# Log everything
if self.reward_model is not None:
scores_margin = scores[chosen_indices] - scores[rejected_indices]
self.stats["objective/scores_margin"].append(self.accelerator.gather(scores_margin.mean()).mean().item())
self.stats["objective/scores"].append(self.accelerator.gather(scores.mean()).mean().item())
self.stats["objective/scores_margin"].append(
self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
)
self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item())
self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
self.stats["logps/chosen"].append(self.accelerator.gather(chosen_logprobs_sum).mean().item())
self.stats["logps/rejected"].append(self.accelerator.gather(rejected_logprobs_sum).mean().item())
self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())

kl = logprobs - ref_logprobs
mean_kl = kl.sum(1).mean()
self.stats["objective/kl"].append(self.accelerator.gather(mean_kl).mean().item())
self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
non_score_reward = (-self.beta * kl).sum(1)
mean_non_score_reward = non_score_reward.mean()
self.stats["objective/non_score_reward"].append(self.accelerator.gather(mean_non_score_reward).mean().item())
self.stats["objective/non_score_reward"].append(
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
)
if self.reward_model is not None:
rlhf_reward = scores + non_score_reward
self.stats["objective/rlhf_reward"].append(self.accelerator.gather(rlhf_reward).mean().item())
self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
mean_entropy = -logprobs.sum(1).mean()
self.stats["objective/entropy"].append(self.accelerator.gather(mean_entropy).mean().item())
self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
gathered_chosen_rewards = self.accelerator.gather(chosen_rewards)
gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
gathered_rejected_rewards = self.accelerator.gather(rejected_rewards)
gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
margin = gathered_chosen_rewards - gathered_rejected_rewards
self.stats["rewards/margins"].append(margin.mean().item())
Expand Down
26 changes: 15 additions & 11 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,17 +833,21 @@ def get_batch_loss_metrics(
reward_accuracies = (chosen_rewards > rejected_rewards).float()

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
chosen_rewards - rejected_rewards
).mean()
metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
)
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()
if is_torch_xla_available():
xm.mark_step() # needed because .item() calls
for k, v in metrics.items():
Expand Down
30 changes: 16 additions & 14 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,19 +615,21 @@ def repeat_generator():
eps = int(self.state.episode / (time.time() - start_time))
metrics = {}
metrics["eps"] = eps
metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item()
metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item()
metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item()
metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item()
metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item()
metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item()
metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item()
metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item()
metrics["loss/value_avg"] = self.accelerator.gather(vf_loss_stats).mean().item()
metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item()
metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item()
metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item()
metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item()
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
metrics["objective/non_score_reward"] = (
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
)
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
metrics["episode"] = self.state.episode
Expand Down Expand Up @@ -715,7 +717,7 @@ def generate_completions(self, sampling: bool = False):
_, score, _ = get_reward(
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
)
table["score"].extend(self.accelerator.gather(score).float().cpu().numpy())
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())

if sampling:
break
Expand Down
Loading
Loading