Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8nocompile] Add alternate Triton kernels for FP8 conversion which use atomic_max-based algo instead of reduction-based algo #1455

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Dec 20, 2024

Summary

  • Add new Triton kernels to use a different approach for calculating the global amax
  • Have a single shared global amax tensor of size 1
  • All blocks compute local amax then perform a thread-safe max operation tl.atomic_max with the shared global max tensor to find the true global max once the kernel is complete
  • The advantage of this strategy is we don't have to allocate a shared buffer of size num_elements // BLOCK_SIZE outside of the kernels, which prevents us from autotuning the block size since we need to know it ahead of time to allocate the buffer.

Also made the algorithm used configurable and added unit tests for both.

Test plan

  • pytest test/test.py - Passing
  • python3 benchmark/benchmark.py yields large performance improvement as seen below. It's now markedly better than production path eager execution but not quite as good as compiled yet.

Performance of reduction based kernel with un-tuned block size of 8:

  input_size  high_precision_dtype      eager_time    compiled_time    float8nocompile
-------------  ----------------------  ------------  ---------------  -----------------
 65500         torch.float32                599.299          298.101    94446
 65500         torch.bfloat16               649.674          394.535    94386.3
     1.05e+06  torch.float32                640.5            332.171   104449
     1.05e+06  torch.bfloat16               685.822          421.365   104372
     1.68e+07  torch.float32               1963.09          1214.32    280825
     1.68e+07  torch.bfloat16              1828.16          1051.67    261710
     2.68e+08  torch.float32              24129.8          16287.2          3.39791e+06
     2.68e+08  torch.bfloat16             21603.2          12389.9          3.39515e+06

Performance of atomic max based kernel with auto-tuned block size:

  input_size  high_precision_dtype      eager_time    compiled_time    float8nocompile
------------  ----------------------  ------------  ---------------  -----------------
65500         torch.float32                599.055          298.543            372.018
65500         torch.bfloat16               649.523          394.457            413.137
    1.05e+06  torch.float32                640.497          332.419            413.503
    1.05e+06  torch.bfloat16               685.584          421.296            453.472
    1.68e+07  torch.float32               1963.72          1215.19            1415.5
    1.68e+07  torch.bfloat16              1829.28          1051.86            1298.55
    2.68e+08  torch.float32              24126.2          16294.3            19124.8
    2.68e+08  torch.bfloat16             21592.5          12390.6            16485.7

Copy link

pytorch-bot bot commented Dec 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1455

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1c06b47 with merge base 3bac905 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 20, 2024
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Dec 20, 2024
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I'm not sure I would completely throw away the old though since I think tuning the split reduction could get us a lot of performance and is deterministic which is nice

@danielvegamyhre
Copy link
Contributor Author

Looks good. I'm not sure I would completely throw away the old though since I think tuning the split reduction could get us a lot of performance and is deterministic which is nice

Hmm ok I will make the strategy used configurable.

@danielvegamyhre danielvegamyhre changed the title [float8nocompile] Refactor Triton kernels so triton.autotune is easily usable, and autotune the block size [float8nocompile] Add alternate Triton kernels for FP8 conversion which use atomic max algo instead of reduction based algo Dec 20, 2024
@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Dec 20, 2024

Looks good. I'm not sure I would completely throw away the old though since I think tuning the split reduction could get us a lot of performance and is deterministic which is nice

@drisspg I updated the PR to keep both implementations and make the kernel algorithm used configurable, and updated unit tests to exercise both paths.

One interesting thing I noticed though is that the atomic max strategy was failing non-deterministically, I had to add a call torch.cuda.synchronize() after the global amax kernel to fix this.

@danielvegamyhre danielvegamyhre changed the title [float8nocompile] Add alternate Triton kernels for FP8 conversion which use atomic max algo instead of reduction based algo [float8nocompile] Add alternate Triton kernels for FP8 conversion which use atomic_max-based algo instead of reduction-based algo Dec 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants