-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to avoid re-compute mask #34
Comments
Whenever the Q_LEN or KV_LEN changes you will need to recompute the mask. It is possible to create a new smaller block mask from the larger version but you would need to manually do this from the internal tensor components of the BlockMask |
Thanks for your reply! |
We recently added indexing to the block_mask, context is decoding where you can create a larger block_mask and then slice into it + add the correct mask_mod: pytorch/pytorch@09a339f#diff-fdd6d17efe145eae3f8090031505ec062fc47ede339275a73c5e9e52c702dc91 |
@NonvolatileMemory You can also just create a larger block mask to start with, and then reuse that mask - we support passing in a blockmask that was defined for a larger sequence than you're currently calling it with. |
@Chillee |
Could you clarify which config would trigger the bug? Is it the blockmask defined as KV_LEN=8192, Q_LEN=8192, query passed in has a length of 1024 and k/v has a length of 8192? |
Hi, here is my source code from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE,
create_block_mask,
create_mask,
flex_attention,
)
import torch
from functools import lru_cache, partial
def block_mask(b, h, q_idx, kv_idx):
q_block = q_idx // 4
kv_block = kv_idx // 4
return q_block > kv_block
block_mask = create_block_mask(block_mask, B=None, H=None, Q_LEN=4096, KV_LEN=4096, _compile=True)
flex_attn = torch.compile(partial(flex_attention, block_mask=block_mask, enable_gqa=True))
import torch
import torch.nn.functional as F
def torch_mask(q_idx, kv_idx, block_size=4):
return q_idx // block_size > kv_idx // block_size
def diff(bsz=4, seq_len=128 * 20, d_head=128, num_heads=8, block_size=4):
# torch_attn
Q = torch.randn(bsz, num_heads, seq_len, d_head).cuda()
K = torch.randn(bsz, num_heads, seq_len, d_head).cuda()
V = torch.randn(bsz, num_heads, seq_len, d_head).cuda()
scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (Q.size(-1) ** 0.5)
q_idx = torch.arange(seq_len).view(-1, 1)
kv_idx = torch.arange(seq_len).view(1, -1)
mask = torch_mask(q_idx, kv_idx, block_size)[None, None, :, :].cuda()
scores = scores.masked_fill(~mask, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
torch_out = torch.matmul(attn_weights, V)
flex_out = flex_attn(Q, K, V)
return (flex_out[:, :, 16:] - torch_out[:, :, 16:]).max()
a = diff()
print(a)
# tensor(1.2792, device='cuda:0') |
Hi FlexAttention Team,
Thanks for your code.
I use flex attention to impl a fast io-aware streaming attention using this mask:
if I compile as the following:
It will give correct and when qlen, klen, vlen = 8192 or 1024, but wrong and when qlen=1024.
Do I need to rebuild the block mask? Or can I reuse the first 8192 one for every input less than 8192?
The text was updated successfully, but these errors were encountered: