-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🕊️ DPO padding free #2520
🕊️ DPO padding free #2520
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…o padding_free
logits = outputs.logits[:, :-1, :] | ||
labels = input_ids[:, 1:].clone() | ||
loss_mask = loss_mask[:, 1:].bool() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rolling works in both cases (flattened tensors and batched)
# Padding case
# input_ids = [[1, 2, 3, 4],
# [5, 6, 7, 8]]
labels = input_ids[:, 1:].clone()
# labels = [[2, 3, 4],
# [6, 7, 8]]
# But
# Padding-free case
# input_ids = [[1, 2, 3, 4, 5, 6, 7, 8]]
labels = input_ids[:, 1:].clone()
# labels = [[2, 3, 4, 5, 6, 7, 8]]
The first token of the first sequence (1) is removed but not the first token of the second sequence (5). To align the labels while keeping a consistent behaviour across sequences in the batch, we use roll instead. The only difference is that the first token, instead of being discarded, is appended to the end.
# Padding case
# input_ids = [[1, 2, 3, 4],
# [5, 6, 7, 8]]
labels = torch.roll(input_ids, shifts=-1, dims=1)
# labels = [[2, 3, 4, 1],
# [6, 7, 8, 5]]
# And
# Padding-free case
# input_ids = [[1, 2, 3, 4, 5, 6, 7, 8]]
labels = torch.roll(input_ids, shifts=-1, dims=1)
# labels = [[2, 3, 4, 5, 6, 7, 8, 1]]
num_logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label | ||
model_kwargs["num_logits_to_keep"] = num_logits_to_keep |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Read https://github.com/huggingface/trl/pull/2520/files#r1905992872 before
since we have an additional token in the end, we need to keep one additional token.
This update make things at this point nicer imo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this PR do?
demo; further experiments in next comment
With and without padding-free (not sure why they don't match exactly, the logits do precisely match though)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.