From 52b6c57954e5d1986030b52474e23a08a76e2805 Mon Sep 17 00:00:00 2001 From: ranggihwang Date: Sat, 1 Jun 2024 07:25:57 +0000 Subject: [PATCH] SwitchTransformer MoE layer performance improvement --- .../modeling_switch_transformers.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 3701b30a227f0b..4256791998fd24 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -294,9 +294,15 @@ def forward(self, hidden_states): # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones. next_states = hidden_states.clone() - for idx, expert in enumerate(self.experts.values()): - token_indices = router_mask[:, :, idx].bool() - next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype) + + router_mask = router_mask.bool() + idx_mask = router_mask.transpose(1,2) # Batch * experts * tokens + idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2) # 1 * experts * (batch * tokens) + idx_mask = idx_mask.sum(dim=2) + idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens + idx_mask = torch.nonzero(idx_mask, as_tuple=True)[0].tolist() # length: number of "activated" expert / value: index + for idx in idx_mask: + next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(hidden_states[router_mask[:, :, idx]]) hidden_states = router_probs * next_states return hidden_states, (router_logits, expert_index)