Skip to content

Commit

Permalink
Fix select sdp for FA-2 (#56045)
Browse files Browse the repository at this point in the history
  • Loading branch information
umiswing authored Aug 9, 2023
1 parent 8d181e3 commit 08e46d6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/paddle/nn/functional/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _math_attention(


def _select_sdp_cuda(head_dim):
if head_dim <= 128:
if head_dim <= 256:
return "flash_attn"
else:
return "mem_efficient"
Expand Down
11 changes: 11 additions & 0 deletions test/legacy_test/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,17 @@ def setUp(self):
self.use_sdp_kernel = False


class TestFlashAttentionAPITest5(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (8, 1024, 16, 256)
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = False


class TestMathAttentionAPITest(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
Expand Down

0 comments on commit 08e46d6

Please sign in to comment.