diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 3701b30a227f0b..c5797d4573b781 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -294,9 +294,17 @@ def forward(self, hidden_states): # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones. next_states = hidden_states.clone() - for idx, expert in enumerate(self.experts.values()): - token_indices = router_mask[:, :, idx].bool() - next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype) + + router_mask = router_mask.bool() + batch_size, seq_len, num_experts = router_mask.shape + idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0) + idx_mask = torch.nonzero(idx_mask, as_tuple=True)[ + 0 + ].tolist() # length: number of "activated" expert / value: index + for idx in idx_mask: + next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))( + hidden_states[router_mask[:, :, idx]] + ) hidden_states = router_probs * next_states return hidden_states, (router_logits, expert_index)