Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precision issues in Mistral rotary embeddings #29496

Closed
avnermay opened this issue Mar 6, 2024 · 10 comments
Closed

Precision issues in Mistral rotary embeddings #29496

avnermay opened this issue Mar 6, 2024 · 10 comments

Comments

@avnermay
Copy link

avnermay commented Mar 6, 2024

if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1

If during mixed precision training (e.g., bf16 with HF trainer) of a Mistral model you pass an input equal to (or greater than) the model's maximum sequence length, it will generate new sin_cached and cos_cached tensors which will be incorrect due to precision issues. In particular, the inv_freq tensor will be in bf16 and this causes the issues. This causes large model quality issues, which I believe is what should be done here.

Other models and code bases deal with this by forcing the inv_freq tensor to be float32, which would be good to do here as well. It would also be a good idea to double check other models to make sure this precision problem does not happen for other models.

inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

https://github.com/Dao-AILab/flash-attention/blob/6c9e60de566800538fedad2ad5e6b7b55ca7f0c5/flash_attn/layers/rotary.py#L383-L392

@ArthurZucker
Copy link
Collaborator

Do you want to open a PR to propagate the changes we made to Llama and gemma?

@ArthurZucker
Copy link
Collaborator

cc @gante

@danielhanchen
Copy link
Contributor

@avnermay I'm not too certain, but I think inv_freq will always be calculated in float32. For eg Gemma:

self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim))

And for Llama:

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))

The downcast only applies to matrix multiplications and explicit downcasts like what I found what they did in Keras.

I haven't ran the code to confirm, but it would be great if you can print the dtype during a finetuning run to confirm inv_freq is actually bfloat16.

@gante
Copy link
Member

gante commented Mar 11, 2024

@danielhanchen the inv_freq permanent buffer can be casted with .to model casting, e.g.

from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = model.to(device="cuda", dtype=torch.bfloat16)
print(model.model.layers[0].self_attn.rotary_emb.inv_freq.dtype)

On Llama and Gemma that's no problem, since we're recently updated the code to cast inv_freq to float() before it is applied to get sin and cos (e.g. here). However, other RoPE models like Mistral have yet to receive the same treatment.

We'll gladly take PRs to fix it ;) We will be touching the other RoPE models soon anyways, to migrate them to a Llama-like structure (which, contrarily to other models, is compatible with torch.compile)

@danielhanchen
Copy link
Contributor

@gante Whoops sorry just saw this - apologies!

Oh fair points on this! Hmm is there like some sort of lockin mechanism to not allow the conversion to occur? Maybe some sort of overriding mechanism ie write over tensor.to itself

@avnermay
Copy link
Author

Why not use the approach taken by the other models, that force inv_freq to be float32? The key is avoiding cases where cos and sin are recomputed using a low-precision inv_freq tensor. This occurs (for example) during mixed precision training, because inv_freq was automatically downcast to bfloat16 in that case.

@gante
Copy link
Member

gante commented Mar 19, 2024

@danielhanchen the only solution is to explicitly upcast 😬 some frameworks like deepspeed explicitly can hijack tensor creation and force them to be initialized in a certain type (which has also caused issues with RoPE).

@avnermay that is the solution. The change is simple, but we are working on other overlapping problems -- bear with us 🤗

@avnermay
Copy link
Author

Just commenting on this so that it is not marked as stale. Thanks!

@huggingface huggingface deleted a comment from github-actions bot Apr 15, 2024
@huggingface huggingface deleted a comment from github-actions bot May 10, 2024
@ArthurZucker
Copy link
Collaborator

#30642 will fix this ! 🤗

Copy link

github-actions bot commented Jun 4, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants