Skip to content

Commit

Permalink
MOE gate fixes and enhancements (#5156)
Browse files Browse the repository at this point in the history
Fixes the following issues:
- Fix capacity when using TP for non-MoE by aligning the capacity to TP
- Fix TopKGate.wg (gate weight) when using ZeRO with fp16 or bf16
- Fix top2 aux loss to be similar to top1 aux loss

Following are few configurable enhancements:
- Support top2 with disable token dropping
- Support disable top2 2nd expert sampling

---------

Signed-off-by: Moshe Island <[email protected]>
Co-authored-by: Moshe Island <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored Mar 7, 2024
1 parent db70c18 commit 5a2e705
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 22 deletions.
7 changes: 5 additions & 2 deletions deepspeed/moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class MoE(nn.Module):
use_rts (bool, optional): default=True, whether to use Random Token Selection.
use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
enable_expert_tensor_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts
top2_2nd_expert_sampling (bool, optional): default=True, whether to perform sampling for 2nd expert
"""

def __init__(self,
Expand All @@ -48,7 +49,8 @@ def __init__(self,
drop_tokens: bool = True,
use_rts: bool = True,
use_tutel: bool = False,
enable_expert_tensor_parallelism: bool = False) -> None:
enable_expert_tensor_parallelism: bool = False,
top2_2nd_expert_sampling: bool = True) -> None:

super(MoE, self).__init__()

Expand All @@ -69,7 +71,8 @@ def __init__(self,

experts = Experts(expert, self.num_local_experts, self.expert_group_name)
self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
min_capacity, noisy_gate_policy, drop_tokens, use_rts),
min_capacity, noisy_gate_policy, drop_tokens, use_rts,
top2_2nd_expert_sampling),
experts,
self.expert_group_name,
self.ep_size,
Expand Down
61 changes: 41 additions & 20 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ def top1gating(logits: Tensor,
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size()
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
capacity = new_capacity

# Compute l_aux
Expand Down Expand Up @@ -275,23 +280,27 @@ def top1gating(logits: Tensor,
return l_aux, combine_weights, dispatch_mask, exp_counts


def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
def top2gating(logits: Tensor,
capacity_factor: float,
min_capacity: int,
drop_tokens: bool = True,
top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)

capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))

# Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)

# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
if top2_2nd_expert_sampling:
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits += gumbel_rsample(logits.shape, device=logits.device)

# Replace top-expert with min value
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
indices2_s = torch.argmax(logits_except1, dim=1)
mask2 = F.one_hot(indices2_s, num_classes=num_experts)

Expand All @@ -301,17 +310,29 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
# Update 2nd's location by accounting for locations of 1st
locations2 += torch.sum(mask1, dim=0, keepdim=True)

# gating decisions
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')

# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts

# Remove locations outside capacity from mask
mask1 *= torch.lt(locations1, capacity)
mask2 *= torch.lt(locations2, capacity)
# gating decisions
exp_counts = torch.sum(mask1 + mask2, dim=0)

if drop_tokens:
# Calculate configured capacity and remove locations outside capacity from mask
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
mask1 *= torch.lt(locations1, capacity)
mask2 *= torch.lt(locations2, capacity)
else:
# Do not drop tokens - set capacity according to current expert assignments
new_capacity = torch.max(exp_counts)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size()
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
capacity = new_capacity

# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)
Expand All @@ -338,7 +359,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()

return l_aux, combine_weights, dispatch_mask, exp_counts
return l_aux, combine_weights, dispatch_mask, exp_counts.detach().to('cpu')


class TopKGate(Module):
Expand Down Expand Up @@ -368,13 +389,14 @@ def __init__(self,
min_capacity: int = 8,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True) -> None:
use_rts: bool = True,
top2_2nd_expert_sampling: bool = True) -> None:
super().__init__()

# Only top-1 and top-2 are supported at the moment.
if k != 1 and k != 2:
raise ValueError('Only top-1 and top-2 gatings are supported.')
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
self.k = k
self.capacity_factor = capacity_factor
self.eval_capacity_factor = eval_capacity_factor
Expand All @@ -385,6 +407,7 @@ def __init__(self,
self.gate_time = 0.0
self.drop_tokens = drop_tokens
self.use_rts = use_rts
self.top2_2nd_expert_sampling = top2_2nd_expert_sampling

def forward(self,
input: torch.Tensor,
Expand All @@ -394,13 +417,11 @@ def forward(self,
if self.wall_clock_breakdown:
self.timers(TOPK_GATE_TIMER).start()

if self.wg.weight.dtype != torch.float32:
self.wg = self.wg.float()
input_fp32 = input.float()
# input jittering
if self.noisy_gate_policy == 'Jitter' and self.training:
input_fp32 = multiplicative_jitter(input_fp32, device=input.device)
logits = self.wg(input_fp32)
logits = torch.nn.functional.linear(input_fp32, weight=self.wg.weight.float(), bias=None)

if self.k == 1:
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
Expand All @@ -409,7 +430,7 @@ def forward(self,

else:
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity)
self.min_capacity, self.drop_tokens, self.top2_2nd_expert_sampling)

if self.wall_clock_breakdown:
self.timers(TOPK_GATE_TIMER).stop()
Expand Down

0 comments on commit 5a2e705

Please sign in to comment.