Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autotp training #6922

Open
wants to merge 49 commits into
base: master
Choose a base branch
from
Open

Autotp training #6922

wants to merge 49 commits into from

Conversation

inkcherry
Copy link
Contributor

@inkcherry inkcherry commented Jan 2, 2025

FYI @tjruwase @GuanhuaWang @delock @skyshine102 context: #5445
changes/support

  • auto tensor parallel training for HF model(zero compatible. I only tested zero1 currently)
  • distributed ckpt save(UCP is not supported).
  • HF model files save(set gather_16bit_weights_on_model_save=True in ds config).
  • Dataloader check.
  • Uts.
  • tp layer refactor by abstract layer design.

HF trainer dependency:
transformer: https://github.com/inkcherry/transformers/tree/ds_tp
accelerate: https://github.com/inkcherry/accelerate/tree/ds_tp
I could send them once ds support these api.

Usage:
Users do not need to modify the client code, they only need to configure the settings in the config file to achieve the desired functionality.
Below is an example of code for fine-tuning a LLaMA 2 model (SFT). It supports Zero3/FSDP training and enables TP training by simply adjusting the configuration

https://github.com/inkcherry/stanford_alpaca/commits/tp_demo_1127/
This branch contains three commits, with the last two commits added for quick experiments and logging purposes.
results
loss curve(gbs=16):
zero3(baseline)
image
tp(this)
image

zero1 with zero1+tp(zero compatible)
image

performance(For your reference only.):
zero3(not enabled any acceleration.) : 18GB 2.3s/it
zero1:38GB 1.30s/it
zero1+tp: 24GB 1.66s/it
extension:
I think async-TP/domino .etc. can be implemented by inheriting a class and overriding the fwd/bwd methods. The logic for gather/partition can be reused to achieve this.(please correct me if I am wrong)

Complex sharding can also be achieved through independent partitioning and gathering. Partitioning is mandatory, while gathering is required for training.
TODO:
embedding vocab parallel
Currently, the parallelism for embeddings is primarily based on hidden_dim parallel combined with allreduce. This approach takes advantage of efficient reduction kernels. and it is not forced to use.
In training, however, the more common method is vocab parallelism. Enabling by default can save a certain amount of GPU memory.

thanks for @delock guidance.
I also verified inference with cpu-inference workloads(Optimized Model List in https://github.com/intel/intel-extension-for-pytorch/tree/main).
many thanks for @xuguangxin @ikurtchen @rogerxfeng8 ,@Yejing-Lai ,@ys950902 .etc. Help review and address matters related to inference.

@delock
Copy link
Collaborator

delock commented Jan 2, 2025

@tjruwase @GuanhuaWang
We had internal review of @inkcherry 's PR. This PR allows train HF models with tensor parallel without need for megatron. Which is very friendly to user.

Let us know your plan for Domino integration. @inkcherry 's memory data looks good. With Domino we think it can have less impact on performance since TP communication can overlap with computation.

@inkcherry by design should autotp training work with ZeRO3 as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants