-
Notifications
You must be signed in to change notification settings - Fork 536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for alibi when using flash attention #820
Adding support for alibi when using flash attention #820
Conversation
Pulling the latest commits from main fork
Pulling from the main repo
Pulling from mosaicml/llm-foundry main
Merging from mosaic main
Pulling from mosaic main
Pulling from mosaic main.
FYI, I added a BERT FA2 PR to the examples repo: mosaicml/examples#440 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this! Added some questions and comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me! Could we potentially include an example yaml in scripts/train/pretrain
with flash attention 2 as well?
if you install llm-foundry using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, will leave approval for @irenedea
Co-authored-by: Irene Dea <[email protected]>
ALiBi support has recently been added to FlashAttention2 (PR#540) and is stable with
flash-attention=v2.4.2
. This PR allows for ALiBi slopes to be passed to flash attention, hence enabling ALiBi for Flash Attention.Experiments on 125m and 1b models show (nearly) identical training curves for ALiBi with Triton attention (which is the current default attention implementation when using ALiBi) and ALiBi with Flash Attention. Further, using Flash Attention results in higher MFU numbers. In the plots below, 'treat' refers to flash attention, and 'control' refers to triton attention.
Wandb experiment link
Note: Flash Attention seems to use more memory than Triton:
These changes would effectively also enable the use of sliding window attention and memory efficient sequence id masking with ALiBi when using Flash Attention.
As an aside, @Skylion007 has been able to successfully update his own version of MosaicBERT with FA2, and has found it to indeed be faster than FA1: w&b link