Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlashAttention2 issue with Mistral/Mixtral related to max length and RotaryEmbedding #31228

Closed
psinger opened this issue Jun 4, 2024 · 3 comments

Comments

@psinger
Copy link

psinger commented Jun 4, 2024

I have been seeing very weird behavior when training and running Mistral or Mixtral with samples being exactly the length of max_position_embeddings. The strange behavior manifested itself with complete broken outputs that interestingly resolved itself after reloading the model and running samples with shorter length through.

So the following combination always broke:
Model with max_position_embeddings=8192 and using FA2 and using some samples with size max_length=8192.
It was resolved by either disabling FA2, or actually using samples with max_length=8191.

After a lot of debugging, I figured out that this issue only happens with Flash Attention 2 and not with SDPA or vanilla attention.

I am suspecting that this issue stems from the following line:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L447

rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

If we have a batch with a sequence length of let's say 8192, which could be the same as max_position_embeddings, then kv_seq_len will be 8192 which is the max here, but then we are adding 1, which will lead to 8193 and then we are calling rotary_emb with it.

There, we then call:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L214
and thus re-init the cache with a longer than supported max sequence length.

I think it can be already solved by changing it to:
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1)

I actually noticed that this code has been changed very recently for Mistral to not take the max length and reset it anylonger:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L108

This was done in PR #30642

I think this might have been just a side effect and does not fix Mixtral behavior.

@amyeroberts
Copy link
Collaborator

cc @ArthurZucker @younesbelkada

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jun 6, 2024

I think this is heavily related to #29496, and should simply need an update in the ROPE precision! Would you like to open a PR to update this and see if that solved your issues ?

Copy link

github-actions bot commented Jul 5, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants