Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8] Re-enable slow-accum in the bwd of axis-wise scaling schemes #1325

Merged
merged 3 commits into from
Dec 4, 2024

Conversation

lw
Copy link
Contributor

@lw lw commented Nov 22, 2024

Stack from ghstack (oldest at bottom):

And circumvent the issue with the slow CUTLASS kernel by using the cuBLAS kernel + manual scaling.

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Nov 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1325

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 6f4615b with merge base 1a0dbf1 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 22, 2024
@lw lw added the topic: performance Use this tag if this PR improves the performance of a feature label Nov 22, 2024
[ghstack-poisoned]
@lw lw requested a review from vkuzo November 26, 2024 17:20
and b_scale.shape == (1, b_data.shape[1])
and not use_fast_accum
):
# The rowwise CUTLASS-based kernel is so slow without fast-accum that
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, do we have any OSS shareable evidence (perf/accuracy) on doing this versus rowwise with fast-accum off that we can add here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2024-12-03 at 19 09 19

I ran a quick benchmark on my H100 with a recent-ish version of PyTorch (nightly from Nov 12). I samples all MxNxK matmul shapes where each of M, N and K is a power of two between 512 and 16384. Here I'm plotting the slowdowns observed when activating slow-accum for the rowwise (CUTLASS-based) and tensorwise (cuBLAS-based) modes

In summary: in tensorwise we get a max slowdown of 50% (usually much less), with rowwise we typically are 2x as slow, with peaks of 4.5x as slow as fast-accum.

(I suspect that for very small shapes the benchmark was CPU-bound hence slow-accum looks as fast as fast-accum, but that's probably misleading)

[ghstack-poisoned]
@lw lw marked this pull request as ready for review December 4, 2024 13:51
@lw
Copy link
Contributor Author

lw commented Dec 4, 2024

Landing since Ruff is already broken on main

@lw lw merged commit 6f4615b into gh/lw/1/base Dec 4, 2024
29 of 31 checks passed
@lw
Copy link
Contributor Author

lw commented Dec 4, 2024

Superseded by #1377

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: performance Use this tag if this PR improves the performance of a feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants