-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add axiswise scaling to Float8Linear #920
Conversation
Stack from ghstack (oldest at bottom): |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/920
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d70326c with merge base 52d27a1 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: This PR: support scaling of all arguments of all gemms to be axiswise, and ensure that training with axiswise scaling works e2e. Future PR: support more granular configurability and optimize performance, add docs Test Plan: ``` // tests pass ./test/float8/test_everything.sh // sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8: // 1. verify performance does not regress with tensorwise scaling // 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though // logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 6368d8ec2fb50eea52cd54e1ca5724047483a7eb ghstack-comment-id: 2368837904 Pull Request resolved: #920
Summary: This PR: support scaling of all arguments of all gemms to be axiswise, and ensure that training with axiswise scaling works e2e. Future PR: support more granular configurability and optimize performance, add docs Test Plan: ``` // tests pass ./test/float8/test_everything.sh // sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8: // 1. verify performance does not regress with tensorwise scaling // 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though // logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 77d62e8efb3a838035213125476c714290882a08 ghstack-comment-id: 2368837904 Pull Request resolved: #920
Summary: This PR: support scaling of all arguments of all gemms to be axiswise, and ensure that training with axiswise scaling works e2e. Future PR: support more granular configurability and optimize performance, add docs Test Plan: ``` // tests pass ./test/float8/test_everything.sh // sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8: // 1. verify performance does not regress with tensorwise scaling // 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though // logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 0d471db431fab2195a86e84bc7d3a93cc25db6e4 ghstack-comment-id: 2368837904 Pull Request resolved: #920
Summary: This PR: support scaling of all arguments of all gemms to be axiswise, and ensure that training with axiswise scaling works e2e. Future PR: support more granular configurability and optimize performance, add docs Test Plan: ``` // tests pass ./test/float8/test_everything.sh // sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8: // 1. verify performance does not regress with tensorwise scaling // 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though // logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: af334fd3f9f0b10e2f0a7cf1e38513741d1b45f7 ghstack-comment-id: 2368837904 Pull Request resolved: #920
Summary: This PR: support scaling of all arguments of all gemms to be axiswise, and ensure that training with axiswise scaling works e2e. Future PR: support more granular configurability and optimize performance, add docs Test Plan: ``` // tests pass ./test/float8/test_everything.sh // sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8: // 1. verify performance does not regress with tensorwise scaling // 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though // logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 304a5427739966a9601fa860ed248fc2bb902d67 ghstack-comment-id: 2368837904 Pull Request resolved: #920
@@ -39,6 +44,7 @@ | |||
from torch._dynamo.test_case import TestCase as DynamoTestCase | |||
from torch._dynamo.testing import CompileCounterWithBackend | |||
|
|||
# TODO(future PR): standardize IS_H100 with the rest of the codebase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we want is SM89 for testing cublas matmuls, since in theory we have hardware with that capability on CI/CD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I meant the variable name, not what it's checking
# Cast grad_output to float8_e5m2 during backward | ||
output = self.cast_output_to_float8_in_bw(output) | ||
|
||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need a separate path for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is choosing to move fast on a single use case at the expense of taking on some temporary tech debt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Left some minor readability nits
and other granularities in a separate PR. | ||
""" | ||
|
||
# TODO(this PR): types of inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, right?
# the reshapes are needed in order to make the shapes compatible with | ||
# torch.mm | ||
orig_shape = input_fp8.shape | ||
input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this equivalent to input_fp8.flatten(0, -2)
? If so, I find this more self-descriptive
orig_shape = input_fp8.shape | ||
input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1]) | ||
res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t) | ||
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conversely, res_bits.unflatten(0, orig_shape[:-1])
.
# the reshapes are needed in order to make the shapes compatible with | ||
# torch.mm | ||
grad_output_orig_shape = grad_output.shape | ||
grad_output_reshaped = grad_output.reshape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here and below
Summary: This PR: support scaling of all arguments of all gemms to be axiswise, and ensure that training with axiswise scaling works e2e. Future PR: support more granular configurability and optimize performance, add docs Feel free to ignore the UX introduced in this PR, it's just an intermediate step. See next PR for the real UX. Test Plan: ``` // tests pass ./test/float8/test_everything.sh // sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8: // 1. verify performance does not regress with tensorwise scaling // 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though // logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f ``` Reviewers: Subscribers: Tasks: Tags:
Summary: This PR: support scaling of all arguments of all gemms to be axiswise, and ensure that training with axiswise scaling works e2e. Future PR: support more granular configurability and optimize performance, add docs Feel free to ignore the UX introduced in this PR, it's just an intermediate step. See next PR for the real UX. Test Plan: ``` // tests pass ./test/float8/test_everything.sh // sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8: // 1. verify performance does not regress with tensorwise scaling // 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though // logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f ``` Reviewers: Subscribers: Tasks: Tags:
* Update README.md * Adding additional minor changes and Using markdown note blocks * Minor typos and undoing changes that are more impactful * adds
Summary:
This PR: support scaling of all arguments of all gemms to be axiswise,
and ensure that training with axiswise scaling works e2e.
Future PR: support more granular configurability and optimize
performance, add docs
Feel free to ignore the UX introduced in this PR, it's just an intermediate step. See next PR for the real UX.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: