From c89f1a154d513aec7727dc9b4d705e3fa140d67b Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 11 Apr 2024 20:38:44 +0000 Subject: [PATCH 1/9] slice --- src/transformers/models/dbrx/modeling_dbrx.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index e13d7d09e6cd68..9c7e373bff3467 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -754,9 +754,14 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, self.activation_fn = ACT2FN[act_fn_name] def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: - expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] - expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] - expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] + # Slice in no grad context to avoid storing the entire param in backward pass + with torch.no_grad(): + expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] + expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] + expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] + expert_w1.requires_grad = True + expert_v1.requires_grad = True + expert_w2.requires_grad = True gate_proj = x.matmul(expert_w1.t()) up_proj = x.matmul(expert_v1.t()) From dbd8b147b60c08a3911885b4b6836a5244389410 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 11 Apr 2024 20:42:55 +0000 Subject: [PATCH 2/9] fix requires grad --- src/transformers/models/dbrx/modeling_dbrx.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 9c7e373bff3467..e99dc47126065a 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -759,9 +759,9 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] - expert_w1.requires_grad = True - expert_v1.requires_grad = True - expert_w2.requires_grad = True + expert_w1.requires_grad = expert_w1.requires_grad + expert_v1.requires_grad = expert_v1.requires_grad + expert_w2.requires_grad = expert_w2.requires_grad gate_proj = x.matmul(expert_w1.t()) up_proj = x.matmul(expert_v1.t()) From a7ee563164cb818b4c9d3ead059fc03b1d527258 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 12 Apr 2024 13:09:00 +0000 Subject: [PATCH 3/9] remove grad --- src/transformers/models/dbrx/modeling_dbrx.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index e99dc47126065a..8773aa4598bbb4 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -759,9 +759,9 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] - expert_w1.requires_grad = expert_w1.requires_grad - expert_v1.requires_grad = expert_v1.requires_grad - expert_w2.requires_grad = expert_w2.requires_grad + # expert_w1.requires_grad = expert_w1.requires_grad + # expert_v1.requires_grad = expert_v1.requires_grad + # expert_w2.requires_grad = expert_w2.requires_grad gate_proj = x.matmul(expert_w1.t()) up_proj = x.matmul(expert_v1.t()) From 2e3bd86aa664a83c12841bbcc542f67484a22850 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 12 Apr 2024 13:56:21 +0000 Subject: [PATCH 4/9] disconnect differently --- src/transformers/models/dbrx/modeling_dbrx.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 8773aa4598bbb4..cee00c5d7c228a 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -755,13 +755,16 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: # Slice in no grad context to avoid storing the entire param in backward pass - with torch.no_grad(): - expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] - expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] - expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] - # expert_w1.requires_grad = expert_w1.requires_grad - # expert_v1.requires_grad = expert_v1.requires_grad - # expert_w2.requires_grad = expert_w2.requires_grad + # with torch.no_grad(): + # expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] + # expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] + # expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] + expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach() + expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach() + expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach() + expert_w1.requires_grad = self.w1.requires_grad + expert_v1.requires_grad = self.v1.requires_grad + expert_w2.requires_grad = self.w2.requires_grad gate_proj = x.matmul(expert_w1.t()) up_proj = x.matmul(expert_v1.t()) From 826947df1a24806ea11994412599903db1930284 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 12 Apr 2024 16:46:07 +0000 Subject: [PATCH 5/9] remove grad --- src/transformers/models/dbrx/modeling_dbrx.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index cee00c5d7c228a..77b519c879680b 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -762,9 +762,9 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach() expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach() expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach() - expert_w1.requires_grad = self.w1.requires_grad - expert_v1.requires_grad = self.v1.requires_grad - expert_w2.requires_grad = self.w2.requires_grad + # expert_w1.requires_grad = self.w1.requires_grad + # expert_v1.requires_grad = self.v1.requires_grad + # expert_w2.requires_grad = self.w2.requires_grad gate_proj = x.matmul(expert_w1.t()) up_proj = x.matmul(expert_v1.t()) From 35aca3a69ad06a6c22b4e507efb70af02645478d Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 12 Apr 2024 17:00:25 +0000 Subject: [PATCH 6/9] enable grads --- src/transformers/models/dbrx/modeling_dbrx.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 77b519c879680b..cee00c5d7c228a 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -762,9 +762,9 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach() expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach() expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach() - # expert_w1.requires_grad = self.w1.requires_grad - # expert_v1.requires_grad = self.v1.requires_grad - # expert_w2.requires_grad = self.w2.requires_grad + expert_w1.requires_grad = self.w1.requires_grad + expert_v1.requires_grad = self.v1.requires_grad + expert_w2.requires_grad = self.w2.requires_grad gate_proj = x.matmul(expert_w1.t()) up_proj = x.matmul(expert_v1.t()) From 7ffb9f85d82218fd312a68e3b539b30e06c8b9f2 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 13 Apr 2024 00:13:05 +0000 Subject: [PATCH 7/9] patch --- src/transformers/models/dbrx/modeling_dbrx.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index cee00c5d7c228a..b44d3463aca6ea 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -754,11 +754,7 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, self.activation_fn = ACT2FN[act_fn_name] def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: - # Slice in no grad context to avoid storing the entire param in backward pass - # with torch.no_grad(): - # expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] - # expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] - # expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] + # Detach experts to avoid backpropagating through the expert selection expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach() expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach() expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach() From 99eba889bee2ee64d6c45be4af010636cc8a6c33 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 14 Apr 2024 21:21:40 +0000 Subject: [PATCH 8/9] nissan al ghaib --- src/transformers/models/dbrx/modeling_dbrx.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index b44d3463aca6ea..263bffb2521d5f 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -753,15 +753,7 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, raise ValueError(f"FFN activation function has unhandled kwargs {ffn_act_fn=}") self.activation_fn = ACT2FN[act_fn_name] - def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: - # Detach experts to avoid backpropagating through the expert selection - expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach() - expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach() - expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx].detach() - expert_w1.requires_grad = self.w1.requires_grad - expert_v1.requires_grad = self.v1.requires_grad - expert_w2.requires_grad = self.w2.requires_grad - + def forward(self, x: torch.Tensor, expert_w1: torch.Tensor, expert_v1: torch.Tensor, expert_w2: torch.Tensor) -> torch.Tensor: gate_proj = x.matmul(expert_w1.t()) up_proj = x.matmul(expert_v1.t()) gate_proj = self.activation_fn(gate_proj) @@ -789,6 +781,13 @@ def forward( out = torch.zeros_like(x) expert_mask = nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0) + # Chunk weights and experts at once to avoid storing full parameter multiple times in autograd + w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(self.moe_num_experts, dim=0) + v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(self.moe_num_experts, dim=0) + w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(self.moe_num_experts, dim=0) + w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked] + v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked] + w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked] for expert_idx in range(0, self.moe_num_experts): topk_idx, token_idx = torch.where(expert_mask[expert_idx]) if token_idx.shape[0] == 0: @@ -798,7 +797,7 @@ def forward( topk_list = topk_idx expert_tokens = x[None, token_list].reshape(-1, hidden_size) - expert_out = self.mlp(expert_tokens, expert_idx) * top_weights[token_list, topk_list, None] + expert_out = self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx]) * top_weights[token_list, topk_list, None] out.index_add_(0, token_idx, expert_out) From ab9d85f5f726b4bd62c6acff2ea764d0c6c1f820 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 15 Apr 2024 07:23:27 -0700 Subject: [PATCH 9/9] Update modeling_dbrx.py --- src/transformers/models/dbrx/modeling_dbrx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 263bffb2521d5f..27488ba864a5c8 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -781,7 +781,7 @@ def forward( out = torch.zeros_like(x) expert_mask = nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0) - # Chunk weights and experts at once to avoid storing full parameter multiple times in autograd + # Chunk experts at once to avoid storing full parameter multiple times in autograd w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(self.moe_num_experts, dim=0) v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(self.moe_num_experts, dim=0) w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(self.moe_num_experts, dim=0)