Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to flash-attn2 #149

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

Ding3LI
Copy link

@Ding3LI Ding3LI commented Jan 24, 2024

Update Comments

Modified scgpt/model/model.py & scgpt/model/multiomic_model.py for flash-attn2. Updated README.md with flash-attn2 installation and notes.

Updated Dependencies

flash-attn 1.x --> flash-att 2.x

Is Code Tested?

Yes. The updated code was tested on A100 GPU (Linux), and it works fine doing fine-tuning integration with the latest Flash Attention 2.

Note:
In the current version of code, scGPT uses FlashMHA. And the updated flash-attn2 uses a module named mpa.py. The details of newer MHA implementation can be found HERE starting at line 354 flash_attn/modules/mha.py

…ash-attn2. Updated README.md with flash-attn2 installation and notes
@Ding3LI Ding3LI closed this Feb 2, 2024
@Ding3LI Ding3LI reopened this Feb 2, 2024
@NozomiMizore
Copy link

In ur latest changes, u have set the use_flash_attn=False in model.py. I have also modified the code to it. It seems like the model actually didn't call the flash-attn but call the standard mha. And the metric result of my "tutorial_annotation" is terrible. But if I don't set use_flash_attn to False as you did, the program would get some error about padding_mask. Do u have the same problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants