Skip to content

Commit

Permalink
replaced the penalty function
Browse files Browse the repository at this point in the history
  • Loading branch information
isabella618033 committed Aug 31, 2023
1 parent 1806039 commit 93e04cd
Showing 1 changed file with 7 additions and 18 deletions.
25 changes: 7 additions & 18 deletions openvalidators/reward/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import List
from .config import RewardModelType
from .reward import BaseRewardModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, NoRepeatNGramLogitsProcessor


class DirectPreferenceRewardModel(BaseRewardModel):
Expand All @@ -38,6 +38,7 @@ def __init__(self, device: str):
self.model = AutoModelForCausalLM.from_pretrained(DirectPreferenceRewardModel.reward_model_name,
trust_remote_code=True,
torch_dtype=torch.float16).to(self.device)
self.ngram_logit_processor = NoRepeatNGramLogitsProcessor(ngram_size = 5)

def reward_single(self, prompt: str, completion: str, name: str ,with_penalty=True) -> float:
r""" Calculates a direct preference optimization (DPO) style reward for a completion,
Expand Down Expand Up @@ -80,11 +81,10 @@ def reward_single(self, prompt: str, completion: str, name: str ,with_penalty=Tr
logits = logits[:, :-1, :] # [batch_size=1, seq_len-1, vocab_len]

if with_penalty:
# Apply penalty for repeated generation
for i in range(len(prompt_part)+1, len(combined)-1):
logit = logits[:,i,:].clone()
inputs = combined[len(prompt_part):i].clone()
logits[:,i,:] = self.logit_penalty(input_ids=inputs, logit=logit)
org_logit = logits.clone()
logits = self.ngram_logit_processor(combined[len(prompt_part):].reshape(1, -1).clone(), logits.permute(0, 2, 1)).permute(0, 2, 1)
# ngram_logit_processor set punished tokens to -inf, resetting them to 10 std below instead
logits[logits == -float("Inf")] = org_logit.mean() - org_logit.std()*10

# Rescale via log(softmax(logits)).
logits = logits.log_softmax(-1)
Expand All @@ -103,15 +103,4 @@ def get_rewards(self, prompt: str, completions: List[str], name: str) -> torch.F
rewards = torch.tensor([self.reward_single(prompt, completion, name) for completion in completions],
dtype=torch.float32).to(self.device)
bt.logging.trace(f"DirectPreferenceRewardModel | rewards: {rewards.tolist()}")
return rewards

def logit_penalty(self, input_ids: torch.LongTensor, logit: torch.FloatTensor) -> torch.FloatTensor:
# Counts the unique tokens within each generation
uniques, counts = input_ids.unique(return_counts=True)
score = torch.gather(logit, 1, uniques.unsqueeze(0))

# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * (self.penalty**counts), score / (self.penalty**counts))

logit.scatter_(1, uniques.unsqueeze(0), score.to(logit.dtype))
return logit
return rewards

0 comments on commit 93e04cd

Please sign in to comment.