diff --git a/openvalidators/reward/dpo.py b/openvalidators/reward/dpo.py index 6cc7d59..21fba3b 100644 --- a/openvalidators/reward/dpo.py +++ b/openvalidators/reward/dpo.py @@ -20,7 +20,7 @@ from typing import List from .config import RewardModelType from .reward import BaseRewardModel -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, AutoModelForCausalLM, NoRepeatNGramLogitsProcessor class DirectPreferenceRewardModel(BaseRewardModel): @@ -38,6 +38,7 @@ def __init__(self, device: str): self.model = AutoModelForCausalLM.from_pretrained(DirectPreferenceRewardModel.reward_model_name, trust_remote_code=True, torch_dtype=torch.float16).to(self.device) + self.ngram_logit_processor = NoRepeatNGramLogitsProcessor(ngram_size = 5) def reward_single(self, prompt: str, completion: str, name: str ,with_penalty=True) -> float: r""" Calculates a direct preference optimization (DPO) style reward for a completion, @@ -80,11 +81,10 @@ def reward_single(self, prompt: str, completion: str, name: str ,with_penalty=Tr logits = logits[:, :-1, :] # [batch_size=1, seq_len-1, vocab_len] if with_penalty: - # Apply penalty for repeated generation - for i in range(len(prompt_part)+1, len(combined)-1): - logit = logits[:,i,:].clone() - inputs = combined[len(prompt_part):i].clone() - logits[:,i,:] = self.logit_penalty(input_ids=inputs, logit=logit) + org_logit = logits.clone() + logits = self.ngram_logit_processor(combined[len(prompt_part):].reshape(1, -1).clone(), logits.permute(0, 2, 1)).permute(0, 2, 1) + # ngram_logit_processor set punished tokens to -inf, resetting them to 10 std below instead + logits[logits == -float("Inf")] = org_logit.mean() - org_logit.std()*10 # Rescale via log(softmax(logits)). logits = logits.log_softmax(-1) @@ -103,15 +103,4 @@ def get_rewards(self, prompt: str, completions: List[str], name: str) -> torch.F rewards = torch.tensor([self.reward_single(prompt, completion, name) for completion in completions], dtype=torch.float32).to(self.device) bt.logging.trace(f"DirectPreferenceRewardModel | rewards: {rewards.tolist()}") - return rewards - - def logit_penalty(self, input_ids: torch.LongTensor, logit: torch.FloatTensor) -> torch.FloatTensor: - # Counts the unique tokens within each generation - uniques, counts = input_ids.unique(return_counts=True) - score = torch.gather(logit, 1, uniques.unsqueeze(0)) - - # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability - score = torch.where(score < 0, score * (self.penalty**counts), score / (self.penalty**counts)) - - logit.scatter_(1, uniques.unsqueeze(0), score.to(logit.dtype)) - return logit \ No newline at end of file + return rewards \ No newline at end of file