Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What is the expected gpu memory performance drop wrt flash attention with block masks? #54

Open
arilato opened this issue Oct 19, 2024 · 2 comments

Comments

@arilato
Copy link

arilato commented Oct 19, 2024

I'm testing out using flex attention to utilize some custom attention masks. The attention masks I'm working with are causal, except there is usually a relatively small single rectangular area of 0s in the mask, which I added support for with a block mask.

Previously, using xformers which maps to flash attention 2, I managed to train mistral-large at 28k sequence length with just tp on 8xh100. However, with flex attention, even 16k sequence length is running OOM. Is thie expected?

I am compiling flex_attention after importing it, and also compiling the block mask when instantiating it.

@Chillee
Copy link
Contributor

Chillee commented Oct 21, 2024

@arilato I would not expect that. The extra memory overhead from FlexAttention should be S^2/(BLOCK_SIZE^2). With the default BLOCK_SIZE of 128, at 28k, the extra memory overhead should be around 100kb or so.

@drisspg
Copy link
Contributor

drisspg commented Oct 21, 2024

@arilato you need to also ensure that you are compiling the create_block_mask function since if this is called without compile it will realized the full 28k x 28k sequence tensor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants