Skip to content

Commit

Permalink
Merge pull request #138 from opentensor/dpo_penalty_update
Browse files Browse the repository at this point in the history
Dpo penalty update
  • Loading branch information
Eugene-hu authored Aug 25, 2023
2 parents 9a2d704 + b156a6d commit bd315ec
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion openvalidators/reward/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,23 @@ def name(self) -> str: return RewardModelType.dpo.value
def __init__(self, device: str):
super().__init__()
self.device = device
self.penalty = 1.2 # Same penalty as the original [paper](https://arxiv.org/pdf/1909.05858.pdf).
self.tokenizer = AutoTokenizer.from_pretrained(DirectPreferenceRewardModel.reward_model_name)
self.model = AutoModelForCausalLM.from_pretrained(DirectPreferenceRewardModel.reward_model_name,
trust_remote_code=True,
torch_dtype=torch.float16).to(self.device)

def reward_single(self, prompt: str, completion: str, name: str) -> float:
def reward_single(self, prompt: str, completion: str, name: str ,with_penalty=True) -> float:
r""" Calculates a direct preference optimization (DPO) style reward for a completion,
which is a reference model's average log-probability for completion tokens given a prompt.
Uses guidance from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py.
"""
with torch.no_grad():

# Check if completion is
if completion.strip() == '' or len(completion) <= 5:
return -11 # exp(-11)=1.67e-5 < 2e-5=1/50257 (typical vocab size)

# Tokenize the combined prompt + completion.
combined = self.tokenizer(prompt + completion, return_tensors="pt").input_ids[0].to(self.device) # [seq_len]
# Tokenize only the prompt, to help determine prompt token length.
Expand Down Expand Up @@ -73,6 +79,13 @@ def reward_single(self, prompt: str, completion: str, name: str) -> float:
# Predict only where labels are available.
logits = logits[:, :-1, :] # [batch_size=1, seq_len-1, vocab_len]

if with_penalty:
# Apply penalty for repeated generation
for i in range(len(prompt_part)+1, len(combined)-1):
logit = logits[:,i,:].clone()
inputs = combined[len(prompt_part):i].clone()
logits[:,i,:] = self.logit_penalty(input_ids=inputs, logit=logit)

# Rescale via log(softmax(logits)).
logits = logits.log_softmax(-1)
# Calculate the model's log-probability for each actual completion token.
Expand All @@ -91,3 +104,14 @@ def get_rewards(self, prompt: str, completions: List[str], name: str) -> torch.F
dtype=torch.float32).to(self.device)
bt.logging.trace(f"DirectPreferenceRewardModel | rewards: {rewards.tolist()}")
return rewards

def logit_penalty(self, input_ids: torch.LongTensor, logit: torch.FloatTensor) -> torch.FloatTensor:
# Counts the unique tokens within each generation
uniques, counts = input_ids.unique(return_counts=True)
score = torch.gather(logit, 1, uniques.unsqueeze(0))

# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * (self.penalty**counts), score / (self.penalty**counts))

logit.scatter_(1, uniques.unsqueeze(0), score.to(logit.dtype))
return logit

0 comments on commit bd315ec

Please sign in to comment.