-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a register_replacement to fix float8 delayed scaling kernel fusion issues in torchao/float8 #1469
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1469
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a3b1245 with merge base f7f20e9 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D67758184 |
…n issues in torchao/float8 (pytorch#1469) Summary: The original pr in inductor, pytorch/pytorch#143464 As suggested in the comments, we moved the patterns to torchao repo. Differential Revision: D67758184
…n issues in torchao/float8 (pytorch#1469) Summary: The original pr in inductor, pytorch/pytorch#143464 As suggested in the comments, we moved the patterns to torchao repo. Differential Revision: D67758184
This pull request was exported from Phabricator. Differential Revision: D67758184 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D67758184 |
…n issues in torchao/float8 (pytorch#1469) Summary: The original pr in inductor, pytorch/pytorch#143464 As suggested in the comments, we moved the patterns to torchao repo. Differential Revision: D67758184
This pull request was exported from Phabricator. Differential Revision: D67758184 |
@pytorchbot label "topic: not user facing" |
…n issues in torchao/float8 (pytorch#1469) Summary: The original pr in inductor, pytorch/pytorch#143464 As suggested in the comments, we moved the patterns to torchao repo. Differential Revision: D67758184
…n issues in torchao/float8 (pytorch#1469) Summary: The original pr in inductor, pytorch/pytorch#143464 As suggested in the comments, we moved the patterns to torchao repo. Differential Revision: D67758184
This pull request was exported from Phabricator. Differential Revision: D67758184 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D67758184 |
…n issues in torchao/float8 (pytorch#1469) Summary: The original pr in inductor, pytorch/pytorch#143464 As suggested in the comments, we moved the patterns to torchao repo. Differential Revision: D67758184
This pull request was exported from Phabricator. Differential Revision: D67758184 |
…n issues in torchao/float8 (pytorch#1469) Summary: The original pr in inductor, pytorch/pytorch#143464 As suggested in the comments, we moved the patterns to torchao repo. Differential Revision: D67758184
This pull request was exported from Phabricator. Differential Revision: D67758184 |
@@ -36,6 +36,9 @@ | |||
WeightWithDynamicFloat8CastTensor, | |||
WeightWithStaticFloat8CastTensor, | |||
) | |||
from torchao.float8.inductor_utils import register_fp8_delayed_scaling_patterns | |||
|
|||
register_fp8_delayed_scaling_patterns() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer that we have an explicit API to enable this instead of unexpectedly registering a call on module import. How about something like below?
def _prototype_register_float8_delayed_scaling_inductor_passes():
"""
Note: this is a prototype API and is subject to change.
TODO writeme
"""
...
Then, we can mention this in the delayed scaling section of README.md
def register_fp8_delayed_scaling_patterns() -> bool: | ||
# To make the fp8 delayed scaling pattern work, we need a fix pr from inductor, https://github.com/pytorch/pytorch/pull/139321 | ||
# Added the try-catch block to ignore the failed pattern if the current torch verions doesn't include the fix pr. | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this check for something specific instead of a blanket try-catch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I modified to explicitly log out the error and ask the user to update pytorch version if the duplicate pattern error is caught. Calling register_fp8_delayed_scaling_patterns without the inductor fix pr will now get:
Caught duplicated patterns in register_fp8_delayed_scaling_patterns: Traceback (most recent call last):
File "/data/users/shuqiyang/fbsource/buck-out/v2/gen/fbcode/d22fb8d26889267e/scripts/shuqiyang/test_inductor/__test_float8__/test_float8#link-tree/torchao/float8/inductor_utils.py", line 100, in register_fp8_delayed_scaling_patterns
register_fp8_delayed_scaling_patterns_inner()
File "/data/users/shuqiyang/fbsource/buck-out/v2/gen/fbcode/d22fb8d26889267e/scripts/shuqiyang/test_inductor/__test_float8__/test_float8#link-tree/torchao/float8/inductor_utils.py", line 79, in register_fp8_delayed_scaling_patterns_inner
register_replacement(
File "/data/users/shuqiyang/fbsource/buck-out/v2/gen/fbcode/d22fb8d26889267e/scripts/shuqiyang/test_inductor/__test_float8__/test_float8#link-tree/torch/_inductor/pattern_matcher.py", line 1398, in register_replacement
assert pattern_repr not in _seen_patterns
AssertionError
Please update your pytorch dependency to the latest main branch to fix it.
I didn't find a good way to check whether a specific commit is in the pytorch package. If the commit is later released with a torch version, I can add the torch version check here instead of checking the exception message.
torchao/float8/inductor_utils.py
Outdated
) | ||
|
||
|
||
def run_once(f): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can delete this if you change the API to a function user has to call
torchao/float8/inductor_utils.py
Outdated
import torch | ||
|
||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max | ||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also handle torch.float8_e4m3fnuz
and torch.float8_e5m2fnuz
torchao/float8/inductor_utils.py
Outdated
|
||
|
||
@run_once | ||
def register_fp8_delayed_scaling_patterns() -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a docblock describing that this replaces max(x)
with max(max(x, dim=-1))
for cases relevant to delayed scaling
…n issues in torchao/float8 (pytorch#1469) Summary: The original pr in inductor, pytorch/pytorch#143464 As suggested in the comments, we moved the patterns to torchao repo. Differential Revision: D67758184
This pull request was exported from Phabricator. Differential Revision: D67758184 |
@vkuzo thanks for your comments, I update the pr to use an explicit API; add readme and docblock; add the two additional fp8 dtypes; add explicit check and hint for update pytorch version for the inductor fix pr. |
…n issues in torchao/float8 (pytorch#1469) Summary: The original pr in inductor, pytorch/pytorch#143464 As suggested in the comments, we moved the patterns to torchao repo. Differential Revision: D67758184
This pull request was exported from Phabricator. Differential Revision: D67758184 |
…n issues in torchao/float8 (pytorch#1469) Summary: The original pr in inductor, pytorch/pytorch#143464 As suggested in the comments, we moved the patterns to torchao repo. Differential Revision: D67758184
This pull request was exported from Phabricator. Differential Revision: D67758184 |
With the pattern replacement:
Baseline:
We can get 5% speedup by |
Summary:
The original pr in inductor, pytorch/pytorch#143464
As suggested in the comments, we moved the patterns to torchao repo.
Note that it needs the pytorch repo to contain this pr pytorch/pytorch#139321 to work.
Test Plans:
Run float8 training script. Amax and cast are fused in delayed scaling; dynamic scaling is not affected. Also tested with preceding ops.
The delayed scaling kernel also looks reasonable to me, https://fburl.com/phabricator/iqmlollk
A simple performance test: D67517255
(It needs some temporary changes to force recompute weight and activation in backward.)
Linear - 4096x4096x4096
0.7000332674418591
ms0.6707110309278329
ms4.4%
With layer_norm:
0.753488252631584
0.7374823118279574
2.2%
Witi sigmoid:
0.7153208260869579
0.6845765714285723
4.5%
Differential Revision: D67758184