Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support varied input sequence lengths with a fixed block mask #31

Open
tilmto opened this issue Aug 27, 2024 · 5 comments
Open

Support varied input sequence lengths with a fixed block mask #31

tilmto opened this issue Aug 27, 2024 · 5 comments
Labels
question Further information is requested

Comments

@tilmto
Copy link

tilmto commented Aug 27, 2024

Thanks for the great repo!

When using a custom-defined attention mask pattern (e.g., the A-shape mask in this work), I noticed that when the input length (e.g., 512) is shorter than the length of the predefined block mask (e.g., 1024) in mask_mod, the generation results may not be correct, even though the attention pattern of the former is a truncated version of the latter.

Therefore, I wonder if FlexAttention generally supports varying input sequence lengths under a fixed block mask, and how it handles this situation?

@drisspg
Copy link
Contributor

drisspg commented Aug 27, 2024

In general no, the current blessed solution is to call create_block_mask with the new shapes. It is possible to do the slicing of the inner tensors today. The description of this structure can be found here: https://github.com/pytorch/pytorch/blob/44dadf25065c73bd1370258e7fb1b421cee4283a/torch/nn/attention/flex_attention.py#L192

@drisspg drisspg added the question Further information is requested label Aug 27, 2024
@tilmto
Copy link
Author

tilmto commented Aug 27, 2024

Thanks for the prompt response! So, can I understand it like this: if we need to perform evaluations on common LM benchmarks, which often contain questions of varying lengths, we need to create the block mask on the fly for each input (ideally with _compile=True to speed up this process)?

@drisspg
Copy link
Contributor

drisspg commented Aug 28, 2024

yup thats the best approach, with _compile=True the cost should be relatively low compared to actual compute and this cost gets ammortized over all attention calls throughout the model

@tilmto
Copy link
Author

tilmto commented Aug 28, 2024

Got it! The last question is that I find that sometimes setting _compile=True leads to errors related to insufficient cache sizes. This often happens with models that have many full attentions, but when replacing them with sliding window attentions, everything works well. Are there any workarounds for this?

@drisspg
Copy link
Contributor

drisspg commented Aug 28, 2024

hmmm this is likely a dynamic shapes thing, @Chillee

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants