From 08e46d6f1427a6b4c213a66c91a05f52b1219b40 Mon Sep 17 00:00:00 2001 From: umiswing Date: Wed, 9 Aug 2023 14:20:55 +0800 Subject: [PATCH] Fix select sdp for FA-2 (#56045) --- python/paddle/nn/functional/flash_attention.py | 2 +- test/legacy_test/test_flash_attention.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index b9077896da6c5..ae2ddc2a921a2 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -81,7 +81,7 @@ def _math_attention( def _select_sdp_cuda(head_dim): - if head_dim <= 128: + if head_dim <= 256: return "flash_attn" else: return "mem_efficient" diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index 6490ce428381b..4088f60570f50 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -410,6 +410,17 @@ def setUp(self): self.use_sdp_kernel = False +class TestFlashAttentionAPITest5(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (8, 1024, 16, 256) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = False + + class TestMathAttentionAPITest(TestFlashAttentionAPI): def setUp(self): self.place = paddle.CUDAPlace(0)