Skip to content

Commit

Permalink
refine warning
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Jan 7, 2025
1 parent 9328aa6 commit 31d6e0b
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,12 @@ def make_inputs_require_grad(module, input, output):
if args.padding_free:
if model.config._attn_implementation != "flash_attention_2":
warnings.warn(
"Padding-free training is only supported with the `flash_attention_2` implementation. "
"You're very likely to get unexpected results. Please set "
"`attn_implementation='flash_attention_2'` in the model config."
"Padding-free training is enabled, but the attention implementation is not set to "
"'flash_attention_2'. Padding-free training flattens batches into a single sequence, and "
"'flash_attention_2' is the only known attention mechanism that reliably supports this. Using "
"other implementations may lead to unexpected behavior. To ensure compatibility, set "
"`attn_implementation='flash_attention_2'` in the model configuration, or verify that your "
"attention mechanism can handle flattened sequences."
)
self.padding_free = args.padding_free

Expand Down

0 comments on commit 31d6e0b

Please sign in to comment.