diff --git a/openvalidators/reward/dpo.py b/openvalidators/reward/dpo.py index 3a5860b..6cc7d59 100644 --- a/openvalidators/reward/dpo.py +++ b/openvalidators/reward/dpo.py @@ -33,17 +33,23 @@ def name(self) -> str: return RewardModelType.dpo.value def __init__(self, device: str): super().__init__() self.device = device + self.penalty = 1.2 # Same penalty as the original [paper](https://arxiv.org/pdf/1909.05858.pdf). self.tokenizer = AutoTokenizer.from_pretrained(DirectPreferenceRewardModel.reward_model_name) self.model = AutoModelForCausalLM.from_pretrained(DirectPreferenceRewardModel.reward_model_name, trust_remote_code=True, torch_dtype=torch.float16).to(self.device) - def reward_single(self, prompt: str, completion: str, name: str) -> float: + def reward_single(self, prompt: str, completion: str, name: str ,with_penalty=True) -> float: r""" Calculates a direct preference optimization (DPO) style reward for a completion, which is a reference model's average log-probability for completion tokens given a prompt. Uses guidance from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py. """ with torch.no_grad(): + + # Check if completion is + if completion.strip() == '' or len(completion) <= 5: + return -11 # exp(-11)=1.67e-5 < 2e-5=1/50257 (typical vocab size) + # Tokenize the combined prompt + completion. combined = self.tokenizer(prompt + completion, return_tensors="pt").input_ids[0].to(self.device) # [seq_len] # Tokenize only the prompt, to help determine prompt token length. @@ -73,6 +79,13 @@ def reward_single(self, prompt: str, completion: str, name: str) -> float: # Predict only where labels are available. logits = logits[:, :-1, :] # [batch_size=1, seq_len-1, vocab_len] + if with_penalty: + # Apply penalty for repeated generation + for i in range(len(prompt_part)+1, len(combined)-1): + logit = logits[:,i,:].clone() + inputs = combined[len(prompt_part):i].clone() + logits[:,i,:] = self.logit_penalty(input_ids=inputs, logit=logit) + # Rescale via log(softmax(logits)). logits = logits.log_softmax(-1) # Calculate the model's log-probability for each actual completion token. @@ -91,3 +104,14 @@ def get_rewards(self, prompt: str, completions: List[str], name: str) -> torch.F dtype=torch.float32).to(self.device) bt.logging.trace(f"DirectPreferenceRewardModel | rewards: {rewards.tolist()}") return rewards + + def logit_penalty(self, input_ids: torch.LongTensor, logit: torch.FloatTensor) -> torch.FloatTensor: + # Counts the unique tokens within each generation + uniques, counts = input_ids.unique(return_counts=True) + score = torch.gather(logit, 1, uniques.unsqueeze(0)) + + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + score = torch.where(score < 0, score * (self.penalty**counts), score / (self.penalty**counts)) + + logit.scatter_(1, uniques.unsqueeze(0), score.to(logit.dtype)) + return logit \ No newline at end of file