-
Notifications
You must be signed in to change notification settings - Fork 534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix TE HF checkpoint saving #1280
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can i load a TE ckpt into a non TEd model?
LGTM but also @dakinggg wdyt
fbf0967
to
72e1356
Compare
@mvpatel2000 loading from fp8 and training with bf16 seens to work with test run example here: Curious what the use case is in which you would do that though? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will review fully once CI passes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but same comment on waiting for CI/CD topass
Description
Fixes HF Checkpoint callback for TransformerEngine FP8 saving. This PR ensures we serialize the io.BytesIO extra_state tensors as regular tensors in
save_pretrained
so the code does not error.Tests
failed-hf-checkpointer-fp8-llama3-8b-metamath-4ep-KOTaOP
🔴success-hf-checkpointer-fp8-llama3-8b-metamath-4ep-yxNFTK
✅Issues
Closes https://databricks.atlassian.net/browse/RGENAI-255