Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to avoid re-compute mask #34

Open
NonvolatileMemory opened this issue Sep 5, 2024 · 7 comments
Open

How to avoid re-compute mask #34

NonvolatileMemory opened this issue Sep 5, 2024 · 7 comments

Comments

@NonvolatileMemory
Copy link

Hi FlexAttention Team,

Thanks for your code.

I use flex attention to impl a fast io-aware streaming attention using this mask:

def sliding_window_causal_with_stream(b, h, q_idx, kv_idx):
    # Causal mask ensures no future positions are attended
    causal_mask = q_idx >= kv_idx
    # Sliding window mask restricts the view within a window size
    window_mask = q_idx - kv_idx <= 256 
    # Stream mask ensures that for q_idx >= 4, kv_idx <= 4 is always visible
    stream_mask = (q_idx >= 4) & (kv_idx <= 4)
    # Combine all masks: sliding window and causal mask, or stream mask
    return (causal_mask & window_mask) | stream_mask

if I compile as the following:

block_mask = create_block_mask(sliding_window_causal_with_stream, B=None, H=None, Q_LEN=8192, KV_LEN=8192,  _compile=True)
flex_attention = torch.compile(partial(flex_attention, block_mask=block_mask, enable_gqa=True))

It will give correct and when qlen, klen, vlen = 8192 or 1024, but wrong and when qlen=1024.

Do I need to rebuild the block mask? Or can I reuse the first 8192 one for every input less than 8192?

@drisspg
Copy link
Contributor

drisspg commented Sep 5, 2024

Whenever the Q_LEN or KV_LEN changes you will need to recompute the mask. It is possible to create a new smaller block mask from the larger version but you would need to manually do this from the internal tensor components of the BlockMask

@NonvolatileMemory
Copy link
Author

Thanks for your reply!
In my view lots of masks are qlen and kvlen agnostic, maybe it is better to define a regenerate_mask given the new_input_len and max_len_block_mask.

@drisspg
Copy link
Contributor

drisspg commented Sep 6, 2024

We recently added indexing to the block_mask, context is decoding where you can create a larger block_mask and then slice into it + add the correct mask_mod: pytorch/pytorch@09a339f#diff-fdd6d17efe145eae3f8090031505ec062fc47ede339275a73c5e9e52c702dc91

@Chillee
Copy link
Contributor

Chillee commented Sep 9, 2024

@NonvolatileMemory You can also just create a larger block mask to start with, and then reuse that mask - we support passing in a blockmask that was defined for a larger sequence than you're currently calling it with.

@NonvolatileMemory
Copy link
Author

@Chillee
I think it will cause bug. I cannot pass allclose when input size is 4096 but mask is defined by 8192

@joydddd
Copy link

joydddd commented Sep 10, 2024

It will give correct and when qlen, klen, vlen = 8192 or 1024, but wrong and when qlen=1024.

Could you clarify which config would trigger the bug? Is it the blockmask defined as KV_LEN=8192, Q_LEN=8192,

query passed in has a length of 1024 and k/v has a length of 8192?

@NonvolatileMemory
Copy link
Author

It will give correct and when qlen, klen, vlen = 8192 or 1024, but wrong and when qlen=1024.

Could you clarify which config would trigger the bug? Is it the blockmask defined as KV_LEN=8192, Q_LEN=8192,

query passed in has a length of 1024 and k/v has a length of 8192?

Hi, here is my source code

from torch.nn.attention.flex_attention import (
    _DEFAULT_SPARSE_BLOCK_SIZE,
    create_block_mask,
    create_mask,
    flex_attention,
)
import torch
from functools import lru_cache, partial

def block_mask(b, h, q_idx, kv_idx):
    q_block = q_idx // 4
    kv_block = kv_idx // 4
    return q_block > kv_block
block_mask = create_block_mask(block_mask, B=None, H=None, Q_LEN=4096, KV_LEN=4096,  _compile=True)
flex_attn = torch.compile(partial(flex_attention, block_mask=block_mask, enable_gqa=True))

import torch
import torch.nn.functional as F

def torch_mask(q_idx, kv_idx, block_size=4):
    return q_idx // block_size > kv_idx // block_size


def diff(bsz=4, seq_len=128 * 20, d_head=128, num_heads=8, block_size=4):
    # torch_attn

    Q = torch.randn(bsz, num_heads, seq_len, d_head).cuda()
    K = torch.randn(bsz, num_heads, seq_len, d_head).cuda()
    V = torch.randn(bsz, num_heads, seq_len, d_head).cuda()

    scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (Q.size(-1) ** 0.5)

    q_idx = torch.arange(seq_len).view(-1, 1)
    kv_idx = torch.arange(seq_len).view(1, -1)
    mask = torch_mask(q_idx, kv_idx, block_size)[None, None, :, :].cuda()

    scores = scores.masked_fill(~mask, float('-inf'))
    attn_weights = F.softmax(scores, dim=-1)
    torch_out = torch.matmul(attn_weights, V)
    flex_out = flex_attn(Q, K, V)
    return (flex_out[:, :, 16:] - torch_out[:, :, 16:]).max()
a = diff()
print(a)
# tensor(1.2792, device='cuda:0')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants