-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🕊️ DPO padding free #2520
🕊️ DPO padding free #2520
Changes from all commits
4980f09
c21d4ba
d3e2e19
1921a03
b451208
4384121
854c282
223a336
47997d4
a69a63c
c6e9be0
963d0ca
b60247e
f753449
dadd028
c68a316
72df715
31ba855
098a773
9328aa6
31d6e0b
fe28812
13a3250
02d5399
175c5e2
9f086e5
f96b5d1
d806e31
4ade2eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -394,6 +394,18 @@ def make_inputs_require_grad(module, input, output): | |
self.precompute_ref_log_probs = args.precompute_ref_log_probs | ||
self.use_num_logits_to_keep = args.use_num_logits_to_keep | ||
|
||
if args.padding_free: | ||
if model.config._attn_implementation != "flash_attention_2": | ||
warnings.warn( | ||
"Padding-free training is enabled, but the attention implementation is not set to " | ||
"'flash_attention_2'. Padding-free training flattens batches into a single sequence, and " | ||
"'flash_attention_2' is the only known attention mechanism that reliably supports this. Using " | ||
"other implementations may lead to unexpected behavior. To ensure compatibility, set " | ||
"`attn_implementation='flash_attention_2'` in the model configuration, or verify that your " | ||
"attention mechanism can handle flattened sequences." | ||
) | ||
self.padding_free = args.padding_free | ||
|
||
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader | ||
# keep track of first called to avoid computation of future calls | ||
self._precomputed_train_ref_log_probs = False | ||
|
@@ -1149,15 +1161,26 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to | |
# [0, 0, 0, x, x, x, 0]] | ||
# ^ start computing logits from here ([:, -(7-3+1):]) | ||
first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() | ||
num_logits_to_keep = loss_mask.shape[1] - first_compute_index | ||
model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label | ||
num_logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label | ||
model_kwargs["num_logits_to_keep"] = num_logits_to_keep | ||
|
||
if self.padding_free: | ||
# Flatten the input_ids, position_ids, and loss_mask | ||
# input_ids = [[a, b, c, 0], -> input_ids = [[a, b, c, d, e, f, g]] | ||
# [d, e, f, g]] position_ids = [[0, 1, 2, 0, 1, 2, 3]] | ||
input_ids = input_ids[attention_mask.bool()].unsqueeze(0) | ||
loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) | ||
position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 | ||
model_kwargs["position_ids"] = position_ids | ||
else: | ||
model_kwargs["attention_mask"] = attention_mask | ||
|
||
outputs = model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs) | ||
outputs = model(input_ids, **model_kwargs) | ||
logits = outputs.logits | ||
|
||
# Offset the logits by one to align with the labels | ||
logits = outputs.logits[:, :-1, :] | ||
labels = input_ids[:, 1:].clone() | ||
loss_mask = loss_mask[:, 1:].bool() | ||
Comment on lines
-1158
to
-1160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rolling works in both cases (flattened tensors and batched) # Padding case
# input_ids = [[1, 2, 3, 4],
# [5, 6, 7, 8]]
labels = input_ids[:, 1:].clone()
# labels = [[2, 3, 4],
# [6, 7, 8]]
# But
# Padding-free case
# input_ids = [[1, 2, 3, 4, 5, 6, 7, 8]]
labels = input_ids[:, 1:].clone()
# labels = [[2, 3, 4, 5, 6, 7, 8]] The first token of the first sequence (1) is removed but not the first token of the second sequence (5). To align the labels while keeping a consistent behaviour across sequences in the batch, we use roll instead. The only difference is that the first token, instead of being discarded, is appended to the end. # Padding case
# input_ids = [[1, 2, 3, 4],
# [5, 6, 7, 8]]
labels = torch.roll(input_ids, shifts=-1, dims=1)
# labels = [[2, 3, 4, 1],
# [6, 7, 8, 5]]
# And
# Padding-free case
# input_ids = [[1, 2, 3, 4, 5, 6, 7, 8]]
labels = torch.roll(input_ids, shifts=-1, dims=1)
# labels = [[2, 3, 4, 5, 6, 7, 8, 1]] |
||
labels = torch.roll(input_ids, shifts=-1, dims=1) | ||
loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool() | ||
|
||
if self.use_num_logits_to_keep: | ||
# Align labels with logits | ||
|
@@ -1178,6 +1201,17 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to | |
labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later | ||
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) | ||
per_token_logps[~loss_mask] = 0 | ||
per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) | ||
|
||
if self.padding_free: | ||
# Unflatten the per_token_logps (shape: [1, sum_seq_len] -> [batch_size, seq_len]) | ||
batch_size, seq_len = attention_mask.shape | ||
per_token_logps_ = torch.zeros( | ||
batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype | ||
) | ||
per_token_logps_[attention_mask.bool()] = per_token_logps | ||
per_token_logps = per_token_logps_ | ||
|
||
all_logps = per_token_logps.sum(-1) | ||
|
||
output = {} | ||
|
@@ -1208,8 +1242,22 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to | |
|
||
output["chosen_logps"] = all_logps[:num_examples] | ||
output["rejected_logps"] = all_logps[num_examples:] | ||
output["mean_chosen_logits"] = logits[:num_examples][loss_mask[:num_examples]].mean() | ||
output["mean_rejected_logits"] = logits[num_examples:][loss_mask[num_examples:]].mean() | ||
|
||
# Compute the mean logits | ||
if self.padding_free: | ||
# position_ids contains a sequence of range identifiers (e.g., [[0, 1, 2, 0, 1, 2, 3, ...]]). | ||
# There are 2*num_examples ranges in total: the first half corresponds to the chosen tokens, | ||
# and the second half to the rejected tokens. | ||
# To find the start of the rejected tokens, we look for the num_examples+1-th zero in pos_id. | ||
split_idx = (position_ids == 0).nonzero(as_tuple=True)[1][num_examples] | ||
mean_chosen_logits = logits[0, :split_idx][loss_mask[0, :split_idx]].mean() | ||
mean_rejected_logits = logits[0, split_idx:][loss_mask[0, split_idx:]].mean() | ||
else: | ||
mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean() | ||
mean_rejected_logits = logits[num_examples:][loss_mask[num_examples:]].mean() | ||
|
||
output["mean_chosen_logits"] = mean_chosen_logits | ||
output["mean_rejected_logits"] = mean_rejected_logits | ||
|
||
if self.aux_loss_enabled: | ||
output["aux_loss"] = outputs.aux_loss | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Read https://github.com/huggingface/trl/pull/2520/files#r1905992872 before
since we have an additional token in the end, we need to keep one additional token.
This update make things at this point nicer imo.