From 00e19b4fe16014c73ceba4c61732994c27f337ab Mon Sep 17 00:00:00 2001 From: coco58323 Date: Mon, 5 Aug 2024 21:19:34 +0800 Subject: [PATCH 1/3] Skip non-selected experts for mixtral and qwen2_moe --- src/transformers/models/mixtral/modeling_mixtral.py | 3 ++- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 522b6db7bcc768..0603ca4220568a 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -728,7 +728,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): + expert_hitted = torch.where(expert_mask.sum(dim=-1).sum(dim=-1) > 0)[0] + for expert_idx in expert_hitted: expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 8cf5c200d8455f..dae602fc3b7349 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -719,7 +719,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): + expert_hitted = torch.where(expert_mask.sum(dim=-1).sum(dim=-1) > 0)[0] + for expert_idx in expert_hitted: expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) From 774f2c80215a0b9a9c6b2f305d607707ef2abdc7 Mon Sep 17 00:00:00 2001 From: coco58323 Date: Mon, 5 Aug 2024 22:45:04 +0800 Subject: [PATCH 2/3] Fix: tensor tolist() --- src/transformers/models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 0603ca4220568a..642e285030562f 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -728,7 +728,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert - expert_hitted = torch.where(expert_mask.sum(dim=-1).sum(dim=-1) > 0)[0] + expert_hitted = torch.where(expert_mask.sum(dim=-1).sum(dim=-1) > 0)[0].tolist() for expert_idx in expert_hitted: expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index dae602fc3b7349..c986a54c037000 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -719,7 +719,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert - expert_hitted = torch.where(expert_mask.sum(dim=-1).sum(dim=-1) > 0)[0] + expert_hitted = torch.where(expert_mask.sum(dim=-1).sum(dim=-1) > 0)[0].tolist() for expert_idx in expert_hitted: expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) From 7c092a6b49fb8ede19979b0914042639b8204b2b Mon Sep 17 00:00:00 2001 From: coco58323 Date: Mon, 5 Aug 2024 23:40:03 +0800 Subject: [PATCH 3/3] WIP: tokenization test --- src/transformers/models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 642e285030562f..c72ab7e4bd8c7f 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -728,7 +728,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert - expert_hitted = torch.where(expert_mask.sum(dim=-1).sum(dim=-1) > 0)[0].tolist() + expert_hitted = (expert_mask.sum(dim=(-1, -2)) > 0).nonzero(as_tuple=True)[0].tolist() for expert_idx in expert_hitted: expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index c986a54c037000..71c2f0560ed2a3 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -719,7 +719,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert - expert_hitted = torch.where(expert_mask.sum(dim=-1).sum(dim=-1) > 0)[0].tolist() + expert_hitted = (expert_mask.sum(dim=(-1, -2)) > 0).nonzero(as_tuple=True)[0].tolist() for expert_idx in expert_hitted: expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx])