-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Precision issues in Mistral rotary embeddings #29496
Comments
Do you want to open a PR to propagate the changes we made to Llama and gemma? |
cc @gante |
@avnermay I'm not too certain, but I think
And for Llama:
The downcast only applies to matrix multiplications and explicit downcasts like what I found what they did in Keras. I haven't ran the code to confirm, but it would be great if you can print the dtype during a finetuning run to confirm |
@danielhanchen the from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = model.to(device="cuda", dtype=torch.bfloat16)
print(model.model.layers[0].self_attn.rotary_emb.inv_freq.dtype) On Llama and Gemma that's no problem, since we're recently updated the code to cast We'll gladly take PRs to fix it ;) We will be touching the other RoPE models soon anyways, to migrate them to a Llama-like structure (which, contrarily to other models, is compatible with |
@gante Whoops sorry just saw this - apologies! Oh fair points on this! Hmm is there like some sort of lockin mechanism to not allow the conversion to occur? Maybe some sort of overriding mechanism ie write over |
Why not use the approach taken by the other models, that force inv_freq to be float32? The key is avoiding cases where cos and sin are recomputed using a low-precision |
@danielhanchen the only solution is to explicitly upcast 😬 some frameworks like deepspeed explicitly can hijack tensor creation and force them to be initialized in a certain type (which has also caused issues with RoPE). @avnermay that is the solution. The change is simple, but we are working on other overlapping problems -- bear with us 🤗 |
Just commenting on this so that it is not marked as stale. Thanks! |
#30642 will fix this ! 🤗 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
transformers/src/transformers/models/mistral/modeling_mistral.py
Lines 120 to 121 in 965cf67
transformers/src/transformers/models/mistral/modeling_mistral.py
Line 377 in 965cf67
If during mixed precision training (e.g., bf16 with HF trainer) of a Mistral model you pass an input equal to (or greater than) the model's maximum sequence length, it will generate new
sin_cached
andcos_cached
tensors which will be incorrect due to precision issues. In particular, theinv_freq
tensor will be in bf16 and this causes the issues. This causes large model quality issues, which I believe is what should be done here.Other models and code bases deal with this by forcing the inv_freq tensor to be float32, which would be good to do here as well. It would also be a good idea to double check other models to make sure this precision problem does not happen for other models.
transformers/src/transformers/models/llama/modeling_llama.py
Lines 136 to 147 in 965cf67
https://github.com/Dao-AILab/flash-attention/blob/6c9e60de566800538fedad2ad5e6b7b55ca7f0c5/flash_attn/layers/rotary.py#L383-L392
The text was updated successfully, but these errors were encountered: