Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dpo penalty update #138

Merged
merged 3 commits into from
Aug 25, 2023
Merged

Dpo penalty update #138

merged 3 commits into from
Aug 25, 2023

Conversation

Eugene-hu
Copy link
Contributor

@@ -33,17 +33,23 @@ def name(self) -> str: return RewardModelType.dpo.value
def __init__(self, device: str):
super().__init__()
self.device = device
self.penalty = 1.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have any data for why we picked 1.2 for penalty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the same default parameter used by huggingface and was retrieved from this paper (https://arxiv.org/pdf/2305.14314.pdf)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that adding this reference to the code in one comment line could help clarify future doubts


# Check if completion is
if completion.strip() == '' or len(completion) <= 5:
return -11 # exp(-11)=1.67e-5 < 2e-5=1/50257 (typical vocab size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I got it why is -11, could you please elaborate more so I could better understand it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exp(-11) corresponds to base value given to zero or short responses; it is the nearest integer value that is less than equal probability value across all logits (1/50257).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be feasible to calculate this in runtime by getting something like 1 / model.vocab_size? That way the code will be independent of the model used as it would be calculated dynamically.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be feasible to calculate this in runtime by getting something like 1 / model.vocab_size? That way the code will be independent of the model used as it would be calculated dynamically.

Yes this can be done in a future update, and will be necessary if the DPO model tokenizer is changed to something non-standard.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I will create an issue for that so we don't lose track of this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants