Skip to content

Commit

Permalink
make fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
ranggihwang committed Jun 4, 2024
1 parent 3459299 commit 590cd13
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def forward(self, hidden_states):
next_states = hidden_states.clone()

router_mask = router_mask.bool()
batch_size, seq_len, num_experts= router_mask.shape
batch_size, seq_len, num_experts = router_mask.shape
idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0)
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
0
Expand Down

0 comments on commit 590cd13

Please sign in to comment.