Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flex attention - gaps in profiler #76

Open
tugot17 opened this issue Nov 11, 2024 · 7 comments
Open

Flex attention - gaps in profiler #76

tugot17 opened this issue Nov 11, 2024 · 7 comments

Comments

@tugot17
Copy link

tugot17 commented Nov 11, 2024

Repost from the PyTorch forum

I have recently been playing with Flex attention, trying to replace some of my custom triton kernels. Unfortunately Flex attention ones were substantially (around 5x) slower than the custom one I tried. I use something similar to the Tri Dao flash attention kernel, but with bias. I started looking at the torch profile for a single forward pass on the complied version after the model warmup and it looks very strange.

import torch
from typing import Optional
from torch.nn.attention.flex_attention import (
    create_block_mask,
    flex_attention,
)
from torch.nn.attention.flex_attention import _mask_mod_signature
import random
from functools import partial
from triton.testing import do_bench
from torch.profiler import profile, ProfilerActivity, record_function

torch._dynamo.config.cache_size_limit = 1000
flex_attention = torch.compile(flex_attention, dynamic=False)

def flexattention_varlen_bias(
    query: torch.Tensor,  # shape (total_tokens_q, n_heads, head_dim)
    key: torch.Tensor,    # shape (total_tokens_k, n_heads, head_dim)
    value: torch.Tensor,  # shape (total_tokens_v, n_heads, head_dim)
    *,
    bias: Optional[torch.Tensor] = None,  # shape (total_tokens_q,)
    document_id: torch.Tensor,  # shape (total_tokens,)
    use_causal_mask: bool = False,
) -> torch.Tensor:
    device = query.device 
    _, _, total_tokens, _ =  query.size()

    def additive_attention_fn(score, b, h, q_idx, kv_idx):
        return score + bias[kv_idx]

    def document_causal_mask(b, h, q_idx, kv_idx):
        causal_mask = q_idx >= kv_idx if use_causal_mask else True
        document_mask = document_id[q_idx] == document_id[kv_idx]
        return causal_mask & document_mask 

    block_mask = create_block_mask(
        document_causal_mask,
        1,
        1,
        total_tokens,
        total_tokens,
        device=device,
    )

    return flex_attention(
        query,
        key,
        value,
        block_mask=block_mask,
        score_mod=additive_attention_fn if bias is not None else None
    )


if __name__ == "__main__":
    random.seed(0)
    device = "cuda"
    dtype = torch.float16
    causal = True
    max_seq_len, doc_count = 2048, 8
    nheads = 32
    HEAD_DIM = 128

    def generate_random_lengths(total_length, num_documents):
        lengths = [1] * num_documents
        remaining_length = total_length - num_documents
        for _ in range(remaining_length):
            index = random.randint(0, num_documents - 1)
            lengths[index] += 1
        return lengths

    lengths = generate_random_lengths(max_seq_len, doc_count)

    document_id = torch.zeros(max_seq_len, dtype=torch.int32, device=device)
    current_idx = 0
    for doc_id, length in enumerate(lengths):
        document_id[current_idx:current_idx+length] = doc_id
        current_idx += length

    query = torch.randn(1, max_seq_len, nheads, HEAD_DIM, device=device, dtype=dtype)
    key = torch.randn_like(query)
    value = torch.randn_like(query)
    bias = torch.rand(max_seq_len, device=device, dtype=dtype)

    flex_attention_fn = partial(flexattention_varlen_bias,
        query,
        key,
        value,
        bias=bias,
        document_id=document_id,
        use_causal_mask=causal
    )

    benchmark_time = do_bench(flex_attention_fn, warmup=25, rep=100)
    print(f"Benchmark time: {benchmark_time:.4f} ms")

    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
        with record_function("flexattention_varlen_bias"):
            out = flex_attention_fn()
    prof.export_chrome_trace("flexattention_trace.json")

    print("Output shape:", out.size())
    # print("Document ID tensor:", document_id)
    print("Lengths of each document:", lengths)

The profile:
image

There are a lot of gaps, it looks as if the flex attention was not properly compiled. Do you have any ideas how to reduce these? I tried a bunch of different compilation modes but unfortunately to no avail.

@drisspg
Copy link
Contributor

drisspg commented Nov 11, 2024

can you try flex_attention = torch.compile(flex_attention, dynamic=False, fullgraph=True) which will raise warnings with reasons why it can't be compiled

@tugot17
Copy link
Author

tugot17 commented Nov 11, 2024

it doesn't raise the warnings, and it is still super scattered

image

@drisspg
Copy link
Contributor

drisspg commented Nov 13, 2024

Screenshot 2024-11-12 at 10 17 59 PM The problem is that create_block_mask is not being compiled

https://gist.github.com/drisspg/4cc2996497e34ce19d3a445dca2147e4

I think one problem with your setup though is that you are creating the block mask at every call to flex attention where this should typically be passed in so that you can amortize the block mask creation over the N transformer Blocks in a typical transformer.

@tugot17
Copy link
Author

tugot17 commented Nov 13, 2024

*I deleted the above comment cause it was wrong

Regarding your comment, since for every call we do different document_id tensor, it is kind of inevitable that we will have to recompile it, right? I struggle to see an optimization we could do here

@drisspg
Copy link
Contributor

drisspg commented Nov 13, 2024

you will need to regenerate block_mask every time the document_id tensor changes correct. But in most transformers you have more than 1 call to attention func, and the document_id is the same for each of these calls. All I meant was that block_mask should be reused as much as possible

@tugot17
Copy link
Author

tugot17 commented Nov 14, 2024

@drisspg one more question, do you have any idea why

cbm = torch.compile(create_block_mask, dynamic=False, fullgraph=True)
    block_mask = cbm(
        document_causal_mask,
        1,
        1,
        total_tokens,
        total_tokens,
        device=device,
    )

compiles fine on the pytorch experimental '2.6.0.dev20241015' but it does not compile on the stable '2.5.1+cu124', with an error. For both of them the pytorch cuda version is 12.4 (torch.version.cuda). Running on 4090.

File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 2150, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 2132, in run_node
    return node.target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function getitem>(*(FakeTensor(..., device='cuda:0', size=(8192,), dtype=torch.int32), BatchedTensor(lvl=3, bdim=0, value=
    FakeTensor(..., device='cuda:0', size=(8192,), dtype=torch.int64)
)), **{}):
vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.

from user code:
   File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py", line 850, in create_block_mask
    partial_block_mask, full_block_mask = inner_func(
  File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py", line 775, in _create_block_mask_inner
    mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, _compile=True)
  File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py", line 755, in create_mask
    mask = mask_mod(b, h, m, n)
  File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
  File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
  File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
  File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
  File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
  File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
  File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
  File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
  File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/piotr_mazurek/luminous-inference/src/luminous_inference/kernels/flex_attention_varlen_bias.py", line 33, in document_causal_mask
    document_mask = document_id[q_idx] == document_id[kv_idx]


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

I1114 15:55:28.448000 2178885 site-packages/torch/_dynamo/utils.py:399] TorchDynamo compilation metrics:
I1114 15:55:28.448000 2178885 site-packages/torch/_dynamo/utils.py:399] Function                  Runtimes (s)
I1114 15:55:28.448000 2178885 site-packages/torch/_dynamo/utils.py:399] ----------------------  --------------
I1114 15:55:28.448000 2178885 site-packages/torch/_dynamo/utils.py:399] _compile.compile_inner         

@drisspg
Copy link
Contributor

drisspg commented Nov 15, 2024

ohh sorry we recently made updates (not landed in 2.5.1) to enable compile for create_block_mask without needing to call the _compile=True private func

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants