Skip to content

Commit

Permalink
SwitchTransformer MoE layer performance improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
ranggihwang committed Jun 1, 2024
1 parent 96eb062 commit 52b6c57
Showing 1 changed file with 9 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,15 @@ def forward(self, hidden_states):
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.

next_states = hidden_states.clone()
for idx, expert in enumerate(self.experts.values()):
token_indices = router_mask[:, :, idx].bool()
next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)

router_mask = router_mask.bool()
idx_mask = router_mask.transpose(1,2) # Batch * experts * tokens
idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2) # 1 * experts * (batch * tokens)
idx_mask = idx_mask.sum(dim=2)
idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[0].tolist() # length: number of "activated" expert / value: index
for idx in idx_mask:
next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(hidden_states[router_mask[:, :, idx]])

hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index)
Expand Down

0 comments on commit 52b6c57

Please sign in to comment.