Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_prepare_4d_attention_mask_for_sdpa is not for causal attention but claims... #30095

Closed
minostauros opened this issue Apr 7, 2024 · 5 comments · Fixed by #30138
Closed

_prepare_4d_attention_mask_for_sdpa is not for causal attention but claims... #30095

minostauros opened this issue Apr 7, 2024 · 5 comments · Fixed by #30138

Comments

@minostauros
Copy link

minostauros commented Apr 7, 2024

... SDPA causal mask generation may be wrong for the mask generation.

if torch.all(mask == 1):
if is_tracing:
pass
elif tgt_len == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
return None
elif key_value_length == tgt_len:
return None
else:
# Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)

Will it be safe to just return None for the else: case?

For causal attention, we can just use _prepare_4d_causal_attention_mask_for_sdpa

Related issues:
pytorch/pytorch#108108
Dao-AILab/flash-attention@9e5e8bc
#28802

@amyeroberts
Copy link
Collaborator

cc @fxmarty

@fxmarty
Copy link
Contributor

fxmarty commented Apr 9, 2024

Hi @minostauros, thank you for the report.

... SDPA causal mask generation may be wrong for the mask generation.

_prepare_4d_attention_mask_for_sdpa does not handle causal masks. However,

Will it be safe to just return None for the else: case?

Yes, good catch, I'll fix that! This is a somewhat unlikely case though, where one would use past key values for typically encoder-type of models. How did you run into this case?

@minostauros
Copy link
Author

This is a somewhat unlikely case though, where one would use past key values for typically encoder-type of models. How did you run into this case?

I didn't run into the specific section but I was just reviewing #28802 and was trying to add flash-attention-2 to BERT (BLIP-2 variant of BERT to be exact).
Thanks for confirmation!

Copy link

github-actions bot commented May 7, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@minostauros
Copy link
Author

This issue will be closed by #30138

@huggingface huggingface deleted a comment from github-actions bot Jun 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants