Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][FIX] Fix Mixtral model #30658

Closed
wants to merge 6 commits into from

Conversation

casper-hansen
Copy link

This PR is a WIP based on @kalomaze's implementation that fixes the Mixtral model. It has been known for a while that Mixtral has been hard to train due to some bug in the code. Please note this is meant for testing based on the provided code.

More context:

@PhilipMay
Copy link
Contributor

PhilipMay commented May 5, 2024

@kalomaze
Copy link

kalomaze commented May 5, 2024

This PR is a WIP based on @kalomaze's implementation that fixes the Mixtral model. It has been known for a while that Mixtral has been hard to train due to some bug in the code. Please note this is meant for testing based on the provided code.

More context:

Thanks for showing interest and putting up a PR with my changes, I've been a bit busy.

What I specifically did to the code:

  • Replaced the Attention, RoPE, etc functions (that should be functionally equivalent to Llama architecture) with their equivalents from modeling_llama.py.
  • MLP evaluation / MoE related components, load balancing, etc are left unchanged, as they were not the cause of high loss.

Then I confirmed training was working as intended by testing topk=1 full finetuning with frozen routing on a "fake MoE" made with mergekit (in which all MLPs are identical).

image

I'm told what specifically fixed the training on my end was replacing the RoPE / Rotary embeddings functions with the Llama implementation equivalents (according to @2wlearning on Twitter):
image

If there is a need to only keep the necessary changes rather than fully standardizing the implementation of Non-MoE components (as they are seen in modeling_llama.py), the Transformers maintainers may want to look into adopting the RoPE related fixes.

From my perspective though, the architectures are essentially equivalent outside of the MoE components; not to mention, the SWA component has been abandoned by Mistral for Mixtral 8x7b/8x22b/7b-v0.2, in favor of higher native context lengths. So I don't see a downside to standardizing these functionally equivalent functions.

Opinions @ArthurZucker ?

@PhilipMay PhilipMay mentioned this pull request May 6, 2024
@PhilipMay
Copy link
Contributor

PhilipMay commented May 6, 2024

@kalomaze and @casper-hansen when I look at this plot:

my conclusion would be that the training with "oldcode" works well (loss goes down) but there is a problem with the new code (loss goes up and then does not move anymore). Can you please clarify?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, could you isolate the changes (I believe it's mostly propagating the ROPE changes from Gemma and Llama to mistral) that improved the training?

Also a similar PR is done in #30642 to support torch compile, we can disentangle both imo

src/transformers/models/mixtral/modeling_mixtral.py Outdated Show resolved Hide resolved
Comment on lines 434 to 438
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here, I don't think this is what allowed for better training right?

@casper-hansen
Copy link
Author

Thanks, could you isolate the changes (I believe it's mostly propagating the ROPE changes from Gemma and Llama to mistral) that improved the training?

Also a similar PR is done in #30642 to support torch compile, we can disentangle both imo

Yes, the isolated changes should be RoPE based on the comments left here. I think the next step will be to start over with the modeling here and copy in the RoPE code from Llama.

@claralp
Copy link
Contributor

claralp commented May 6, 2024

What I specifically did to the code:

  • Replaced the Attention, RoPE, etc functions (that should be functionally equivalent to Llama architecture) with their equivalents from modeling_llama.py.
  • MLP evaluation / MoE related components, load balancing, etc are left unchanged, as they were not the cause of high loss.

Then I confirmed training was working as intended by testing topk=1 full finetuning with frozen routing on a "fake MoE" made with mergekit (in which all MLPs are identical).

image

Can you show the model config you used for these experiments? e.g. did you used RoPE scaling here

@kalomaze
Copy link

kalomaze commented May 6, 2024

there is a problem with the new code (loss goes up and then does not move anymore). Can you please clarify?

I think you are extrapolating a bit too much here. This is not a test being done with instruction finetuning data, and instead is using general webscrape data from FineWeb. I would intuitively expect (due to how training works) for this to be more divergent on average compared to a consistently formatted Instruct dataset, as the patterns are less consistent.

The important part (imo) is that average loss is still lower by a non-negligible extent.

@kalomaze
Copy link

kalomaze commented May 6, 2024

e.g. did you used RoPE scaling here

I used the RoPE config as specified in Llama3 8b's config.json:

  "rope_scaling": null,
  "rope_theta": 500000.0,

In the Mixtral 8x7b & 8x2bb config.json files, you can find

  "rope_theta": 1000000.0,

And there is no mention of a "rope_scaling" type, so I presumed that it similarly registers as null in actual training use.

Perhaps the Llama code has some sort of fallback that was not getting applied in the current Mixtral modeling code before I swapped in the modeling_llama equivalent functions? I haven't had the time to look too closely & narrow it down

@casper-hansen
Copy link
Author

I have now updated the code to focus mostly on RoPE and based it on the current Mixtral model on the main branch. @kalomaze can you confirm if this the intended fix?

@DreamGenX
Copy link

The original tweet demonstrates a difference on 4x8B franken-MoE of Llama 3, not on Mistral 8x7b or 7b. If the claim is that the modeling for Mistral's RoPE is wrong, it would make sense to demonstrate the change on Mistral models.

From the tweet we can see that the change is apparent at step 1 of the training, therefore we should be able to see the same effect by just doing inference (measuring perplexity on some dataset with and without the change).

@casper-hansen
Copy link
Author

I have validated that perplexity is lower than main branch using the gist that was posted (with a small fix). However, when trying to extract the most essential changes to RoPE scaling, I must have taken out the part that lowered the perplexity score. I am looking to create a working commit. @kalomaze @DreamGenX

Perplexity:

  • Main: 3.843
  • PR: 2.619

I will be back with a commit where this can be fully verified!

@casper-hansen
Copy link
Author

It seems there is no measureable difference on eager/flash attention. @kalomaze This suggests you trained your model using SDPA, and the reason you got a better loss was because is_causal was set to False.

Perplexity (PR):

  • attn_implementation="sdpa": 2.619
  • attn_implementation="eager": 3.843
  • attn_implementation="flash_attention_2": 3.843

Now, what makes SDPA have a much lower perplexity? I investigated and found that is_causal on SDPA is set to False by default, which causes the low perplexity. However, on the main branch, it's determined by the following. @ArthurZucker please advise if you have any idea about what is happening here.

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)

ppl.py script
import torch
import torch.nn as nn
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer


def evaluate_perplexity(model, tokenizer):
    def _perplexity(nlls, n_samples, seqlen):
        return torch.exp(torch.stack(nlls).sum() / (n_samples * seqlen))

    # load and prepare dataset
    data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    data = tokenizer("\n\n".join(data["text"]), return_tensors="pt")
    data = data.input_ids.to(model.device)

    seqlen = 2048
    model = model.eval()
    n_samples = data.numel() // seqlen

    nlls = []

    with tqdm(range(n_samples), desc="Perplexity -") as progress_bar:
        for i in progress_bar:
            start_index = i * seqlen
            end_index = (i + 1) * seqlen
            batch = data[:, start_index:end_index].to(model.device)
            with torch.no_grad():
                logits = model(batch).logits
            shift_logits = logits[:, :-1, :].contiguous().float()
            shift_labels = data[:, start_index:end_index][:, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            neg_log_likelihood = loss.float() * seqlen
            nlls.append(neg_log_likelihood)

            curr_ppl = _perplexity(nlls, i + 1, seqlen)
            progress_bar.set_description(f"Perplexity {curr_ppl:.3f}")

    ppl = _perplexity(nlls, n_samples, seqlen)

    return ppl.item()

if __name__ == "__main__":
    model_path = 'mistralai/Mixtral-8x7B-v0.1'
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
    )
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    evaluate_perplexity(model, tokenizer)

@casper-hansen
Copy link
Author

Closing this PR for now as it seems to be a bust. is_causal is to make sure that SDPA correctly masks the inputs, so setting it to False will obviously give lower perplexity.

@DreamGenX
Copy link

@casper-hansen Thank you so much for your dillegent investigation

@ArthurZucker
Copy link
Collaborator

The idea is to dispatch to the memory efficient attention kernel of sdpa. When the attention mask is None, then you are using causal attention by default if you are in the prefill phase.
However, if you are decoding, and the attention_mask is stil None (no padding) then the attention becomes non causal: you pay attention to everything that happened before -> thus if q_len == 1 and attention_mask is None you should not be causal.

Results should however be the same. I am not at the origin of this commit, cc @fxmarty as this sounds interesting to investigate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants