Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fused transpose and non-transpose kernel and use it for grad output #1497

Merged
merged 33 commits into from
Jan 8, 2025

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jan 3, 2025

Stack Summary
The following stack of PRs completes a float8 training prototype with performance that slightly beats the production Float8Linear + torch.compile approach.

Changes in this PR
This PR implements a new kernel which reads in a row major high precision tensor, and writes 2 outputs: a fp8 row major output, and a transposed fp8 row major output. This is useful for the backward pass where we need to convert the grad_output to both of these formats.

Next steps:

  1. Add support for activation checkpointing
  2. Add test verifying this prototype is compatible with FSDP
  3. Implement usage of this prototype in torchtitan and benchmark results

Test Plan

  1. pytest kernels/ - kernel specific unit tests are passing
  2. pytest test/ - e2e training test is passing

Performance Benchmarking

Performance benchmarks show this implementation is beating torch.compile by 1.72-4.45% depending on the input tensor size:

input_shape    kernel_algo                 high_precision_dtype      eager_time    compiled_time    float8nocompile
-------------  --------------------------  ----------------------  ------------  ---------------  -----------------
(16, 4096)     KernelAlgorithm.ATOMIC_MAX  torch.bfloat16               649.218          394.725            386.469
(256, 4096)    KernelAlgorithm.ATOMIC_MAX  torch.bfloat16               685.783          420.743            408.137
(4096, 4096)   KernelAlgorithm.ATOMIC_MAX  torch.bfloat16              1829.13          1053.64             977.858
(65536, 4096)  KernelAlgorithm.ATOMIC_MAX  torch.bfloat16             21554.2          12369.7            10813.3
(16, 4096)     KernelAlgorithm.REDUCTION   torch.bfloat16               650.026          394.951            696.221
(256, 4096)    KernelAlgorithm.REDUCTION   torch.bfloat16               684.865          421.144            729.459
(4096, 4096)   KernelAlgorithm.REDUCTION   torch.bfloat16              1826.42          1050.85            1596.12
(65536, 4096)  KernelAlgorithm.REDUCTION   torch.bfloat16             21584.7          12347.2            17290

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Copy link

pytorch-bot bot commented Jan 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1497

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 01eedbf with merge base eb49333 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Jan 3, 2025
ghstack-source-id: b8ffafdd2ff8428643045b9e7fb9046a0eab22c7
ghstack-comment-id: 2569853196
Pull Request resolved: #1497
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jan 3, 2025
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 3, 2025
@danielvegamyhre danielvegamyhre requested a review from vkuzo January 3, 2025 21:50
[ghstack-poisoned]
danielvegamyhre added a commit that referenced this pull request Jan 3, 2025
ghstack-source-id: be4465cbef3e93fa415d1acf65d9a889043ead0d
ghstack-comment-id: 2569853196
Pull Request resolved: #1497
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@danielvegamyhre danielvegamyhre changed the base branch from gh/danielvegamyhre/14/head to main January 8, 2025 03:48
@danielvegamyhre danielvegamyhre merged commit 070345d into main Jan 8, 2025
44 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants