-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[float8] Re-enable slow-accum in the bwd of axis-wise scaling schemes #1325
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1325
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 6f4615b with merge base 1a0dbf1 (): BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
and b_scale.shape == (1, b_data.shape[1]) | ||
and not use_fast_accum | ||
): | ||
# The rowwise CUTLASS-based kernel is so slow without fast-accum that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious, do we have any OSS shareable evidence (perf/accuracy) on doing this versus rowwise with fast-accum off that we can add here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran a quick benchmark on my H100 with a recent-ish version of PyTorch (nightly from Nov 12). I samples all MxNxK matmul shapes where each of M, N and K is a power of two between 512 and 16384. Here I'm plotting the slowdowns observed when activating slow-accum for the rowwise (CUTLASS-based) and tensorwise (cuBLAS-based) modes
In summary: in tensorwise we get a max slowdown of 50% (usually much less), with rowwise we typically are 2x as slow, with peaks of 4.5x as slow as fast-accum.
(I suspect that for very small shapes the benchmark was CPU-bound hence slow-accum looks as fast as fast-accum, but that's probably misleading)
Landing since Ruff is already broken on main |
Superseded by #1377 |
Stack from ghstack (oldest at bottom):
And circumvent the issue with the slow CUTLASS kernel by using the cuBLAS kernel + manual scaling.