Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a register_replacement to fix float8 delayed scaling kernel fusion issues in torchao/float8 #1469

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

y-sq
Copy link
Contributor

@y-sq y-sq commented Jan 2, 2025

Summary:
The original pr in inductor, pytorch/pytorch#143464

As suggested in the comments, we moved the patterns to torchao repo.
Note that it needs the pytorch repo to contain this pr pytorch/pytorch#139321 to work.

Test Plans:
Run float8 training script. Amax and cast are fused in delayed scaling; dynamic scaling is not affected. Also tested with preceding ops.

The delayed scaling kernel also looks reasonable to me, https://fburl.com/phabricator/iqmlollk

A simple performance test: D67517255
(It needs some temporary changes to force recompute weight and activation in backward.)

  • I used the "recompute weight + activation in backward" case to test, as it will likely become the default choice.
  • "recompute weight" is enabled by "force_recompute_fp8_weight_in_bwd=enable_activation_checkpointing". It might be later changed to other options. So I made the changes in a local testing scripts for now.
TORCH_LOGS="fusion" TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1  buck run @mode/opt scripts/shuqiyang/test_inductor:test_float8 --  ~/local/tmp/20241120_test --dtype_filter float8 --scaling_type_input delayed --scaling_type_weight delayed --scaling_type_grad_output delayed --enable_activation_checkpointing True  2>&1 | tee ~/test_compile_b.txt

Linear - 4096x4096x4096

  • Before: 0.7000332674418591 ms
  • After: 0.6707110309278329 ms
  • Delta: 4.4%

With layer_norm:

  • Before: 0.753488252631584
  • After: 0.7374823118279574
  • Delta: 2.2%

Witi sigmoid:

  • Before: 0.7153208260869579
  • After: 0.6845765714285723
  • Delta: 4.5%

Differential Revision: D67758184

Copy link

pytorch-bot bot commented Jan 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1469

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a3b1245 with merge base f7f20e9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 2, 2025
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D67758184

@y-sq y-sq force-pushed the export-D67758184 branch from b96a383 to 4b00505 Compare January 2, 2025 09:40
y-sq added a commit to y-sq/ao that referenced this pull request Jan 2, 2025
…n issues in torchao/float8 (pytorch#1469)

Summary:

The original pr in inductor, pytorch/pytorch#143464

As suggested in the comments, we moved the patterns to torchao repo.

Differential Revision: D67758184
y-sq added a commit to y-sq/ao that referenced this pull request Jan 2, 2025
…n issues in torchao/float8 (pytorch#1469)

Summary:

The original pr in inductor, pytorch/pytorch#143464

As suggested in the comments, we moved the patterns to torchao repo.

Differential Revision: D67758184
@y-sq y-sq force-pushed the export-D67758184 branch from 4b00505 to 4b2dfaf Compare January 2, 2025 09:40
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D67758184

1 similar comment
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D67758184

y-sq added a commit to y-sq/ao that referenced this pull request Jan 2, 2025
…n issues in torchao/float8 (pytorch#1469)

Summary:

The original pr in inductor, pytorch/pytorch#143464

As suggested in the comments, we moved the patterns to torchao repo.

Differential Revision: D67758184
@y-sq y-sq force-pushed the export-D67758184 branch from 4b2dfaf to db11bad Compare January 2, 2025 09:45
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D67758184

@y-sq
Copy link
Contributor Author

y-sq commented Jan 2, 2025

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jan 2, 2025
y-sq added a commit to y-sq/ao that referenced this pull request Jan 2, 2025
…n issues in torchao/float8 (pytorch#1469)

Summary:

The original pr in inductor, pytorch/pytorch#143464

As suggested in the comments, we moved the patterns to torchao repo.

Differential Revision: D67758184
@y-sq y-sq force-pushed the export-D67758184 branch from db11bad to 3d28e67 Compare January 2, 2025 22:12
y-sq added a commit to y-sq/ao that referenced this pull request Jan 2, 2025
…n issues in torchao/float8 (pytorch#1469)

Summary:

The original pr in inductor, pytorch/pytorch#143464

As suggested in the comments, we moved the patterns to torchao repo.

Differential Revision: D67758184
@y-sq y-sq force-pushed the export-D67758184 branch from 3d28e67 to 5605907 Compare January 2, 2025 22:13
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D67758184

1 similar comment
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D67758184

@y-sq y-sq requested review from eellison and drisspg January 2, 2025 22:13
y-sq added a commit to y-sq/ao that referenced this pull request Jan 2, 2025
…n issues in torchao/float8 (pytorch#1469)

Summary:

The original pr in inductor, pytorch/pytorch#143464

As suggested in the comments, we moved the patterns to torchao repo.

Differential Revision: D67758184
@y-sq y-sq force-pushed the export-D67758184 branch from 5605907 to eb590ce Compare January 2, 2025 22:20
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D67758184

y-sq added a commit to y-sq/ao that referenced this pull request Jan 2, 2025
…n issues in torchao/float8 (pytorch#1469)

Summary:

The original pr in inductor, pytorch/pytorch#143464

As suggested in the comments, we moved the patterns to torchao repo.

Differential Revision: D67758184
@y-sq y-sq force-pushed the export-D67758184 branch from eb590ce to 5f1e60f Compare January 2, 2025 22:21
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D67758184

@y-sq y-sq requested a review from vkuzo January 2, 2025 22:26
@y-sq
Copy link
Contributor Author

y-sq commented Jan 2, 2025

The failed tests in Run Regression Tests / test-nightly (CUDA Nightly, linux.g5.12xlarge.nvidia.gpu, --pre torch are from non-float8 folders, seems unrelated to this pr?

Screenshot 2025-01-02 at 3 45 24 PM

@@ -36,6 +36,9 @@
WeightWithDynamicFloat8CastTensor,
WeightWithStaticFloat8CastTensor,
)
from torchao.float8.inductor_utils import register_fp8_delayed_scaling_patterns

register_fp8_delayed_scaling_patterns()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer that we have an explicit API to enable this instead of unexpectedly registering a call on module import. How about something like below?

def _prototype_register_float8_delayed_scaling_inductor_passes():
    """
    Note: this is a prototype API and is subject to change.
    TODO writeme
    """
    ...

Then, we can mention this in the delayed scaling section of README.md

def register_fp8_delayed_scaling_patterns() -> bool:
# To make the fp8 delayed scaling pattern work, we need a fix pr from inductor, https://github.com/pytorch/pytorch/pull/139321
# Added the try-catch block to ignore the failed pattern if the current torch verions doesn't include the fix pr.
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this check for something specific instead of a blanket try-catch?

Copy link
Contributor Author

@y-sq y-sq Jan 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified to explicitly log out the error and ask the user to update pytorch version if the duplicate pattern error is caught. Calling register_fp8_delayed_scaling_patterns without the inductor fix pr will now get:

Caught duplicated patterns in register_fp8_delayed_scaling_patterns: Traceback (most recent call last):
  File "/data/users/shuqiyang/fbsource/buck-out/v2/gen/fbcode/d22fb8d26889267e/scripts/shuqiyang/test_inductor/__test_float8__/test_float8#link-tree/torchao/float8/inductor_utils.py", line 100, in register_fp8_delayed_scaling_patterns
    register_fp8_delayed_scaling_patterns_inner()
  File "/data/users/shuqiyang/fbsource/buck-out/v2/gen/fbcode/d22fb8d26889267e/scripts/shuqiyang/test_inductor/__test_float8__/test_float8#link-tree/torchao/float8/inductor_utils.py", line 79, in register_fp8_delayed_scaling_patterns_inner
    register_replacement(
  File "/data/users/shuqiyang/fbsource/buck-out/v2/gen/fbcode/d22fb8d26889267e/scripts/shuqiyang/test_inductor/__test_float8__/test_float8#link-tree/torch/_inductor/pattern_matcher.py", line 1398, in register_replacement
    assert pattern_repr not in _seen_patterns
AssertionError
 
Please update your pytorch dependency to the latest main branch to fix it.

I didn't find a good way to check whether a specific commit is in the pytorch package. If the commit is later released with a torch version, I can add the torch version check here instead of checking the exception message.

)


def run_once(f):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can delete this if you change the API to a function user has to call

import torch

E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also handle torch.float8_e4m3fnuz and torch.float8_e5m2fnuz



@run_once
def register_fp8_delayed_scaling_patterns() -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a docblock describing that this replaces max(x) with max(max(x, dim=-1)) for cases relevant to delayed scaling

y-sq added a commit to y-sq/ao that referenced this pull request Jan 3, 2025
…n issues in torchao/float8 (pytorch#1469)

Summary:

The original pr in inductor, pytorch/pytorch#143464

As suggested in the comments, we moved the patterns to torchao repo.

Differential Revision: D67758184
@y-sq y-sq force-pushed the export-D67758184 branch from 5f1e60f to 92b82d2 Compare January 3, 2025 23:40
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D67758184

@y-sq
Copy link
Contributor Author

y-sq commented Jan 3, 2025

@vkuzo thanks for your comments, I update the pr to use an explicit API; add readme and docblock; add the two additional fp8 dtypes; add explicit check and hint for update pytorch version for the inductor fix pr.

y-sq added a commit to y-sq/ao that referenced this pull request Jan 3, 2025
…n issues in torchao/float8 (pytorch#1469)

Summary:

The original pr in inductor, pytorch/pytorch#143464

As suggested in the comments, we moved the patterns to torchao repo.

Differential Revision: D67758184
@y-sq y-sq force-pushed the export-D67758184 branch from 92b82d2 to d798cae Compare January 3, 2025 23:44
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D67758184

…n issues in torchao/float8 (pytorch#1469)

Summary:

The original pr in inductor, pytorch/pytorch#143464

As suggested in the comments, we moved the patterns to torchao repo.

Differential Revision: D67758184
@y-sq y-sq force-pushed the export-D67758184 branch from 63d354f to a3b1245 Compare January 4, 2025 01:03
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D67758184

@y-sq
Copy link
Contributor Author

y-sq commented Jan 4, 2025

With the pattern replacement:

TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1 TORCH_LOGS="fusion"  python benchmarks/float8/profile_linear_float8.py ~/tmp/log_delayed_scaling_3 --dtype_filter float8 --scaling_type_input delayed --scaling_type_weight delayed --scaling_type_grad_output delayed --enable_activation_checkpointing True --add_inductor_metadata_to_trace False --enable_float8_delayed_scaling_inductor_passes True
Summary of time (ms) by kernel category

 experiment     1_float8
category               
0_gemm            0.353
1_f8_overhead     0.161
2_other           0.028
All               0.542

Baseline:

--enable_float8_delayed_scaling_inductor_passes True
Summary of time (ms) by kernel category

 experiment     1_float8
category               
0_gemm            0.349
1_f8_overhead     0.198
2_other           0.022
All               0.570

We can get 5% speedup by --enable_float8_delayed_scaling_inductor_passes True.

@y-sq y-sq requested a review from vkuzo January 4, 2025 07:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants