Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🕊️ DPO padding free #2520

Merged
merged 29 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
4980f09
padding free
qgallouedec Dec 26, 2024
c21d4ba
specify dtype
qgallouedec Dec 26, 2024
d3e2e19
test
qgallouedec Dec 26, 2024
1921a03
warnings when not flash attention
qgallouedec Dec 26, 2024
b451208
fix test
qgallouedec Dec 26, 2024
4384121
remove
qgallouedec Dec 26, 2024
854c282
docstring padding-free
qgallouedec Dec 26, 2024
223a336
flash-attn dep
qgallouedec Dec 26, 2024
47997d4
Stronger warning
qgallouedec Dec 26, 2024
a69a63c
require_flash_attn in test
qgallouedec Dec 26, 2024
c6e9be0
flash-attn in CI
qgallouedec Dec 26, 2024
963d0ca
rm flash-attn from dep
qgallouedec Dec 26, 2024
b60247e
Remove flash-attn dependency from test workflows
qgallouedec Dec 26, 2024
f753449
refactor
qgallouedec Dec 26, 2024
dadd028
Update .github/workflows/tests.yml
qgallouedec Dec 26, 2024
c68a316
Update trl/trainer/dpo_trainer.py
qgallouedec Dec 26, 2024
72df715
drop require flash-attn
qgallouedec Dec 26, 2024
31ba855
Merge branch 'padding_free' of https://github.com/huggingface/trl int…
qgallouedec Dec 26, 2024
098a773
fix dtype
qgallouedec Dec 26, 2024
9328aa6
Merge branch 'main' into padding_free
qgallouedec Jan 6, 2025
31d6e0b
refine warning
qgallouedec Jan 7, 2025
fe28812
Merge branch 'main' into padding_free
qgallouedec Jan 7, 2025
13a3250
Update trl/trainer/dpo_config.py
qgallouedec Jan 7, 2025
02d5399
Add logic to compute mean logits for chosen and rejected tokens with …
qgallouedec Jan 7, 2025
175c5e2
format
qgallouedec Jan 7, 2025
9f086e5
Update trl/trainer/dpo_trainer.py
qgallouedec Jan 7, 2025
f96b5d1
Update trl/trainer/dpo_trainer.py
qgallouedec Jan 7, 2025
d806e31
fix comment [ci skip]
qgallouedec Jan 7, 2025
4ade2eb
fix num logits to keep
qgallouedec Jan 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ def test_dpo_trainer_use_num_logits_to_keep(self):
model=model,
ref_model=None,
args=training_args,
tokenizer=tokenizer,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)
Expand All @@ -1122,7 +1122,7 @@ def test_dpo_trainer_use_num_logits_to_keep(self):
model=model,
ref_model=None,
args=training_args,
tokenizer=tokenizer,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)
Expand Down Expand Up @@ -1165,6 +1165,45 @@ def test_dpo_trainer_use_num_logits_to_keep(self):

trainer.train()

def test_padding_free(self):
model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

# Normally, we need `attn_implementation="flash_attention_2"` to that the model returns correct logits.
# Without it, the logits may be incorrect, but that's fine here. This test focuses only on the inner logic
# of padding_free.
model = AutoModelForCausalLM.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
learning_rate=9e-1,
per_device_train_batch_size=2,
padding_free=True,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

trainer = DPOTrainer(
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))


@require_vision
class DPOVisionTrainerTester(unittest.TestCase):
Expand Down
12 changes: 12 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ class DPOConfig(TrainingArguments):
useful for saving memory and speeding up training by not computing the logits for all tokens, especially in
scenarios when working with very long prompts where labels are ignored (-100).
[Read more](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM)
padding_free (`bool`, *optional*, defaults to `False`):
Whether forward passes are performed without padding by flattening all sequences in the batch
into a single continuous sequence. This approach requires associating a `position_ids` vector to track
positional information. Currently, this is only supported with the `flash_attention_2` mechanism, as it
can handle the flattened batch structure.
"""

learning_rate: float = field(
Expand Down Expand Up @@ -355,3 +360,10 @@ class DPOConfig(TrainingArguments):
"tokens, especially in scenarios when working with very long prompts where labels are ignored (-100)."
},
)
padding_free: bool = field(
default=False,
metadata={
"help": "Whether the forward passes are performed without padding, i.e. flattening all the samples in the "
"batch into a single sample, associated with a position_ids vector. Only possible with flash-attention."
},
)
64 changes: 56 additions & 8 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,18 @@ def make_inputs_require_grad(module, input, output):
self.precompute_ref_log_probs = args.precompute_ref_log_probs
self.use_num_logits_to_keep = args.use_num_logits_to_keep

if args.padding_free:
if model.config._attn_implementation != "flash_attention_2":
warnings.warn(
"Padding-free training is enabled, but the attention implementation is not set to "
"'flash_attention_2'. Padding-free training flattens batches into a single sequence, and "
"'flash_attention_2' is the only known attention mechanism that reliably supports this. Using "
"other implementations may lead to unexpected behavior. To ensure compatibility, set "
"`attn_implementation='flash_attention_2'` in the model configuration, or verify that your "
"attention mechanism can handle flattened sequences."
)
self.padding_free = args.padding_free

# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
# keep track of first called to avoid computation of future calls
self._precomputed_train_ref_log_probs = False
Expand Down Expand Up @@ -1149,15 +1161,26 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
# [0, 0, 0, x, x, x, 0]]
# ^ start computing logits from here ([:, -(7-3+1):])
first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
num_logits_to_keep = loss_mask.shape[1] - first_compute_index
model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label
num_logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label
model_kwargs["num_logits_to_keep"] = num_logits_to_keep
Comment on lines +1164 to +1165
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Read https://github.com/huggingface/trl/pull/2520/files#r1905992872 before

since we have an additional token in the end, we need to keep one additional token.
This update make things at this point nicer imo.


if self.padding_free:
# Flatten the input_ids, position_ids, and loss_mask
# input_ids = [[a, b, c, 0], -> input_ids = [[a, b, c, d, e, f, g]]
# [d, e, f, g]] position_ids = [[0, 1, 2, 0, 1, 2, 3]]
input_ids = input_ids[attention_mask.bool()].unsqueeze(0)
loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0)
position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1
model_kwargs["position_ids"] = position_ids
else:
model_kwargs["attention_mask"] = attention_mask

outputs = model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs)
outputs = model(input_ids, **model_kwargs)
logits = outputs.logits

# Offset the logits by one to align with the labels
logits = outputs.logits[:, :-1, :]
labels = input_ids[:, 1:].clone()
loss_mask = loss_mask[:, 1:].bool()
Comment on lines -1158 to -1160
Copy link
Member Author

@qgallouedec qgallouedec Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rolling works in both cases (flattened tensors and batched)

# Padding case
# input_ids = [[1, 2, 3, 4],
#              [5, 6, 7, 8]]
labels = input_ids[:, 1:].clone()
# labels =    [[2, 3, 4],
#              [6, 7, 8]]

# But

# Padding-free case
# input_ids = [[1, 2, 3, 4, 5, 6, 7, 8]]
labels = input_ids[:, 1:].clone()
# labels =    [[2, 3, 4, 5, 6, 7, 8]]

The first token of the first sequence (1) is removed but not the first token of the second sequence (5). To align the labels while keeping a consistent behaviour across sequences in the batch, we use roll instead. The only difference is that the first token, instead of being discarded, is appended to the end.

# Padding case
# input_ids = [[1, 2, 3, 4],
#              [5, 6, 7, 8]]
labels = torch.roll(input_ids, shifts=-1, dims=1)
# labels =    [[2, 3, 4, 1],
#              [6, 7, 8, 5]]

# And

# Padding-free case
# input_ids = [[1, 2, 3, 4, 5, 6, 7, 8]]
labels = torch.roll(input_ids, shifts=-1, dims=1)
# labels =    [[2, 3, 4, 5, 6, 7, 8, 1]]

labels = torch.roll(input_ids, shifts=-1, dims=1)
loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool()

if self.use_num_logits_to_keep:
# Align labels with logits
Expand All @@ -1178,6 +1201,17 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps[~loss_mask] = 0
per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1)

if self.padding_free:
# Unflatten the per_token_logps (shape: [1, sum_seq_len] -> [batch_size, seq_len])
batch_size, seq_len = attention_mask.shape
per_token_logps_ = torch.zeros(
batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype
)
per_token_logps_[attention_mask.bool()] = per_token_logps
per_token_logps = per_token_logps_

all_logps = per_token_logps.sum(-1)

output = {}
Expand Down Expand Up @@ -1208,8 +1242,22 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to

output["chosen_logps"] = all_logps[:num_examples]
output["rejected_logps"] = all_logps[num_examples:]
output["mean_chosen_logits"] = logits[:num_examples][loss_mask[:num_examples]].mean()
output["mean_rejected_logits"] = logits[num_examples:][loss_mask[num_examples:]].mean()

# Compute the mean logits
if self.padding_free:
# position_ids contains a sequence of range identifiers (e.g., [[0, 1, 2, 0, 1, 2, 3, ...]]).
# There are 2*num_examples ranges in total: the first half corresponds to the chosen tokens,
# and the second half to the rejected tokens.
# To find the start of the rejected tokens, we look for the num_examples+1-th zero in pos_id.
split_idx = (position_ids == 0).nonzero(as_tuple=True)[1][num_examples]
mean_chosen_logits = logits[0, :split_idx][loss_mask[0, :split_idx]].mean()
mean_rejected_logits = logits[0, split_idx:][loss_mask[0, split_idx:]].mean()
else:
mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean()
mean_rejected_logits = logits[num_examples:][loss_mask[num_examples:]].mean()

output["mean_chosen_logits"] = mean_chosen_logits
output["mean_rejected_logits"] = mean_rejected_logits

if self.aux_loss_enabled:
output["aux_loss"] = outputs.aux_loss
Expand Down
Loading