Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement efficient packing without cross-contamination attention #4224

Merged
merged 13 commits into from
Jul 3, 2024

Conversation

chuan298
Copy link

@chuan298 chuan298 commented Jun 11, 2024

What does this PR do?

Update 15/6/2024: Add support packing for eager and sdpa


Fixes #2289

Implement efficient packing without cross-contamination attention
Taking inspiration from some repository as axolotl and functionary, I applied packing sequences more effectively, enabling the model to learn samples more efficiently without attending to other samples within the same pack. Now I only support this implement for sft with flash_attention_2.

Example training config:

### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
flash_attn: fa2

### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all

### dataset
dataset: alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
efficient_packing: true

### output
output_dir: saves/llama3-8b/lora/sft
logging_steps: 1
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true

### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

Before submitting

@hiyouga hiyouga added the pending This problem is yet to be addressed label Jun 12, 2024
@hiyouga hiyouga mentioned this pull request Jun 15, 2024
1 task
@AlongWY
Copy link
Contributor

AlongWY commented Jun 20, 2024

是否应该考虑使用 varlen_flash_atten 实现?

@@ -33,6 +33,9 @@ def run_sft(
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)

if data_args.efficient_packing:
configure_packing(model.config, model_args)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we do configure_packing in llamafactory.model.patcher?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I just edited it

@@ -66,6 +66,21 @@

SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}

SUPPORTED_CLASS_FOR_MULTIPACK = [
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it "efficient_packing" rather than "multipack"?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I just fixed.

@chuan298
Copy link
Author

是否应该考虑使用 varlen_flash_atten 实现?

Hi @AlongWY , The models in transformers have used flash_attn_varlen_func by default when passing attention_mask. I just made a slight change to the attention_mask when packing sequences and returned indices, cu_seqlens, and max_seqlen_in_batch corresponding to the modified attention_mask.

@chuan298
Copy link
Author

chuan298 commented Jul 2, 2024

Hi @hiyouga
It was my mistake for not testing thoroughly. I just changed efficient_packing to ModelArguments to minimize changes in the code (the old code required passing data_args in every load_model function, which I found quite unreasonable and led to errors in other parts), and I have now thoroughly retested everything.

@hiyouga
Copy link
Owner

hiyouga commented Jul 2, 2024

hi @chuan298
Thank you for your efforts in integrating efficient packing into llama factory. We will merge this PR in the coming days.

)
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)
Copy link
Owner

@hiyouga hiyouga Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_prepare_4d_causal_attention_mask has never been used in the Llama's forward pass, this patch will not affect training

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to construct the 4d attention mask during get_dataset

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, in an older version of transformers, I saw the code using _prepare_4d_causal_attention_mask with eager mode. I just checked again, and it has been removed and modified. We need to convert the attention mask with sdpa and eager mode to something like:

For example, if batch = 3 and seqlen = 6, the old attention_mask is:
     
        [
          [1, 1, 2, 2, 2, 0],
          [1, 1, 1, 2, 2, 0],
          [1, 1, 1, 1, 1, 1]
        ]
        
Convert to new 4D-attention mask:
        
        [
          [
            [
                [0, -inf, -inf, -inf, -inf, -inf],
                [0, 0, -inf, -inf, -inf, -inf],
                [-inf, -inf, 0, -inf, -inf, -inf],
                [-inf, -inf, 0, 0, -inf, -inf],
                [-inf, -inf, 0, 0, 0, -inf],
                [-inf, -inf, -inf, -inf, -inf, 0]
            ]
          ],
          [
            [
                [0, -inf, -inf, -inf, -inf, -inf],
                [0, 0, -inf, -inf, -inf, -inf],
                [0, 0, 0, -inf, -inf, -inf],
                [-inf, -inf, -inf, 0, -inf, -inf],
                [-inf, -inf, -inf, 0, 0, -inf],
                [-inf, -inf, -inf, -inf, -inf, 0]
            ]
          ],
          [
            [
                [0, -inf, -inf, -inf, -inf, -inf],
                [0, 0, -inf, -inf, -inf, -inf],
                [0, 0, 0, -inf, -inf, -inf],
                [0, 0, 0, 0, -inf, -inf],
                [0, 0, 0, 0, 0, -inf],
                [0, 0, 0, 0, 0, 0]
            ]
          ]
        ]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will fix that part right now

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, I am fixing it

Copy link
Owner

@hiyouga hiyouga Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are designing a new data collator for SFT Trainer that converts attention masks with indices to 4d attention masks with correct dtype

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I have learned a lot by looking at the way you design your system

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

glad to hear any valuable advice from you about the implementation

Copy link
Owner

@hiyouga hiyouga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is now ready to be merged. Patches for MoE models are left for future work. We will update the patch for eager/sdpa attention soon.

@hiyouga hiyouga merged commit 87d9b2d into hiyouga:main Jul 3, 2024
1 check passed
@hiyouga hiyouga added solved This problem has been already solved and removed pending This problem is yet to be addressed labels Jul 3, 2024
@Leo-T-Zang
Copy link

Thanks for doing this! One quick question: does position id re-initialized for packed examples?

@chuan298
Copy link
Author

Thanks for doing this! One quick question: does position id re-initialized for packed examples?

Almost all models use RoPE instead of absolute positional embeddings, so we don't need to reinitialize position ids

@YeQiuO
Copy link
Contributor

YeQiuO commented Sep 3, 2024

您好,我在 Qwen2-7B-Instruct 上使用该方法后,效果远不如packing前训练,请问有什么解决思路吗?

@FoolMark
Copy link

FoolMark commented Sep 5, 2024

您好,我在 Qwen2-7B-Instruct 上使用该方法后,效果远不如packing前训练,请问有什么解决思路吗?
packing与原始non-packing方法的区别在于

  1. packing后一条数据可能有非常多的sample (视cutoff_len大小决定),所以实际上的batch_size会比之前大很多。 learning_rate 可能需要做相应调整
  2. 原来每条样本在loss计算中的权重是一致的,但是packing以后, 更新的梯度可能会被更长的样本dominate,相当于较短的样本权重变低了,这种情况下可以计算并更改每个样本的loss_weight与non-packing对齐
    目前只想到这些可能的原因。 理论上与non-packing应该不会有太大差距,只是不知道当前的实现是否完全正确

@YeQiuO
Copy link
Contributor

YeQiuO commented Sep 5, 2024

@FoolMark 感谢您的回复

目前我实验下来,测评结果优劣对比:不用 packing > 用 packing >>> neat_packing

且这种趋势在 Llama3_8B_Instruct 的 SFT 微调上非常明显,前两种方法的 B-4 有13,而 neat_packing 只有 3,这显得非常的不正常。

该 PR 的主要工作是 get_unpad_data 和 4D_mask,用于防止合并后的样本训练时看到其他样本。按道理不应该有如此大的副作用,是否有代码编写错误?

或许是因为该方法的使用有前置条件,不能直接使用?@chuan298 作者可以帮忙解释一下吗?非常感谢!

@chuan298
Copy link
Author

chuan298 commented Sep 5, 2024

@FoolMark 感谢您的回复

目前我实验下来,测评结果优劣对比:不用 packing > 用 packing >>> neat_packing

且这种趋势在 Llama3_8B_Instruct 的 SFT 微调上非常明显,前两种方法的 B-4 有13,而 neat_packing 只有 3,这显得非常的不正常。

该 PR 的主要工作是 get_unpad_data 和 4D_mask,用于防止合并后的样本训练时看到其他样本。按道理不应该有如此大的副作用,是否有代码编写错误?

或许是因为该方法的使用有前置条件,不能直接使用?@chuan298 作者可以帮忙解释一下吗?非常感谢!

Thank you for reporting the issue. I will test it again and get back to you as soon as possible.

@bao-xiaoyi
Copy link

@FoolMark 感谢您的回复

目前我实验下来,测评结果优劣对比:不用 packing > 用 packing >>> neat_packing

且这种趋势在 Llama3_8B_Instruct 的 SFT 微调上非常明显,前两种方法的 B-4 有13,而 neat_packing 只有 3,这显得非常的不正常。

该 PR 的主要工作是 get_unpad_data 和 4D_mask,用于防止合并后的样本训练时看到其他样本。按道理不应该有如此大的副作用,是否有代码编写错误?

或许是因为该方法的使用有前置条件,不能直接使用?@chuan298 作者可以帮忙解释一下吗?非常感谢!

理论上至少neat_packing会好于packing才对,packing存在交叉污染问题,能确定你们的实验结果没有搞混吗

@YeQiuO
Copy link
Contributor

YeQiuO commented Sep 12, 2024

@bao-xiaoyi

经过多次试验,可以确定 COIG-CQIA 数据集微调 Qwen2_7B-Instruct
不用 packing > 用 packing >>> neat_packing

@bao-xiaoyi
Copy link

bao-xiaoyi commented Sep 13, 2024

@bao-xiaoyi

经过多次试验,可以确定 COIG-CQIA 数据集微调 Qwen2_7B-Instruct 时 不用 packing > 用 packing >>> neat_packing

@hiyouga @chuan298 虽然我还没有机器在llamafactory上进行调试,但我使用了https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py 的验证程序,发现该验证程序在starcoderv2上验证失败了。各位大佬可以参考一下,该方法目前是否实现有误

@YeQiuO
Copy link
Contributor

YeQiuO commented Sep 13, 2024

@bao-xiaoyi 这个 PR 就是参考了axolotlfunctionary,,可以看看 PR conversation 第一行和文件内的注释

@YeQiuO
Copy link
Contributor

YeQiuO commented Sep 13, 2024

@bao-xiaoyi 我看了下 starcoderv2 的源码,已经没有 _get_unpad_data() 方法了,貌似最新版的 transformer 不支持了,都是直接 _prepare_4d_causal_attention_mask_with_cache_position ()

@bao-xiaoyi
Copy link

@bao-xiaoyi 我看了下 starcoderv2 的源码,已经没有 _get_unpad_data() 方法了,貌似最新版的 transformer 不支持了,都是直接 _prepare_4d_causal_attention_mask_with_cache_position ()

https://github.com/MeetKai/functionary/blob/2ade2578017077c9f1b061140a41f8c2349dbc71/functionary/train/packing/monkey_patch_packing.py#L147C64-L148C9
你是说这里其实是无效?

@YeQiuO
Copy link
Contributor

YeQiuO commented Sep 13, 2024

个人愚见,Flash_attn 貌似统一到一个文件中了,目前的实现不满足 transformers 的最近版本,如码所示

@YeQiuO
Copy link
Contributor

YeQiuO commented Sep 13, 2024

但就 eager and sdpa attention,效果依然不好

@bao-xiaoyi
Copy link

_prepare_4d_causal_attention_mask_with_cache_position ()

starcoderv2的确是有这样的问题,但是deepseek-v2-lite我测试了没问题,平均loss误差很小很小

@chuan298
Copy link
Author

Hi @YeQiuO @bao-xiaoyi I'm sorry for the late reply.
I just tested the two methods above with the Qwen2-7B-Instruct model and the mahiatlinux/Reflection-Dataset-ShareGPT-v2 dataset. Below are some results.

  • With no packing
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
deepspeed: examples/deepspeed/ds_z3_config.json
flash_attn: fa2
cutoff_len: 4096
per_device_train_batch_size: 4
gradient_accumulation_steps: 4
learning_rate: 2.0e-5
num_train_epochs: 1.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true

image

  • With neat_packing
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
deepspeed: examples/deepspeed/ds_z3_config.json
flash_attn: fa2
cutoff_len: 4096
per_device_train_batch_size: 4
gradient_accumulation_steps: 4
learning_rate: 4.0e-5
num_train_epochs: 1.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true

image

I found that the results of neat_packing were relatively higher compared to no packing. I then reviewed the latest version of transformers and quickly implemented changes to resemble DataCollatorWithPadding and the _flash_attention_forward function in https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py#L272 by removing the attention_mask and adding position_ids in preprocess_packed_supervised_dataset when returning. However, this implementation only runs with batch_size=1.

stage: sft
do_train: true
finetuning_type: lora
lora_target: all
deepspeed: examples/deepspeed/ds_z3_config.json
flash_attn: fa2
cutoff_len: 4096
per_device_train_batch_size: 1
gradient_accumulation_steps: 4
learning_rate: 4.0e-5
num_train_epochs: 1.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true

image

The results achieved with this were equivalent to no packing, so I believe there have been changes in the latest version of transformers, and I will update the code as soon as possible.

@hiyouga
Copy link
Owner

hiyouga commented Sep 13, 2024

@chuan298 I think we should adopt same token batch sizes and training steps for fair comparison? I find the model trained by neat_packing is underfitted.

@bao-xiaoyi
Copy link

image
感觉functionary那边这句话挺值得深思的

@bao-xiaoyi
Copy link

https://research.ibm.com/blog/hugging-face-training-flash-attention
现在的实现过程应该是有些问题的,大佬们可以参考下这个

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved This problem has been already solved
Projects
None yet
Development

Successfully merging this pull request may close these issues.

sft_packing实现的问题
7 participants