From 590cd13c21790be90fa86efd2915e84255fb9630 Mon Sep 17 00:00:00 2001 From: ranggihwang Date: Tue, 4 Jun 2024 04:02:46 +0000 Subject: [PATCH] make fixup --- .../models/switch_transformers/modeling_switch_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index c1aa04e4c517fe..c5797d4573b781 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -296,7 +296,7 @@ def forward(self, hidden_states): next_states = hidden_states.clone() router_mask = router_mask.bool() - batch_size, seq_len, num_experts= router_mask.shape + batch_size, seq_len, num_experts = router_mask.shape idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0) idx_mask = torch.nonzero(idx_mask, as_tuple=True)[ 0