Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SwitchTransformer] Significant performance improvement on MoE blocks #31173

Merged
merged 4 commits into from
Jun 6, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,19 @@ def forward(self, hidden_states):
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.

next_states = hidden_states.clone()
for idx, expert in enumerate(self.experts.values()):
token_indices = router_mask[:, :, idx].bool()
next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)

router_mask = router_mask.bool()
idx_mask = router_mask.transpose(1, 2) # Batch * experts * tokens
idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2) # 1 * experts * (batch * tokens)
idx_mask = idx_mask.sum(dim=2)
idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
0
].tolist() # length: number of "activated" expert / value: index
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
router_mask = router_mask.bool()
idx_mask = router_mask.transpose(1, 2) # Batch * experts * tokens
idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2) # 1 * experts * (batch * tokens)
idx_mask = idx_mask.sum(dim=2)
idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
0
].tolist() # length: number of "activated" expert / value: index
idx_mask = router_mask.reshape(batch*seq_len, num_experts).transpose(0,1).sum(dim=1)
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[0].tolist()
  • the comment about shapes! 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The batch_size, seq_len, and num_experts are not defined in the funciton.
So, I've defined it with the router_mask and reflected your suggestions.

Thank you @ArthurZucker !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

for idx in idx_mask:
next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(
hidden_states[router_mask[:, :, idx]]
)

hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index)
Expand Down
Loading