Skip to content

Commit

Permalink
[SwitchTransformer] Significant performance improvement on MoE bloc…
Browse files Browse the repository at this point in the history
…ks (#31173)

* SwitchTransformer MoE layer performance improvement

* make fixup

* comments about shapes

* make fixup
  • Loading branch information
ranggihwang authored Jun 6, 2024
1 parent 8177aa0 commit 9b85e40
Showing 1 changed file with 11 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,17 @@ def forward(self, hidden_states):
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.

next_states = hidden_states.clone()
for idx, expert in enumerate(self.experts.values()):
token_indices = router_mask[:, :, idx].bool()
next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)

router_mask = router_mask.bool()
batch_size, seq_len, num_experts = router_mask.shape
idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0)
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
0
].tolist() # length: number of "activated" expert / value: index
for idx in idx_mask:
next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(
hidden_states[router_mask[:, :, idx]]
)

hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index)
Expand Down

0 comments on commit 9b85e40

Please sign in to comment.