Skip to content

Commit

Permalink
comments about shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
ranggihwang committed Jun 4, 2024
1 parent 8e5bb3d commit 3459299
Showing 1 changed file with 2 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,8 @@ def forward(self, hidden_states):
next_states = hidden_states.clone()

router_mask = router_mask.bool()
idx_mask = router_mask.transpose(1, 2) # Batch * experts * tokens
idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2) # 1 * experts * (batch * tokens)
idx_mask = idx_mask.sum(dim=2)
idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens
batch_size, seq_len, num_experts= router_mask.shape
idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0)
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
0
].tolist() # length: number of "activated" expert / value: index
Expand Down

0 comments on commit 3459299

Please sign in to comment.