Skip to content

Commit

Permalink
make fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
ranggihwang committed Jun 3, 2024
1 parent 4965da6 commit 8e5bb3d
Showing 1 changed file with 10 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,19 @@ def forward(self, hidden_states):
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.

next_states = hidden_states.clone()

router_mask = router_mask.bool()
idx_mask = router_mask.transpose(1,2) # Batch * experts * tokens
idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2) # 1 * experts * (batch * tokens)
idx_mask = router_mask.transpose(1, 2) # Batch * experts * tokens
idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2) # 1 * experts * (batch * tokens)
idx_mask = idx_mask.sum(dim=2)
idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[0].tolist() # length: number of "activated" expert / value: index
idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
0
].tolist() # length: number of "activated" expert / value: index
for idx in idx_mask:
next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(hidden_states[router_mask[:, :, idx]])
next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(
hidden_states[router_mask[:, :, idx]]
)

hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index)
Expand Down

0 comments on commit 8e5bb3d

Please sign in to comment.