Skip to content

Commit

Permalink
Merge pull request #4 from mvpatel2000/mvpatel2000/dbrx-chunk
Browse files Browse the repository at this point in the history
Chunk tensors to avoid mem usage spike
  • Loading branch information
mvpatel2000 authored Apr 15, 2024
2 parents a8237bd + ab9d85f commit 8c320f1
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,15 +753,7 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int,
raise ValueError(f"FFN activation function has unhandled kwargs {ffn_act_fn=}")
self.activation_fn = ACT2FN[act_fn_name]

def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
# Detach experts to avoid backpropagating through the expert selection
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach()
expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach()
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach()
expert_w1.requires_grad = self.w1.requires_grad
expert_v1.requires_grad = self.v1.requires_grad
expert_w2.requires_grad = self.w2.requires_grad

def forward(self, x: torch.Tensor, expert_w1: torch.Tensor, expert_v1: torch.Tensor, expert_w2: torch.Tensor) -> torch.Tensor:
gate_proj = x.matmul(expert_w1.t())
up_proj = x.matmul(expert_v1.t())
gate_proj = self.activation_fn(gate_proj)
Expand Down Expand Up @@ -789,6 +781,13 @@ def forward(
out = torch.zeros_like(x)

expert_mask = nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
# Chunk experts at once to avoid storing full parameter multiple times in autograd
w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(self.moe_num_experts, dim=0)
v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(self.moe_num_experts, dim=0)
w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(self.moe_num_experts, dim=0)
w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked]
v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked]
w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked]
for expert_idx in range(0, self.moe_num_experts):
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
if token_idx.shape[0] == 0:
Expand All @@ -798,7 +797,7 @@ def forward(
topk_list = topk_idx

expert_tokens = x[None, token_list].reshape(-1, hidden_size)
expert_out = self.mlp(expert_tokens, expert_idx) * top_weights[token_list, topk_list, None]
expert_out = self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx]) * top_weights[token_list, topk_list, None]

out.index_add_(0, token_idx, expert_out)

Expand Down

0 comments on commit 8c320f1

Please sign in to comment.