Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rescuing flash attention #386

Merged
merged 6 commits into from
Nov 23, 2024
Merged

rescuing flash attention #386

merged 6 commits into from
Nov 23, 2024

Conversation

Green-Sky
Copy link
Contributor

@Green-Sky Green-Sky commented Sep 1, 2024

I ported over the flash attention code changes form #221 .
This does not yet fix the existing code behind the define, which broke when gg removed the code from ggml assuming noone was using it :)
The old broken code behind the define also only gets invoked by the VAE.
The new code gets used by UNET, DiT and clip/t5. However I have not come across any clip or t5 that just works without extra work, but there are sizable memory reductions for sd1, sd2, sdxl and flux compute buffers, as well as a speed increase on cuda.
I switched VAE over to using the new attention code, but flash attention makes no difference, so it is disabled there.
There is alot more left on the table if we employ padding, so flash attention can be applied to more ops.

Flash attention for diffusion models is available via a runtime flag --diffusion-fa.

flux1-schnell performance numbers

all tests done on a debug build, so the real number are likely a bit better.
4 steps, cfg-scale of 1

CPU

512x512:

compute buffer size: 397MB -> 247MB
speed q3_k: 80.00s/it -> 83.38s/it

768x768:

compute buffer size: 1103MB -> 503MB
speed q3_k: 197.11s/it -> 209.16s/it

CUDA

512x512:

compute buffer size: 398MB -> 248MB
speed q3_k: 1.98s/it -> 1.69s/it
speed q4_k: 1.84s/it -> 1.54s/it

768x768:

compute buffer size: 1105MB -> 505MB
speed q3_k: 4.58s/it -> 3.58s/it
speed q4_k: OOM -> 3.29s/it

direct comparison of the cuda images:

nofa fa
cuda_nofa_512x512 cuda_fa_512x512

difference, darker/colorful is worse:

error error^2
image image

SD2 turbo

8 steps, cfg-scale of 1

CPU

512x512:

compute buffer size: 367MB -> 130MB
speed q8_0: 6.40s/it -> 6.65s/it

768x768:

compute buffer size: 1718MB -> 294MB
speed q8_0: 20.59s/it -> 22.41s/it

CUDA

512x512:

compute buffer size: 367MB -> 130MB
speed q8_0: 6.24it/s -> 8.17it/s

768x768:

compute buffer size: 1718MB -> 295MB
speed q8_0: 1.84it/s -> 3.17it/s

direct comparison of the cuda images:

nofa fa
cuda_nofa_512x512 cuda_fa_512x512

difference, darker/colorful is worse:

error error^2
image image

SDXL realvisxl_v50Lightning

6 steps, cfg-scale of 1.8
dpm++2mv2 karras

CPU

512x512:

compute buffer size: 131MB -> 131MB
speed q8_0: 15.04s/it -> 15.21s/it

768x768:

compute buffer size: 330MB -> 280MB
speed q8_0: 37.36s/it -> 39.30s/it

CUDA

512x512:

compute buffer size: 132MB -> 132MB
speed q8_0: 2.01it/s -> 2.36it/s

768x768:

compute buffer size: 331MB -> 280MB
speed q8_0: 1.23s/it -> 1.15s/it

direct comparison of the cuda images:

nofa fa
cuda_nofa_512x512 cuda_fa_512x512

difference, darker/colorful is worse:

error error^2
image image

There still is the opportunity to pad some tensors to make them fit.

TODO

  • remove the define
  • add with a runtime switch
    • diffusion model
    • vae
  • more exhaustive testing with supported models
  • add more then just flux numbers to the op
  • add image comparisons
  • update docs

Please test this code.

props to @FSSRepo for having the code laying around

fixes: #297

images:
sd2_turbo.zip
flux1-schnell-q4_k.zip
flux1-schnell-q3_k.zip

update: added sd2 and sdxl numbers
update2: see rocm tests down in the thread. it works but behaves similar to cpu
udpate3: added flash_attn param and expose via --diffusion-fa runtime flag for supported models.

@Green-Sky Green-Sky changed the title rescue flash attention rescueing flash attention Sep 1, 2024
@Green-Sky Green-Sky changed the title rescueing flash attention rescuing flash attention Sep 1, 2024
@theaerotoad
Copy link

theaerotoad commented Sep 3, 2024

@Green-Sky -- would this enable flash attention for Vulkan builds as well?

@Green-Sky
Copy link
Contributor Author

Green-Sky commented Sep 4, 2024

@Green-Sky -- would this enable flash attention for Vulkan builds as well?

No, sadly not. Vulkan does not implement GGML_OP_FLASH_ATTN_EXT.
However, it looks like cuda, cuda built as rocm (and musa?) and metal all support it.

Would be cool if someone could try rocm and metal builds.

@Green-Sky
Copy link
Contributor Author

I just tested SD3 (2B), and it appears the kv dimensions are not multiples of 256, so flash attention wont work on the mmdit without padding.

@MineGame159
Copy link

I wanted to test the PR with the ROCM backend but I couldn't get it to build. Did you somehow add back the missing ggml_flash_attn function that is used in ggml_extend.hpp? I am not familiar with ggml so it would be nice if you could explain how you got the CUDA backend working.

@Green-Sky
Copy link
Contributor Author

Green-Sky commented Sep 6, 2024

I wanted to test the PR with the ROCM backend but I couldn't get it to build. Did you somehow add back the missing ggml_flash_attn function that is used in ggml_extend.hpp? I am not familiar with ggml so it would be nice if you could explain how you got the CUDA backend working.

Oh no, sorry for the confusion, dont use the old define or cmake option, I did not add that back in.
Just build it as-is and it has flash attention (for diffusion models only) enabled.
The old code that gets enabled with the define is what would get used by VAE, but I have not touched that part yet.

edit: I might push a change soon, where you can enable flash attention for diffusion models via a command line option.

@MineGame159
Copy link

Oh interesting. Because I already tried a comparison between the current master and your PR applied without the cmake flag. And the results didn't look like it was enabled. Here are the results (no Flash Attention is how I marked that I didn't enable the flag). I used the flux1-schnell model using q8_0.

Master - no Flash Attention

total params memory size = 21471.11MB (VRAM 12152.28MB, RAM 9318.83MB)
sampling completed, taking 149.66s - 18.69s/it

PR - no Flash Attention

total params memory size = 21471.11MB (VRAM 12152.28MB, RAM 9318.83MB)
sampling completed, taking 173.13s - 21.62s/it

@Green-Sky
Copy link
Contributor Author

@MineGame159 I see, so speed went down on rocm...
Also, run it with -v and look for

[DEBUG] ggml_extend.hpp:739  - attention_ext L_q:2304 L_k:2304 n_head:24 C:3072 d_head:128 N:1
[DEBUG] ggml_extend.hpp:763  - using flash attention

If an attention_ext line with tensor sizes is follows by using flash attention, then that specific attention was converted to a flash attention.

Also look for flux compute buffer size: 456.75 MB(VRAM) (in the spam with -v) to see how the compute buffer is affected. Total params size does not change with this patch. 😃

@MineGame159
Copy link

Oh sorry 😅, I am pretty new to sd.cpp and AI as a whole. And yes, I can see that it is using flash attention.

[DEBUG] ggml_extend.hpp:738  - attention_ext L_k:1280 n_head:24 C:3072 d_head:128
[DEBUG] ggml_extend.hpp:755  - using flash attention

And the compute buffer size:
Master - flux compute buffer size: 398.50 MB(VRAM)
PR - flux compute buffer size: 248.50 MB(VRAM)

@Green-Sky
Copy link
Contributor Author

And the resulting image look close to identical?
The numbers look good, the compute buffer size reduction is the same as on cuda. Try larger images :)

@MineGame159
Copy link

Tested with a 768x768 image and I can't see any difference. The size of the compute buffer is basically the same as on CUDA for you. Only thing is the decreased speed.

1105.07 MB(VRAM) -> 505.07 MB(VRAM)
41.66s/it -> 57.59s/it

master
pr

@Green-Sky
Copy link
Contributor Author

I diffed your images.

error error^2
image image

The error is even less then what I have with cuda.

Tested with a 768x768 image and I can't see any difference. The size of the compute buffer is basically the same as on CUDA for you. Only thing is the decreased speed.

Yea, good to know, thanks for testing.
In summery: use it if you use cuda and/or need the extra vram savings.

@Green-Sky
Copy link
Contributor Author

Green-Sky commented Sep 7, 2024

edit: I might push a change soon, where you can enable flash attention for diffusion models via a command line option.

I did.
I am now looking into vae.

update: flash attention does not seem to improve VAE in anyway + it does not work on cuda, since cuda supports d_head up to 256, but VAE needs 512. So I think I will not enable VAE flash attention, but still remove the old crusty code.

@Green-Sky Green-Sky marked this pull request as ready for review September 7, 2024 11:28
@Green-Sky
Copy link
Contributor Author

Would be nice if someone could test metal.

@Green-Sky
Copy link
Contributor Author

  • Accelerated memory-efficient CPU inference
    • Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image, enabling Flash Attention just requires ~1.8GB.

imma leave @leejet to update this example, not sure where those numbers came from.

@Green-Sky
Copy link
Contributor Author

If someone wants to play with kv-padding, to enable more cases where flash attention can be used, here is a patch:

diff --git a/ggml_extend.hpp b/ggml_extend.hpp
index 8452a0b..518eb6f 100644
--- a/ggml_extend.hpp
+++ b/ggml_extend.hpp
@@ -710,6 +710,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*

     float scale = (1.0f / sqrt((float)d_head));

+    int kv_pad = 0;
     //if (flash_attn) {
     //    LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
     //}
@@ -717,7 +718,14 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
     GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));

     bool can_use_flash_attn = true;
-    can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
+    if (can_use_flash_attn && L_k % 256 != 0) {
+        if (L_k == 77) {
+            kv_pad = 256 - (L_k % 256);
+        } else {
+            can_use_flash_attn = false;
+        }
+    }
+    //can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
     can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check

     // cuda max d_head seems to be 256, cpu does seem to work with 512
@@ -734,11 +742,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
     ggml_tensor* kqv = nullptr;
     //GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
     if (can_use_flash_attn && flash_attn) {
-        //LOG_DEBUG("using flash attention");
+        LOG_DEBUG("using flash attention");
+        if (kv_pad != 0) {
+            LOG_DEBUG("padding kv by %d", kv_pad);
+            k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
+        }
         k = ggml_cast(ctx, k, GGML_TYPE_F16);

         v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3));  // [N, n_head, L_k, d_head]
         v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N);  // [N * n_head, L_k, d_head]
+        if (kv_pad != 0) {
+            v = ggml_pad(ctx, v, 0, kv_pad, 0, 0);
+        }
         v = ggml_cast(ctx, v, GGML_TYPE_F16);

         kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0);

I wont be committing this here since the memory saving are minuscule, or it uses even more memory because of the padding.
Plus, the images look way off, so maybe there is a mistake somewhere or it's just the f16 cast in the wrong situation...

@Green-Sky
Copy link
Contributor Author

rebased on master, did not yet look at sd3.5, it is disabled, same as sd3 for now.
I still consider this ready for merge as-is, I will probably look at sd3.5 in the coming weeks.

Green-Sky and others added 3 commits November 12, 2024 13:03
this does not fix the currently broken fa behind the define, which is only used by VAE

Co-authored-by: FSSRepo <[email protected]>
@leejet
Copy link
Owner

leejet commented Nov 23, 2024

Thank you for your contribution

@leejet leejet merged commit 1c168d9 into leejet:master Nov 23, 2024
9 checks passed
@Green-Sky Green-Sky deleted the rescue_flash_attn branch November 23, 2024 12:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

stable-diffusion.cpp doesn't compile with flash attention enabled
4 participants