-
Notifications
You must be signed in to change notification settings - Fork 305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rescuing flash attention #386
Conversation
2634bea
to
f88143f
Compare
@Green-Sky -- would this enable flash attention for Vulkan builds as well? |
No, sadly not. Vulkan does not implement Would be cool if someone could try rocm and metal builds. |
I just tested SD3 (2B), and it appears the kv dimensions are not multiples of 256, so flash attention wont work on the mmdit without padding. |
I wanted to test the PR with the ROCM backend but I couldn't get it to build. Did you somehow add back the missing |
Oh no, sorry for the confusion, dont use the old define or cmake option, I did not add that back in. edit: I might push a change soon, where you can enable flash attention for diffusion models via a command line option. |
Oh interesting. Because I already tried a comparison between the current master and your PR applied without the cmake flag. And the results didn't look like it was enabled. Here are the results (no Flash Attention is how I marked that I didn't enable the flag). I used the flux1-schnell model using q8_0. Master - no Flash Attentiontotal params memory size = 21471.11MB (VRAM 12152.28MB, RAM 9318.83MB) PR - no Flash Attentiontotal params memory size = 21471.11MB (VRAM 12152.28MB, RAM 9318.83MB) |
@MineGame159 I see, so speed went down on rocm...
If an Also look for |
Oh sorry 😅, I am pretty new to sd.cpp and AI as a whole. And yes, I can see that it is using flash attention.
And the compute buffer size: |
And the resulting image look close to identical? |
I did. update: flash attention does not seem to improve VAE in anyway + it does not work on cuda, since cuda supports |
b0aecf4
to
e904b86
Compare
Would be nice if someone could test metal. |
f69fbfd
to
cc7efa2
Compare
imma leave @leejet to update this example, not sure where those numbers came from. |
c943406
to
90d420a
Compare
If someone wants to play with kv-padding, to enable more cases where flash attention can be used, here is a patch: diff --git a/ggml_extend.hpp b/ggml_extend.hpp
index 8452a0b..518eb6f 100644
--- a/ggml_extend.hpp
+++ b/ggml_extend.hpp
@@ -710,6 +710,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
float scale = (1.0f / sqrt((float)d_head));
+ int kv_pad = 0;
//if (flash_attn) {
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
//}
@@ -717,7 +718,14 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));
bool can_use_flash_attn = true;
- can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
+ if (can_use_flash_attn && L_k % 256 != 0) {
+ if (L_k == 77) {
+ kv_pad = 256 - (L_k % 256);
+ } else {
+ can_use_flash_attn = false;
+ }
+ }
+ //can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check
// cuda max d_head seems to be 256, cpu does seem to work with 512
@@ -734,11 +742,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
ggml_tensor* kqv = nullptr;
//GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
if (can_use_flash_attn && flash_attn) {
- //LOG_DEBUG("using flash attention");
+ LOG_DEBUG("using flash attention");
+ if (kv_pad != 0) {
+ LOG_DEBUG("padding kv by %d", kv_pad);
+ k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
+ }
k = ggml_cast(ctx, k, GGML_TYPE_F16);
v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
+ if (kv_pad != 0) {
+ v = ggml_pad(ctx, v, 0, kv_pad, 0, 0);
+ }
v = ggml_cast(ctx, v, GGML_TYPE_F16);
kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0); I wont be committing this here since the memory saving are minuscule, or it uses even more memory because of the padding. |
90d420a
to
947340b
Compare
rebased on master, did not yet look at sd3.5, it is disabled, same as sd3 for now. |
947340b
to
a06ec06
Compare
this does not fix the currently broken fa behind the define, which is only used by VAE Co-authored-by: FSSRepo <[email protected]>
no support for sd3 or video
a06ec06
to
cbf0489
Compare
Thank you for your contribution |
I ported over the flash attention code changes form #221 .
This does not yet fix the existing code behind the define, which broke when gg removed the code from ggml assuming noone was using it :)The old broken code behind the define also only gets invoked by the VAE.The new code gets used by UNET, DiT and clip/t5. However I have not come across any clip or t5 that just works without extra work, but there are sizable memory reductions for
sd1,sd2, sdxl and flux compute buffers, as well as a speed increase on cuda.I switched VAE over to using the new attention code, but flash attention makes no difference, so it is disabled there.
There is alot more left on the table if we employ padding, so flash attention can be applied to more ops.
Flash attention for diffusion models is available via a runtime flag
--diffusion-fa
.flux1-schnell performance numbers
all tests done on a debug build, so the real number are likely a bit better.
4 steps, cfg-scale of 1
CPU
512x512:
compute buffer size: 397MB -> 247MB
speed q3_k: 80.00s/it -> 83.38s/it
768x768:
compute buffer size: 1103MB -> 503MB
speed q3_k: 197.11s/it -> 209.16s/it
CUDA
512x512:
compute buffer size: 398MB -> 248MB
speed q3_k: 1.98s/it -> 1.69s/it
speed q4_k: 1.84s/it -> 1.54s/it
768x768:
compute buffer size: 1105MB -> 505MB
speed q3_k: 4.58s/it -> 3.58s/it
speed q4_k: OOM -> 3.29s/it
direct comparison of the cuda images:
difference, darker/colorful is worse:
SD2 turbo
8 steps, cfg-scale of 1
CPU
512x512:
compute buffer size: 367MB -> 130MB
speed q8_0: 6.40s/it -> 6.65s/it
768x768:
compute buffer size: 1718MB -> 294MB
speed q8_0: 20.59s/it -> 22.41s/it
CUDA
512x512:
compute buffer size: 367MB -> 130MB
speed q8_0: 6.24it/s -> 8.17it/s
768x768:
compute buffer size: 1718MB -> 295MB
speed q8_0: 1.84it/s -> 3.17it/s
direct comparison of the cuda images:
difference, darker/colorful is worse:
SDXL realvisxl_v50Lightning
6 steps, cfg-scale of 1.8
dpm++2mv2 karras
CPU
512x512:
compute buffer size: 131MB -> 131MB
speed q8_0: 15.04s/it -> 15.21s/it
768x768:
compute buffer size: 330MB -> 280MB
speed q8_0: 37.36s/it -> 39.30s/it
CUDA
512x512:
compute buffer size: 132MB -> 132MB
speed q8_0: 2.01it/s -> 2.36it/s
768x768:
compute buffer size: 331MB -> 280MB
speed q8_0: 1.23s/it -> 1.15s/it
direct comparison of the cuda images:
difference, darker/colorful is worse:
There still is the opportunity to pad some tensors to make them fit.
TODO
vaePlease test this code.
props to @FSSRepo for having the code laying around
fixes: #297
images:
sd2_turbo.zip
flux1-schnell-q4_k.zip
flux1-schnell-q3_k.zip
update: added sd2 and sdxl numbers
update2: see rocm tests down in the thread. it works but behaves similar to cpu
udpate3: added flash_attn param and expose via
--diffusion-fa
runtime flag for supported models.