Skip to content

Commit

Permalink
repeat penalty for generation
Browse files Browse the repository at this point in the history
  • Loading branch information
Eugene-hu committed Aug 25, 2023
1 parent 64e9e82 commit 44ae35e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions openvalidators/reward/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def reward_single(self, prompt: str, completion: str, name: str ,with_penalty=Fa

if with_penalty:
# Apply penalty for repeated generation
for i in range(len(prompt_part), len(combined)-1):
for i in range(len(prompt_part)+1, len(combined)-1):
logit = logits[:,i,:].clone()
inputs = combined[:i].clone()
inputs = combined[len(prompt_part):i].clone()
logits[:,i,:] = self.logit_penalty(input_ids=inputs, logit=logit)

# Rescale via log(softmax(logits)).
Expand Down

0 comments on commit 44ae35e

Please sign in to comment.