diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index c1aa04e4c517fe..c5797d4573b781 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -296,7 +296,7 @@ def forward(self, hidden_states): next_states = hidden_states.clone() router_mask = router_mask.bool() - batch_size, seq_len, num_experts= router_mask.shape + batch_size, seq_len, num_experts = router_mask.shape idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0) idx_mask = torch.nonzero(idx_mask, as_tuple=True)[ 0