Skip to content

Commit

Permalink
Fix reported KL in PPO trainer (#1180)
Browse files Browse the repository at this point in the history
* Fix reported KL in PPO trainer

previously this was always reporting the estimated KL, even when using `kl_penalty = 'full'` (or `abs`, etc).
Now we return the actual KL calculated in `compute_rewards()`, and report that.

* fix test
  • Loading branch information
mgerstgrasser authored Jan 9, 2024
1 parent 4ae35af commit a236c57
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def test_loss_trainer(self):
logits = torch.exp(all_logprobs)
vpreds = values + 0.1

score, non_score = ppo_trainer.compute_rewards(dummy_scores, all_logprobs, ref_logprobs, mask)
score, non_score, kls = ppo_trainer.compute_rewards(dummy_scores, all_logprobs, ref_logprobs, mask)
values, advantages, returns = ppo_trainer.compute_advantages(values, score, mask)

# just make sure a dummy loss is computed
Expand Down
18 changes: 13 additions & 5 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,11 +733,11 @@ def step(
active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False)
ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False)

rewards, non_score_reward = self.compute_rewards(
rewards, non_score_reward, kls = self.compute_rewards(
scores, active_full_logprobs, ref_full_logprobs, masks
)
else:
rewards, non_score_reward = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
timing["time/ppo/compute_rewards"] = time.time() - t

t = time.time()
Expand Down Expand Up @@ -831,6 +831,7 @@ def step(
masks=masks,
queries=queries,
responses=responses,
kls=kls,
)
# Gather/Reduce stats from all processes
if self.is_distributed:
Expand Down Expand Up @@ -1091,11 +1092,17 @@ def compute_rewards(
Log probabilities of the model, shape (`batch_size`, `response_length`)
ref_logprobs (`torch.FloatTensor`):
Log probabilities of the reference model, shape (`batch_size`, `response_length`)
Returns:
`torch.FloatTensor`: Per token rewards, shape (`batch_size`, `response_length`)
`torch.FloatTensor`: Non score rewards, shape (`batch_size`, `response_length`)
`torch.FloatTensor`: KL penalty, shape (`batch_size`, `response_length`)
"""
rewards, non_score_rewards = [], []
rewards, non_score_rewards, kls = [], [], []
for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
# compute KL penalty (from difference in logprobs)
kl = self._kl_penalty(logprob, ref_logprob)
kls.append(kl)
non_score_reward = -self.kl_ctl.value * kl
non_score_rewards.append(non_score_reward)
reward = non_score_reward.clone()
Expand All @@ -1104,7 +1111,7 @@ def compute_rewards(
# reward is preference model score + KL penalty
reward[last_non_masked_index] += score
rewards.append(reward)
return torch.stack(rewards), torch.stack(non_score_rewards)
return torch.stack(rewards), torch.stack(non_score_rewards), torch.stack(kls)

def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor) -> torch.FloatTensor:
if self.config.kl_penalty == "kl":
Expand Down Expand Up @@ -1256,7 +1263,8 @@ def record_step_stats(self, kl_coef: float, **data):
"""
mask = data.pop("masks")

kl_list = ((data["logprobs"] - data["ref_logprobs"]) * mask).sum(axis=-1)
kls = data.pop("kls")
kl_list = ((kls) * mask).sum(axis=-1)
mean_kl = kl_list.mean()
mean_entropy = (-data["logprobs"] * mask).sum(axis=-1).mean()

Expand Down

0 comments on commit a236c57

Please sign in to comment.