From 3459299ebcc5d61ab661d7db1060de0acd894a80 Mon Sep 17 00:00:00 2001 From: ranggihwang Date: Tue, 4 Jun 2024 04:00:09 +0000 Subject: [PATCH] comments about shapes --- .../switch_transformers/modeling_switch_transformers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index d4280a626bd098..c1aa04e4c517fe 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -296,10 +296,8 @@ def forward(self, hidden_states): next_states = hidden_states.clone() router_mask = router_mask.bool() - idx_mask = router_mask.transpose(1, 2) # Batch * experts * tokens - idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2) # 1 * experts * (batch * tokens) - idx_mask = idx_mask.sum(dim=2) - idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens + batch_size, seq_len, num_experts= router_mask.shape + idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0) idx_mask = torch.nonzero(idx_mask, as_tuple=True)[ 0 ].tolist() # length: number of "activated" expert / value: index