-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP][FIX] Fix Mixtral model #30658
[WIP][FIX] Fix Mixtral model #30658
Conversation
See discussion with @kalomaze here on Discord: https://discord.com/channels/1053877538025386074/1132352574750728192/1236763822081966100 |
Thanks for showing interest and putting up a PR with my changes, I've been a bit busy. What I specifically did to the code:
Then I confirmed training was working as intended by testing topk=1 full finetuning with frozen routing on a "fake MoE" made with mergekit (in which all MLPs are identical). I'm told what specifically fixed the training on my end was replacing the RoPE / Rotary embeddings functions with the Llama implementation equivalents (according to @2wlearning on Twitter): If there is a need to only keep the necessary changes rather than fully standardizing the implementation of Non-MoE components (as they are seen in modeling_llama.py), the Transformers maintainers may want to look into adopting the RoPE related fixes. From my perspective though, the architectures are essentially equivalent outside of the MoE components; not to mention, the SWA component has been abandoned by Mistral for Mixtral 8x7b/8x22b/7b-v0.2, in favor of higher native context lengths. So I don't see a downside to standardizing these functionally equivalent functions. Opinions @ArthurZucker ? |
@kalomaze and @casper-hansen when I look at this plot: my conclusion would be that the training with "oldcode" works well (loss goes down) but there is a problem with the new code (loss goes up and then does not move anymore). Can you please clarify? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, could you isolate the changes (I believe it's mostly propagating the ROPE changes from Gemma and Llama to mistral) that improved the training?
Also a similar PR is done in #30642 to support torch compile, we can disentangle both imo
if self.config.pretraining_tp > 1: | ||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) | ||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) | ||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment here, I don't think this is what allowed for better training right?
Yes, the isolated changes should be RoPE based on the comments left here. I think the next step will be to start over with the modeling here and copy in the RoPE code from Llama. |
Can you show the model config you used for these experiments? e.g. did you used RoPE scaling here |
I think you are extrapolating a bit too much here. This is not a test being done with instruction finetuning data, and instead is using general webscrape data from FineWeb. I would intuitively expect (due to how training works) for this to be more divergent on average compared to a consistently formatted Instruct dataset, as the patterns are less consistent. The important part (imo) is that average loss is still lower by a non-negligible extent. |
I used the RoPE config as specified in Llama3 8b's config.json:
In the Mixtral 8x7b & 8x2bb config.json files, you can find
And there is no mention of a Perhaps the Llama code has some sort of fallback that was not getting applied in the current Mixtral modeling code before I swapped in the modeling_llama equivalent functions? I haven't had the time to look too closely & narrow it down |
I have now updated the code to focus mostly on RoPE and based it on the current Mixtral model on the main branch. @kalomaze can you confirm if this the intended fix? |
The original tweet demonstrates a difference on 4x8B franken-MoE of Llama 3, not on Mistral 8x7b or 7b. If the claim is that the modeling for Mistral's RoPE is wrong, it would make sense to demonstrate the change on Mistral models. From the tweet we can see that the change is apparent at step 1 of the training, therefore we should be able to see the same effect by just doing inference (measuring perplexity on some dataset with and without the change). |
I have validated that perplexity is lower than main branch using the gist that was posted (with a small fix). However, when trying to extract the most essential changes to RoPE scaling, I must have taken out the part that lowered the perplexity score. I am looking to create a working commit. @kalomaze @DreamGenX Perplexity:
I will be back with a commit where this can be fully verified! |
It seems there is no measureable difference on eager/flash attention. @kalomaze This suggests you trained your model using SDPA, and the reason you got a better loss was because Perplexity (PR):
Now, what makes SDPA have a much lower perplexity? I investigated and found that transformers/src/transformers/models/mixtral/modeling_mixtral.py Lines 765 to 773 in 75bbfd5
ppl.py scriptimport torch
import torch.nn as nn
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
def evaluate_perplexity(model, tokenizer):
def _perplexity(nlls, n_samples, seqlen):
return torch.exp(torch.stack(nlls).sum() / (n_samples * seqlen))
# load and prepare dataset
data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
data = tokenizer("\n\n".join(data["text"]), return_tensors="pt")
data = data.input_ids.to(model.device)
seqlen = 2048
model = model.eval()
n_samples = data.numel() // seqlen
nlls = []
with tqdm(range(n_samples), desc="Perplexity -") as progress_bar:
for i in progress_bar:
start_index = i * seqlen
end_index = (i + 1) * seqlen
batch = data[:, start_index:end_index].to(model.device)
with torch.no_grad():
logits = model(batch).logits
shift_logits = logits[:, :-1, :].contiguous().float()
shift_labels = data[:, start_index:end_index][:, 1:]
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
neg_log_likelihood = loss.float() * seqlen
nlls.append(neg_log_likelihood)
curr_ppl = _perplexity(nlls, i + 1, seqlen)
progress_bar.set_description(f"Perplexity {curr_ppl:.3f}")
ppl = _perplexity(nlls, n_samples, seqlen)
return ppl.item()
if __name__ == "__main__":
model_path = 'mistralai/Mixtral-8x7B-v0.1'
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
evaluate_perplexity(model, tokenizer) |
Closing this PR for now as it seems to be a bust. |
@casper-hansen Thank you so much for your dillegent investigation |
The idea is to dispatch to the memory efficient attention kernel of sdpa. When the attention mask is None, then you are using causal attention by default if you are in the prefill phase. Results should however be the same. I am not at the origin of this commit, cc @fxmarty as this sounds interesting to investigate. |
This PR is a WIP based on @kalomaze's implementation that fixes the Mixtral model. It has been known for a while that Mixtral has been hard to train due to some bug in the code. Please note this is meant for testing based on the provided code.
More context: