-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torch compile for mixtral #30793
Add torch compile for mixtral #30793
Conversation
fix style
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
# the `top_x` tensor here. this will give `skipping cudagraphs due to index put with accumulate` | ||
# in compile | ||
# final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||
|
||
# still suffers from `skipping cudagraphs due to ['incompatible ops']` | ||
final_hidden_states[top_x] += current_hidden_states.to(hidden_states.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am kind of stuck on this, here it seems to give cudagraph skipped warnings
no matter what equivalent form I put, for now it seems that cudagraphs can only be applied partially because of this, I have tried the following forms:
final_hidden_states.index_add_
this will giveskipping cudagraphs due to index put with accumulate
final_hidden_states[top_x] += ...
this will giveskipping cudagraphs due to ['incompatible ops']
final_hidden_states.scatter_add_...
this will disable fullgraph tracing because data dependent ops ontop_x
I think the root cause still comes from the dynamic nature of moe where different experts compute different sets of tokens, and it seems that we can not circumvent index put if we do not want every expert to do a full
forward with all tokens @ArthurZucker @gante do you have any thoughts or suggestions on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is expected pretty much yes!
If we use the megablock like implementation (with sparse topology and matrix reprensentation) like it was done in JetMoE we might be able to get over this, but not sure we can go further with the current version!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is expected pretty much yes! If we use the megablock like implementation (with sparse topology and matrix reprensentation) like it was done in JetMoE we might be able to get over this, but not sure we can go further with the current version!
Yes, the root cause is top_x
here we use is unbacked free symbols
in torch.compile and is data dependent beucase of torch.where
, this will cause skipped cudagraphs, but we will still benefit from partial cudagraphs if we are not rewriting it into sparse forms
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unfortunately, currently torch.compile
produces wrong results when setting fullgraph=True
, I believe it has something to do with torch.where
used here(when I try ignore expert mask and compute the whole token set for every expert results can align with eager forward), the traced fx graph is not correct, I think if we want to support torch.compile
in fullgraph
mode we have to rewrite moe layer in a whole different way, maybe compute experts for tokens rather than compute tokens for experts @ArthurZucker
class MixtralBlockTop2MLP(nn.Module): | ||
def __init__(self, config: MixtralConfig): | ||
super().__init__() | ||
self.num_experts = config.num_local_experts | ||
self.ffn_dim = config.intermediate_size | ||
self.hidden_dim = config.hidden_size | ||
|
||
self.w1 = nn.Parameter(torch.empty(self.num_experts, self.ffn_dim, self.hidden_dim)) | ||
self.w2 = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.ffn_dim)) | ||
self.w3 = nn.Parameter(torch.empty(self.num_experts, self.ffn_dim, self.hidden_dim)) | ||
|
||
self.act_fn = ACT2FN[config.hidden_act] | ||
|
||
def forward( | ||
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor | ||
) -> torch.Tensor: | ||
"""_summary_ | ||
|
||
Args: | ||
hidden_states (torch.Tensor): (batch_size * token_num, hidden_dim) | ||
selected_experts (torch.Tensor): (batch_size * token_num, top_k) | ||
routing_weights (torch.Tensor): (batch_size * token_num, top_k) | ||
|
||
Returns: | ||
torch.Tensor: _description_ | ||
""" | ||
|
||
ts, tk = hidden_states.size(0), selected_experts.size(-1) | ||
|
||
w1 = self.w1[selected_experts] # (batch_size * token_num, top_k, ffn_dim, hidden_dim) | ||
w2 = self.w2[selected_experts] # (batch_size * token_num, top_k, hidden_dim, ffn_dim) | ||
w3 = self.w3[selected_experts] # (batch_size * token_num, ffn_dim, hidden_dim) | ||
|
||
x1 = torch.matmul(w1, hidden_states[:, None, :, None]) | ||
x3 = torch.matmul(w3, hidden_states[:, None, :, None]) | ||
x1 = self.act_fn(x1) | ||
final_hidden_states = torch.matmul(w2, x1 * x3).reshape(ts, tk, self.hidden_dim) | ||
final_hidden_states = final_hidden_states * routing_weights[:, :, None] | ||
final_hidden_states = final_hidden_states.sum(dim=1) | ||
return final_hidden_states | ||
|
||
|
||
class MixtralMoeBlock(nn.Module): | ||
def __init__(self, config) -> None: | ||
super().__init__() | ||
self.hidden_dim = config.hidden_size | ||
self.ffn_dim = config.intermediate_size | ||
self.num_experts = config.num_local_experts | ||
self.top_k = config.num_experts_per_tok | ||
|
||
# gating | ||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) | ||
self.experts = MixtralBlockTop2MLP(config) | ||
# Jitter parameters | ||
self.jitter_noise = config.router_jitter_noise | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
batch_size, sequence_length, hidden_dim = hidden_states.shape | ||
if self.training and self.jitter_noise > 0: | ||
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) | ||
hidden_states = hidden_states.view(-1, hidden_dim) | ||
# router_logits: (batch * sequence_length, n_experts) | ||
router_logits = self.gate(hidden_states) | ||
|
||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) | ||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) | ||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True) | ||
# we cast back to the input dtype | ||
routing_weights = routing_weights.to(hidden_states.dtype) | ||
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights) | ||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) | ||
return final_hidden_states, router_logits | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this gathers experts for tokens, and it actually works for torch.compile
with fullgraph and cudagraphs support, and I think it works best when we are doing decoding phase where the batchsize is small, but it will uses more memory because we need to gather expert weights for every token, however it will require changes on model weights structure when loading (from expert-wise scattered MLPs to a centralized MLP)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are supporting fast generation, then I think it's good to have this than the current version because we definitely will gain more speedups especially when decoding @ArthurZucker
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah GPTFast has similar changes! I think it's super interesting but too breaking as you mention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a lot for working on this. Seems like it would be too breaking to merge as is 😢
But I'll ping you if we have a new MoE model to use this as default...
A related PR / implementation is #31173! Does this version support compile
class MixtralBlockTop2MLP(nn.Module): | ||
def __init__(self, config: MixtralConfig): | ||
super().__init__() | ||
self.num_experts = config.num_local_experts | ||
self.ffn_dim = config.intermediate_size | ||
self.hidden_dim = config.hidden_size | ||
|
||
self.w1 = nn.Parameter(torch.empty(self.num_experts, self.ffn_dim, self.hidden_dim)) | ||
self.w2 = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.ffn_dim)) | ||
self.w3 = nn.Parameter(torch.empty(self.num_experts, self.ffn_dim, self.hidden_dim)) | ||
|
||
self.act_fn = ACT2FN[config.hidden_act] | ||
|
||
def forward( | ||
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor | ||
) -> torch.Tensor: | ||
"""_summary_ | ||
|
||
Args: | ||
hidden_states (torch.Tensor): (batch_size * token_num, hidden_dim) | ||
selected_experts (torch.Tensor): (batch_size * token_num, top_k) | ||
routing_weights (torch.Tensor): (batch_size * token_num, top_k) | ||
|
||
Returns: | ||
torch.Tensor: _description_ | ||
""" | ||
|
||
ts, tk = hidden_states.size(0), selected_experts.size(-1) | ||
|
||
w1 = self.w1[selected_experts] # (batch_size * token_num, top_k, ffn_dim, hidden_dim) | ||
w2 = self.w2[selected_experts] # (batch_size * token_num, top_k, hidden_dim, ffn_dim) | ||
w3 = self.w3[selected_experts] # (batch_size * token_num, ffn_dim, hidden_dim) | ||
|
||
x1 = torch.matmul(w1, hidden_states[:, None, :, None]) | ||
x3 = torch.matmul(w3, hidden_states[:, None, :, None]) | ||
x1 = self.act_fn(x1) | ||
final_hidden_states = torch.matmul(w2, x1 * x3).reshape(ts, tk, self.hidden_dim) | ||
final_hidden_states = final_hidden_states * routing_weights[:, :, None] | ||
final_hidden_states = final_hidden_states.sum(dim=1) | ||
return final_hidden_states | ||
|
||
|
||
class MixtralMoeBlock(nn.Module): | ||
def __init__(self, config) -> None: | ||
super().__init__() | ||
self.hidden_dim = config.hidden_size | ||
self.ffn_dim = config.intermediate_size | ||
self.num_experts = config.num_local_experts | ||
self.top_k = config.num_experts_per_tok | ||
|
||
# gating | ||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) | ||
self.experts = MixtralBlockTop2MLP(config) | ||
# Jitter parameters | ||
self.jitter_noise = config.router_jitter_noise | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
batch_size, sequence_length, hidden_dim = hidden_states.shape | ||
if self.training and self.jitter_noise > 0: | ||
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) | ||
hidden_states = hidden_states.view(-1, hidden_dim) | ||
# router_logits: (batch * sequence_length, n_experts) | ||
router_logits = self.gate(hidden_states) | ||
|
||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) | ||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) | ||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True) | ||
# we cast back to the input dtype | ||
routing_weights = routing_weights.to(hidden_states.dtype) | ||
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights) | ||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) | ||
return final_hidden_states, router_logits | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah GPTFast has similar changes! I think it's super interesting but too breaking as you mention
Yes, it does, with fullgraph and cudagraphs enabled |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
This PR is working in progress and it tries to add torch compile support for Mixtral, it currently also contains changes from #30642 because there are some common ground shared between these two models, and there are several issues regarding Mixtral:
I believe it's inevitable because
MistralSparseMoeBlock
usestorch.where
to extract tokens that each expert cares about, and the number and indexes of tokens that each expert attends to are variable, even if we do make a static shape(which means we zero out the non-care tokens for each expert), we are adding extra computation cost because zero-out values still get to take participate in computation, and each expert will have to run full tokens in terms of computation, which makes the whole point of computation-saving of MOE invalid.