Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dpo penalty update #138

Merged
merged 3 commits into from
Aug 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion openvalidators/reward/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,23 @@ def name(self) -> str: return RewardModelType.dpo.value
def __init__(self, device: str):
super().__init__()
self.device = device
self.penalty = 1.2 # Same penalty as the original [paper](https://arxiv.org/pdf/1909.05858.pdf).
self.tokenizer = AutoTokenizer.from_pretrained(DirectPreferenceRewardModel.reward_model_name)
self.model = AutoModelForCausalLM.from_pretrained(DirectPreferenceRewardModel.reward_model_name,
trust_remote_code=True,
torch_dtype=torch.float16).to(self.device)

def reward_single(self, prompt: str, completion: str, name: str) -> float:
def reward_single(self, prompt: str, completion: str, name: str ,with_penalty=True) -> float:
r""" Calculates a direct preference optimization (DPO) style reward for a completion,
which is a reference model's average log-probability for completion tokens given a prompt.
Uses guidance from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py.
"""
with torch.no_grad():

# Check if completion is
if completion.strip() == '' or len(completion) <= 5:
return -11 # exp(-11)=1.67e-5 < 2e-5=1/50257 (typical vocab size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I got it why is -11, could you please elaborate more so I could better understand it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exp(-11) corresponds to base value given to zero or short responses; it is the nearest integer value that is less than equal probability value across all logits (1/50257).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be feasible to calculate this in runtime by getting something like 1 / model.vocab_size? That way the code will be independent of the model used as it would be calculated dynamically.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be feasible to calculate this in runtime by getting something like 1 / model.vocab_size? That way the code will be independent of the model used as it would be calculated dynamically.

Yes this can be done in a future update, and will be necessary if the DPO model tokenizer is changed to something non-standard.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I will create an issue for that so we don't lose track of this


# Tokenize the combined prompt + completion.
combined = self.tokenizer(prompt + completion, return_tensors="pt").input_ids[0].to(self.device) # [seq_len]
# Tokenize only the prompt, to help determine prompt token length.
Expand Down Expand Up @@ -73,6 +79,13 @@ def reward_single(self, prompt: str, completion: str, name: str) -> float:
# Predict only where labels are available.
logits = logits[:, :-1, :] # [batch_size=1, seq_len-1, vocab_len]

if with_penalty:
# Apply penalty for repeated generation
for i in range(len(prompt_part)+1, len(combined)-1):
logit = logits[:,i,:].clone()
inputs = combined[len(prompt_part):i].clone()
logits[:,i,:] = self.logit_penalty(input_ids=inputs, logit=logit)

# Rescale via log(softmax(logits)).
logits = logits.log_softmax(-1)
# Calculate the model's log-probability for each actual completion token.
Expand All @@ -91,3 +104,14 @@ def get_rewards(self, prompt: str, completions: List[str], name: str) -> torch.F
dtype=torch.float32).to(self.device)
bt.logging.trace(f"DirectPreferenceRewardModel | rewards: {rewards.tolist()}")
return rewards

def logit_penalty(self, input_ids: torch.LongTensor, logit: torch.FloatTensor) -> torch.FloatTensor:
# Counts the unique tokens within each generation
uniques, counts = input_ids.unique(return_counts=True)
score = torch.gather(logit, 1, uniques.unsqueeze(0))

# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * (self.penalty**counts), score / (self.penalty**counts))

logit.scatter_(1, uniques.unsqueeze(0), score.to(logit.dtype))
return logit