-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flex attention - gaps in profiler #76
Comments
can you try |
The problem is that create_block_mask is not being compiled
https://gist.github.com/drisspg/4cc2996497e34ce19d3a445dca2147e4 I think one problem with your setup though is that you are creating the block mask at every call to flex attention where this should typically be passed in so that you can amortize the block mask creation over the N transformer Blocks in a typical transformer. |
*I deleted the above comment cause it was wrong Regarding your comment, since for every call we do different |
you will need to regenerate block_mask every time the document_id tensor changes correct. But in most transformers you have more than 1 call to attention func, and the document_id is the same for each of these calls. All I meant was that block_mask should be reused as much as possible |
@drisspg one more question, do you have any idea why cbm = torch.compile(create_block_mask, dynamic=False, fullgraph=True)
block_mask = cbm(
document_causal_mask,
1,
1,
total_tokens,
total_tokens,
device=device,
) compiles fine on the pytorch experimental '2.6.0.dev20241015' but it does not compile on the stable File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 2150, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 2132, in run_node
return node.target(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function getitem>(*(FakeTensor(..., device='cuda:0', size=(8192,), dtype=torch.int32), BatchedTensor(lvl=3, bdim=0, value=
FakeTensor(..., device='cuda:0', size=(8192,), dtype=torch.int64)
)), **{}):
vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.
from user code:
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py", line 850, in create_block_mask
partial_block_mask, full_block_mask = inner_func(
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py", line 775, in _create_block_mask_inner
mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, _compile=True)
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py", line 755, in create_mask
mask = mask_mod(b, h, m, n)
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/apis.py", line 203, in wrapped
return vmap_impl(
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
return _flat_vmap(
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/apis.py", line 203, in wrapped
return vmap_impl(
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
return _flat_vmap(
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/apis.py", line 203, in wrapped
return vmap_impl(
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
return _flat_vmap(
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/apis.py", line 203, in wrapped
return vmap_impl(
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
return _flat_vmap(
File "/home/piotr_mazurek/miniconda3/envs/luminous-inference-test/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/piotr_mazurek/luminous-inference/src/luminous_inference/kernels/flex_attention_varlen_bias.py", line 33, in document_causal_mask
document_mask = document_id[q_idx] == document_id[kv_idx]
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
I1114 15:55:28.448000 2178885 site-packages/torch/_dynamo/utils.py:399] TorchDynamo compilation metrics:
I1114 15:55:28.448000 2178885 site-packages/torch/_dynamo/utils.py:399] Function Runtimes (s)
I1114 15:55:28.448000 2178885 site-packages/torch/_dynamo/utils.py:399] ---------------------- --------------
I1114 15:55:28.448000 2178885 site-packages/torch/_dynamo/utils.py:399] _compile.compile_inner |
ohh sorry we recently made updates (not landed in 2.5.1) to enable compile for create_block_mask without needing to call the _compile=True private func |
Repost from the PyTorch forum
I have recently been playing with Flex attention, trying to replace some of my custom triton kernels. Unfortunately Flex attention ones were substantially (around 5x) slower than the custom one I tried. I use something similar to the Tri Dao flash attention kernel, but with bias. I started looking at the torch profile for a single forward pass on the complied version after the model warmup and it looks very strange.
The profile:
There are a lot of gaps, it looks as if the flex attention was not properly compiled. Do you have any ideas how to reduce these? I tried a bunch of different compilation modes but unfortunately to no avail.
The text was updated successfully, but these errors were encountered: