-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding FSDP Support to Training Library #213
Conversation
560c2ec
to
0b4d516
Compare
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Oleg S <[email protected]>
…ining_backend to TrainingArgs.distributed_backend and DistributedTrainingBackend to DistributedBackend Signed-off-by: Oleg S <[email protected]>
e2b4ae4
to
95eb2c0
Compare
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
@@ -157,6 +181,12 @@ class TrainingArgs(BaseModel): | |||
cpu_offload_optimizer_pin_memory=False, | |||
) | |||
) | |||
fsdp_options: FSDPOptions = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this need to be a factory? I think it can just be an assignment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm following the current convention set by DeepSpeedOptions in the file, so imo if we want to change this, we should make a follow-up PR that updates both of them
reduce_dtype=torch.bfloat16, | ||
buffer_dtype=torch.bfloat16, | ||
), | ||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to expose this ever? This adds a bit of memory overhead for some performance- I think customarily it's probably a default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point, I think it's fine for now, but I will open an issue to track this, as I'm not sure how much of a performance hit compared to memory gain this option will be for us. Might be a nice bonus trick to avoid offloading in some configurations if performance isn't horrendous
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tracked in #228
src/instructlab/training/main_ds.py
Outdated
} | ||
return ds_config | ||
def setup_optimizer(args, model): | ||
if args.distributed_training_framework == "fsdp": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The typical way to do this is via this pattern:
if args.distributed_training_framework == "fsdp": | |
if DistributedBackend(args.distributed_training_framework) == DistributedBackend.FSDP: |
This collects "magic strings" like "fsdp" would be into the Enum object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: it actually has to be DistributedBackend.FSDP.value, since by this point the args have gone through the main_ds argparse post-torchrun and args.distributed_training_framework is just a string
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in latest commit
src/instructlab/training/main_ds.py
Outdated
model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95) | ||
) | ||
accelerator = setup_accelerator(args, model, grad_accum) | ||
if args.distributed_training_framework == "fsdp": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same enum trick here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: it actually has to be DistributedBackend.FSDP.value, since by this point the args have gone through the main_ds argparse post-torchrun and args.distributed_training_framework is just a string
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in latest commit
), | ||
lr_scheduler=lr_scheduler, | ||
dist_init_required=True, | ||
model, optimizer, _, lr_scheduler = accelerator.prepare( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see here that we're "double preparing" the model- is that okay? Is Accelerate smart enough to handle this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I have verified that it is, originally I had some conditionals to avoid it but accelerate was one step ahead
global_grad_norm = accelerator.clip_grad_norm_(model.parameters(), 1.0) | ||
optimizer.step() | ||
lr_scheduler.step() | ||
optimizer.zero_grad() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't seen this here conventionally, only at the top of the training loop. I guess it can be either place. I also see that this is where they put it in the docs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it aint broke 🤷🏻♂️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
++
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO nothing that I noticed is blocking an approval. The only thing that I really want is for this PR to be rebased as a single commit so the history is a bit neater. Once that's done I'll approve!
Signed-off-by: Mustafa Eyceoz <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
Adds support for FSDP and FSDP w/ CPU Offloading.
Introduces accelerate as a distributed backend abstraction (for FSDP/DeepSpeed)
Also fixes mistral template and cleans up data processing.
-Mustafa