You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have been seeing very weird behavior when training and running Mistral or Mixtral with samples being exactly the length of max_position_embeddings. The strange behavior manifested itself with complete broken outputs that interestingly resolved itself after reloading the model and running samples with shorter length through.
So the following combination always broke:
Model with max_position_embeddings=8192 and using FA2 and using some samples with size max_length=8192.
It was resolved by either disabling FA2, or actually using samples with max_length=8191.
After a lot of debugging, I figured out that this issue only happens with Flash Attention 2 and not with SDPA or vanilla attention.
If we have a batch with a sequence length of let's say 8192, which could be the same as max_position_embeddings, then kv_seq_len will be 8192 which is the max here, but then we are adding 1, which will lead to 8193 and then we are calling rotary_emb with it.
I think this is heavily related to #29496, and should simply need an update in the ROPE precision! Would you like to open a PR to update this and see if that solved your issues ?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
I have been seeing very weird behavior when training and running Mistral or Mixtral with samples being exactly the length of
max_position_embeddings
. The strange behavior manifested itself with complete broken outputs that interestingly resolved itself after reloading the model and running samples with shorter length through.So the following combination always broke:
Model with
max_position_embeddings=8192
and using FA2 and using some samples with sizemax_length=8192
.It was resolved by either disabling FA2, or actually using samples with
max_length=8191
.After a lot of debugging, I figured out that this issue only happens with Flash Attention 2 and not with SDPA or vanilla attention.
I am suspecting that this issue stems from the following line:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L447
If we have a batch with a sequence length of let's say
8192
, which could be the same asmax_position_embeddings
, thenkv_seq_len
will be8192
which is the max here, but then we are adding1
, which will lead to8193
and then we are callingrotary_emb
with it.There, we then call:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L214
and thus re-init the cache with a longer than supported max sequence length.
I think it can be already solved by changing it to:
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1)
I actually noticed that this code has been changed very recently for Mistral to not take the max length and reset it anylonger:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L108
This was done in PR #30642
I think this might have been just a side effect and does not fix Mixtral behavior.
The text was updated successfully, but these errors were encountered: