Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue]: Backward performance #44

Open
netw0rkf10w opened this issue Feb 19, 2024 · 1 comment
Open

[Issue]: Backward performance #44

netw0rkf10w opened this issue Feb 19, 2024 · 1 comment

Comments

@netw0rkf10w
Copy link

Problem Description

I have recently shared in the ROCm triton repo some benchmarking results comparing different implementation of attention on an MI250x.

It is worth pointing out that the implementation in this repo is faster (on average) than the triton counterpart in the forward pass, but is substantially slower in the backward pass. The detailed results are below FYI.

Forward pass:

[-------------------------------------- FlashAttention -------------------------------------]
                                           |    math   |   torch   |   flash   |  triton-rocm
1 threads: ----------------------------------------------------------------------------------
      [torch.float16, 12, 64, 256, 64]     |    873.2  |    740.7  |    231.9  |      507.0  
      [torch.float16, 12, 64, 256, 128]    |   1644.3  |   1461.8  |    457.6  |      503.1  
      [torch.float16, 16, 64, 256, 64]     |   1105.1  |    984.8  |    308.3  |      354.6  
      [torch.float16, 16, 64, 256, 128]    |   2190.2  |   1947.9  |    612.8  |      674.2  
      [torch.float16, 16, 64, 576, 64]     |   6015.5  |   5042.3  |   1824.6  |     1769.8  
      [torch.float16, 16, 64, 576, 128]    |  11998.8  |  10062.8  |   3624.3  |     3530.4  
      [torch.float16, 16, 128, 256, 64]    |   2245.4  |   2250.3  |    701.8  |      620.3  
      [torch.float16, 16, 128, 256, 128]   |   4473.6  |   4477.9  |   1375.1  |     1158.3  
      [torch.float16, 16, 128, 784, 64]    |  18274.6  |  16723.8  |   6627.9  |     5680.3  
      [torch.float16, 16, 128, 784, 128]   |  36563.9  |  33427.6  |  13137.0  |    11261.6  
      [torch.bfloat16, 12, 64, 256, 64]    |   1022.3  |    883.0  |    221.9  |      506.1  
      [torch.bfloat16, 12, 64, 256, 128]   |   1933.3  |   1743.4  |    433.8  |     1008.2  
      [torch.bfloat16, 16, 64, 256, 64]    |   1296.0  |   1170.2  |    295.3  |      692.7  
      [torch.bfloat16, 16, 64, 256, 128]   |   2572.6  |   2319.1  |    575.3  |     1353.5  
      [torch.bfloat16, 16, 64, 576, 64]    |   6024.4  |   4971.1  |   1740.2  |     2964.6  
      [torch.bfloat16, 16, 64, 576, 128]   |  12066.4  |  10047.1  |   3426.3  |     5931.7  
      [torch.bfloat16, 16, 128, 256, 64]   |   2196.4  |   2199.6  |    684.5  |      983.3  
      [torch.bfloat16, 16, 128, 256, 128]  |   4379.6  |   4384.1  |   1338.3  |     1886.1  
      [torch.bfloat16, 16, 128, 784, 64]   |  16330.0  |  14716.1  |   6473.6  |    10093.0  
      [torch.bfloat16, 16, 128, 784, 128]  |  32643.9  |  29421.3  |  12839.6  |    19931.2  

Times are in microseconds (us).

Forward pass followed by backward:

[----------------------------------- FlashAttention ----------------------------------]
                                           |   math  |  torch  |  flash  |  triton-rocm
1 threads: ----------------------------------------------------------------------------
      [torch.float16, 12, 64, 256, 64]     |    3.0  |    2.8  |    1.3  |       1.1   
      [torch.float16, 12, 64, 256, 128]    |    5.9  |    5.5  |    2.6  |       1.3   
      [torch.float16, 16, 64, 256, 64]     |    4.0  |    3.7  |    1.8  |       1.2   
      [torch.float16, 16, 64, 256, 128]    |    7.9  |    7.4  |    3.5  |       1.7   
      [torch.float16, 16, 64, 576, 64]     |   16.7  |   14.8  |    8.2  |       3.6   
      [torch.float16, 16, 64, 576, 128]    |   33.4  |   29.4  |   16.2  |       7.0   
      [torch.float16, 16, 128, 256, 64]    |    7.5  |    7.5  |    3.5  |       2.5   
      [torch.float16, 16, 128, 256, 128]   |   14.9  |   14.9  |    6.9  |       3.7   
      [torch.float16, 16, 128, 784, 64]    |   54.6  |   51.5  |   26.7  |      14.6   
      [torch.float16, 16, 128, 784, 128]   |  109.1  |  102.9  |   52.9  |      27.2   
      [torch.bfloat16, 12, 64, 256, 64]    |    3.4  |    3.2  |    1.6  |       1.3   
      [torch.bfloat16, 12, 64, 256, 128]   |    6.7  |    6.3  |    3.1  |       1.9   
      [torch.bfloat16, 16, 64, 256, 64]    |    4.5  |    4.2  |    2.1  |       1.6   
      [torch.bfloat16, 16, 64, 256, 128]   |    8.9  |    8.4  |    4.1  |       2.5   
      [torch.bfloat16, 16, 64, 576, 64]    |   16.6  |   14.6  |   10.2  |       5.0   
      [torch.bfloat16, 16, 64, 576, 128]   |   33.3  |   29.1  |   20.2  |       9.8   
      [torch.bfloat16, 16, 128, 256, 64]   |    7.5  |    7.5  |    4.3  |       3.0   
      [torch.bfloat16, 16, 128, 256, 128]  |   15.0  |   15.0  |    8.5  |       4.7   
      [torch.bfloat16, 16, 128, 784, 64]   |   57.5  |   54.2  |   34.6  |      17.9   
      [torch.bfloat16, 16, 128, 784, 128]  |  114.7  |  108.2  |   68.7  |      33.6   

Times are in milliseconds (ms).

Operating System

Red Hat Enterprise Linux 8.8

CPU

AMD EPYC 7A53 64-Core Processor

GPU

AMD Instinct MI250X

ROCm Version

ROCm 5.7.1

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

@carlushuang
Copy link
Collaborator

There is a WIP refactor of this repo using newly developed ck_tile from composable_kernel, which will bring in speed up on both MI200/300 cases.
Stay tuned for later update

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants