-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support resuming of deepspeed + Lora + offloading #29015
Changes from 4 commits
4acce7a
c0d12fc
84a8867
fd47c33
3d62791
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1734,6 +1734,15 @@ def _inner_training_loop( | |
) | ||
elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: | ||
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) | ||
|
||
# deepspeed ckpt loading | ||
if resume_from_checkpoint is not None and self.is_deepspeed_enabled: | ||
deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) | ||
if self.args.deepspeed_force_lr_scheduler_checkpointing and self.model_wrapped.lr_scheduler is None: | ||
if os.path.isfile(os.path.join(resume_from_checkpoint, SCHEDULER_NAME)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. loading scheduler is handled in |
||
with warnings.catch_warnings(record=True) as caught_warnings: | ||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(resume_from_checkpoint, SCHEDULER_NAME))) | ||
reissue_pt_warnings(caught_warnings) | ||
|
||
# Check if saved optimizer or scheduler states exist | ||
self._load_optimizer_and_scheduler(resume_from_checkpoint) | ||
|
@@ -2416,6 +2425,12 @@ def _save_checkpoint(self, model, trial, metrics=None): | |
else: | ||
staging_output_dir = os.path.join(run_dir, f"tmp-{checkpoint_folder}") | ||
self.save_model(staging_output_dir, _internal_call=True) | ||
if self.is_deepspeed_enabled: | ||
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed | ||
# config `stage3_gather_16bit_weights_on_model_save` is True | ||
self.model_wrapped.save_checkpoint(staging_output_dir) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this happens in |
||
if self.args.deepspeed_force_lr_scheduler_checkpointing and self.model_wrapped.lr_scheduler is None: | ||
torch.save(self.lr_scheduler.state_dict(), os.path.join(staging_output_dir, SCHEDULER_NAME)) | ||
|
||
if not self.args.save_only_model: | ||
# Save optimizer and scheduler | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1316,6 +1316,18 @@ class TrainingArguments: | |
"help": "Activates neftune noise embeddings into the model. NEFTune has been proven to drastically improve model performances for instrcution fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune. Only supported for `PreTrainedModel` and `PeftModel` classes." | ||
}, | ||
) | ||
|
||
deepspeed_force_lr_scheduler_checkpointing: bool = field( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
default=False, | ||
metadata={ | ||
"help": ( | ||
"Force saving and loading or checkpointing the lr_scheduler when deepspeed is enabled and it does not " | ||
"support the lr_scheduler type. " | ||
"Use this to force keeping track of lr_scheduler when the model lr_scheduler type does not fall into " | ||
"its supported lr_scheduler categories." | ||
) | ||
}, | ||
) | ||
|
||
def __post_init__(self): | ||
# expand paths, if not os.makedirs("~/bar") will make directory | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we somehow revert this and just force-set it to
True
in our trainer?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a fair point! I'll push a change now