You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm testing out using flex attention to utilize some custom attention masks. The attention masks I'm working with are causal, except there is usually a relatively small single rectangular area of 0s in the mask, which I added support for with a block mask.
Previously, using xformers which maps to flash attention 2, I managed to train mistral-large at 28k sequence length with just tp on 8xh100. However, with flex attention, even 16k sequence length is running OOM. Is thie expected?
I am compiling flex_attention after importing it, and also compiling the block mask when instantiating it.
The text was updated successfully, but these errors were encountered:
@arilato I would not expect that. The extra memory overhead from FlexAttention should be S^2/(BLOCK_SIZE^2). With the default BLOCK_SIZE of 128, at 28k, the extra memory overhead should be around 100kb or so.
@arilato you need to also ensure that you are compiling the create_block_mask function since if this is called without compile it will realized the full 28k x 28k sequence tensor
I'm testing out using flex attention to utilize some custom attention masks. The attention masks I'm working with are causal, except there is usually a relatively small single rectangular area of 0s in the mask, which I added support for with a block mask.
Previously, using xformers which maps to flash attention 2, I managed to train mistral-large at 28k sequence length with just tp on 8xh100. However, with flex attention, even 16k sequence length is running OOM. Is thie expected?
I am compiling
flex_attention
after importing it, and also compiling the block mask when instantiating it.The text was updated successfully, but these errors were encountered: