Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for alibi when using flash attention #820

Merged

Conversation

ShashankMosaicML
Copy link
Contributor

@ShashankMosaicML ShashankMosaicML commented Dec 23, 2023

ALiBi support has recently been added to FlashAttention2 (PR#540) and is stable with flash-attention=v2.4.2. This PR allows for ALiBi slopes to be passed to flash attention, hence enabling ALiBi for Flash Attention.

Experiments on 125m and 1b models show (nearly) identical training curves for ALiBi with Triton attention (which is the current default attention implementation when using ALiBi) and ALiBi with Flash Attention. Further, using Flash Attention results in higher MFU numbers. In the plots below, 'treat' refers to flash attention, and 'control' refers to triton attention.

Wandb experiment link
Screenshot 2024-01-01 at 2 04 21 PM

Screenshot 2024-01-01 at 2 04 57 PM

Note: Flash Attention seems to use more memory than Triton:

Screenshot 2024-01-01 at 2 03 04 PM

These changes would effectively also enable the use of sliding window attention and memory efficient sequence id masking with ALiBi when using Flash Attention.


As an aside, @Skylion007 has been able to successfully update his own version of MosaicBERT with FA2, and has found it to indeed be faster than FA1: w&b link

@Skylion007
Copy link
Contributor

FYI, I added a BERT FA2 PR to the examples repo: mosaicml/examples#440

@ShashankMosaicML ShashankMosaicML changed the title Adding support for alibi in flash attention Adding support for alibi when using flash attention Jan 2, 2024
Copy link
Contributor

@irenedea irenedea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this! Added some questions and comments.

tests/models/layers/test_flash_attn.py Outdated Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Show resolved Hide resolved
llmfoundry/models/layers/attention.py Outdated Show resolved Hide resolved
tests/models/layers/test_flash_attn.py Outdated Show resolved Hide resolved
tests/models/test_model.py Outdated Show resolved Hide resolved
Copy link
Contributor

@jacobfulano jacobfulano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me! Could we potentially include an example yaml in scripts/train/pretrain with flash attention 2 as well?

@vchiley
Copy link
Contributor

vchiley commented Jan 4, 2024

This looks good to me! Could we potentially include an example yaml in scripts/train/pretrain with flash attention 2 as well?

if you install llm-foundry using pip install -e . [gpu-flash2] (instead of pip install -e . [gpu]) and set `attn_impl: flash, MPT will uses FA2.

Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, will leave approval for @irenedea

@ShashankMosaicML ShashankMosaicML merged commit d991f37 into mosaicml:main Jan 5, 2024
10 checks passed
@ShashankMosaicML ShashankMosaicML deleted the shashank/alibi_flash_attn branch January 5, 2024 21:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants