diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index b9077896da6c5..ae2ddc2a921a2 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -81,7 +81,7 @@ def _math_attention( def _select_sdp_cuda(head_dim): - if head_dim <= 128: + if head_dim <= 256: return "flash_attn" else: return "mem_efficient" diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index 6490ce428381b..4088f60570f50 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -410,6 +410,17 @@ def setUp(self): self.use_sdp_kernel = False +class TestFlashAttentionAPITest5(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (8, 1024, 16, 256) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = False + + class TestMathAttentionAPITest(TestFlashAttentionAPI): def setUp(self): self.place = paddle.CUDAPlace(0)