-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
float8 delayed scaling: private API to fix user overriding buffers #1292
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1292
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit faa1593 with merge base 56bf2e8 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Context: pytorch/torchtitan#654 If the user has delayed scaling and FSDP float8 all-gather on, there is a subtle bug that can happen if the user calls `model.to_empty(device="cuda")`: 1. to_empty recreates the buffers for tracking weight amax and scale 2. (1) leaves the buffers pointed to by Float8Linear.weight._amax_buffer, etc orphaned, because they don't participate in `to_empty` I couldn't think of an easy and clean way to auto-fix this since we can't expect `torch.nn.Module` to know that our logic has multiple references to the same buffer, so exposing a private API for now until we can think of something better. With the current fix, the user can then call `_maybe_fixup_delayed_scaling_buffers` manually to relink the buffers to the correct new versions. Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
9457814
to
faa1593
Compare
assert m_fp8[0].fp8_scale_weight is m_fp8[0].weight._scale_buffer | ||
|
||
m_fp8.to_empty(device="cuda") | ||
m_fp8[0]._maybe_fixup_delayed_scaling_buffers() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we would need to call this inside torchtitan’s training loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, which is definitely not ideal
Summary:
Context: pytorch/torchtitan#654
If the user has delayed scaling and FSDP float8 all-gather on, there is a subtle bug that can happen if the user calls
model.to_empty(device="cuda")
:to_empty
I couldn't think of an easy and clean way to auto-fix this since we can't expect
torch.nn.Module
to know that our logic has multiple references to the same buffer, so exposing a private API for now until we can think of something better.With the current fix, the user can then call
_maybe_fixup_delayed_scaling_buffers
manually to relink the buffers to the correct new versions.Test Plan: CI
Reviewers:
Subscribers:
Tasks:
Tags: