Skip to content

Commit

Permalink
detach expert
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Apr 13, 2024
1 parent 0c80857 commit a8237bd
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,9 +754,13 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int,
self.activation_fn = ACT2FN[act_fn_name]

def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
# Detach experts to avoid backpropagating through the expert selection
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach()
expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach()
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach()
expert_w1.requires_grad = self.w1.requires_grad
expert_v1.requires_grad = self.v1.requires_grad
expert_w2.requires_grad = self.w2.requires_grad

gate_proj = x.matmul(expert_w1.t())
up_proj = x.matmul(expert_v1.t())
Expand Down

0 comments on commit a8237bd

Please sign in to comment.