Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🕊️ DPO padding free #2520

Merged
merged 29 commits into from
Jan 8, 2025
Merged

🕊️ DPO padding free #2520

merged 29 commits into from
Jan 8, 2025

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Dec 26, 2024

What does this PR do?

demo; further experiments in next comment

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
import torch

model_id = "Qwen/Qwen2-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:10%]")
training_args = DPOConfig(output_dir="Gemma2-2B-DPO-pf", max_prompt_length=128, max_completion_length=128, logging_steps=10, padding_free=True)
trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
trainer.train()

With and without padding-free (not sure why they don't match exactly, the logits do precisely match though)
Screenshot 2024-12-26 at 22 00 06

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

trl/trainer/dpo_config.py Outdated Show resolved Hide resolved
trl/trainer/dpo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/dpo_trainer.py Outdated Show resolved Hide resolved
@qgallouedec
Copy link
Member Author

Regression test:

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
import torch

model_id = "Qwen/Qwen2-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:10%]")
# dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO-no_pf", max_prompt_length=128, max_completion_length=128, logging_steps=10, padding_free=False)
trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
trainer.train()

Is the new padding_free=False (no_pf in screenshot) equivalent to DPO on current main branch (main in screenshot)? -> yes

Screenshot 2025-01-07 at 20 15 31

Does padding_free=True (pf in screenshot) results match padding_free=False (no_pf in screenshot) results? -> Yes

Screenshot 2025-01-07 at 20 19 41

(note: on screenshots its written "Gemma" but it's actually a Qwen model trained)

@qgallouedec qgallouedec requested review from lewtun and kashif January 7, 2025 19:21
Comment on lines -1158 to -1160
logits = outputs.logits[:, :-1, :]
labels = input_ids[:, 1:].clone()
loss_mask = loss_mask[:, 1:].bool()
Copy link
Member Author

@qgallouedec qgallouedec Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rolling works in both cases (flattened tensors and batched)

# Padding case
# input_ids = [[1, 2, 3, 4],
#              [5, 6, 7, 8]]
labels = input_ids[:, 1:].clone()
# labels =    [[2, 3, 4],
#              [6, 7, 8]]

# But

# Padding-free case
# input_ids = [[1, 2, 3, 4, 5, 6, 7, 8]]
labels = input_ids[:, 1:].clone()
# labels =    [[2, 3, 4, 5, 6, 7, 8]]

The first token of the first sequence (1) is removed but not the first token of the second sequence (5). To align the labels while keeping a consistent behaviour across sequences in the batch, we use roll instead. The only difference is that the first token, instead of being discarded, is appended to the end.

# Padding case
# input_ids = [[1, 2, 3, 4],
#              [5, 6, 7, 8]]
labels = torch.roll(input_ids, shifts=-1, dims=1)
# labels =    [[2, 3, 4, 1],
#              [6, 7, 8, 5]]

# And

# Padding-free case
# input_ids = [[1, 2, 3, 4, 5, 6, 7, 8]]
labels = torch.roll(input_ids, shifts=-1, dims=1)
# labels =    [[2, 3, 4, 5, 6, 7, 8, 1]]

Comment on lines +1164 to +1165
num_logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label
model_kwargs["num_logits_to_keep"] = num_logits_to_keep
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Read https://github.com/huggingface/trl/pull/2520/files#r1905992872 before

since we have an additional token in the end, we need to keep one additional token.
This update make things at this point nicer imo.

Copy link
Collaborator

@August-murr August-murr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qgallouedec qgallouedec merged commit 4516772 into main Jan 8, 2025
14 checks passed
@qgallouedec qgallouedec deleted the padding_free branch January 8, 2025 08:22
This was referenced Jan 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants