Skip to content

Commit

Permalink
Added torch_dmoe defaults, bug fixes for 2D inputs (#1210)
Browse files Browse the repository at this point in the history
* defaults for torch dmoe match mb dmoe

* defaults for torch dmoe match mb dmoe

* defaults for torch dmoe match mb dmoe

* defaults for torch dmoe match mb dmoe

* defaults for torch dmoe match mb dmoe

* defaults for torch dmoe match mb dmoe

* defaults for torch dmoe match mb dmoe

* top k proper

* permute fix

* permute fix

* permute fix

* permute fix

* permute fix

* narrow down world size bug

* narrow down world size bug

* narrow down world size bug

* narrow down world size bug

* narrow down world size bug

* narrow down world size bug

* narrow down world size bug

* narrow down world size bug

* narrow down world size bug

* narrow down world size bug

* narrow down world size bug

* narrow down world size bug

* narrow down world size bug

* narrow down world size bug

* blocking dimension bug

* done

* yo

* yo
  • Loading branch information
snarayan21 authored May 15, 2024
1 parent 8274c6c commit b414626
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 25 deletions.
45 changes: 25 additions & 20 deletions llmfoundry/models/layers/dmoe.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Callable, Optional
from functools import partial
from typing import Callable, Optional, Tuple, Union

import torch
import torch.nn.functional as F

__all__ = [
'dMoE',
Expand All @@ -13,6 +15,8 @@
'DroplessMLP',
]

DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')


# Add option to route tokens uniformly across experts. We use
# a custom autograd op router backwards is still run for benchmarking.
Expand All @@ -36,17 +40,18 @@ def __init__(
hidden_size: int,
moe_num_experts: int,
moe_top_k: int,
moe_jitter_eps: float,
moe_normalize_expert_weights: bool,
moe_jitter_eps: Optional[float],
moe_normalize_expert_weights: Optional[Union[int, float]],
uniform_expert_assignment: bool,
device: Optional[torch.device],
) -> None:
super().__init__()
self.hidden_size: int = hidden_size
self.moe_num_experts: int = moe_num_experts
self.moe_top_k: int = moe_top_k
self.moe_jitter_eps: float = moe_jitter_eps
self.moe_normalize_expert_weights: bool = moe_normalize_expert_weights
self.moe_jitter_eps: Optional[float] = moe_jitter_eps
self.moe_normalize_expert_weights: Optional[Union[
int, float]] = moe_normalize_expert_weights
self.uniform_expert_assignment: bool = uniform_expert_assignment

self.layer: torch.nn.Module = torch.nn.Linear(
Expand All @@ -57,6 +62,7 @@ def __init__(
)

def jitter(self, x: torch.Tensor) -> torch.Tensor:
assert self.moe_jitter_eps is not None
low: float = 1.0 - self.moe_jitter_eps
high: float = 1.0 + self.moe_jitter_eps
noise: torch.Tensor = torch.rand(
Expand All @@ -66,16 +72,15 @@ def jitter(self, x: torch.Tensor) -> torch.Tensor:
)
return low + noise * (high - low)

def _top_k(self, scores: torch.Tensor) -> torch.Tensor:
def _top_k(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.moe_top_k == 1:
return scores.max(
dim=-1,
) # pyright: ignore[reportGeneralTypeIssues]
values, indices = scores.max(dim=-1,)
return values.unsqueeze(-1), indices.unsqueeze(-1)
return torch.topk(
scores,
self.moe_top_k,
dim=-1,
) # pyright: ignore[reportGeneralTypeIssues]
)

def forward(self, x: torch.Tensor):
if self.training and self.moe_jitter_eps is not None:
Expand Down Expand Up @@ -288,17 +293,17 @@ class dMoE(torch.nn.Module):

def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
moe_num_experts: int,
moe_top_k: int,
mlp_type: str,
activation_fn: Callable,
moe_jitter_eps: float,
moe_normalize_expert_weights: bool,
uniform_expert_assignment: bool,
bias: bool,
device: Optional[torch.device],
hidden_size: int = 1024,
ffn_hidden_size: int = 4096,
moe_num_experts: int = 1,
moe_top_k: int = 1,
mlp_type: str = 'mlp',
activation_fn: Callable = DEFAULT_ACTIVATION_FN,
moe_jitter_eps: Optional[float] = None,
moe_normalize_expert_weights: Optional[Union[int, float]] = None,
uniform_expert_assignment: bool = False,
bias: bool = True,
):
super().__init__()

Expand Down
91 changes: 86 additions & 5 deletions tests/models/layers/test_dmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import copy
from contextlib import nullcontext
from functools import partial
from typing import List, Optional
from typing import List, Optional, Union

import pytest
import torch
Expand Down Expand Up @@ -63,21 +63,26 @@ def _get_torch_dtype(fp16: bool, bf16: bool) -> Optional[torch.dtype]:
)
@pytest.mark.gpu
@pytest.mark.world_size(2)
@pytest.mark.parametrize('moe_num_experts', [8])
@pytest.mark.parametrize('moe_num_experts', [1, 2, 8])
@pytest.mark.parametrize('mlp_type', ['glu', 'mlp'])
@pytest.mark.parametrize('moe_world_size', [1, 2])
@pytest.mark.parametrize('moe_normalize_expert_weights', [1, 2.0])
@pytest.mark.parametrize('two_d_input', [True, False])
def test_dmoe(
moe_num_experts: int,
mlp_type: str,
moe_world_size: int,
moe_normalize_expert_weights: Union[float, int],
two_d_input: bool,
):
if moe_world_size > moe_num_experts or moe_num_experts % moe_world_size != 0:
pytest.skip('Mismatch between moe_world_size and moe_num_experts.')
moe_top_k = min(2, moe_num_experts)
# Generate inputs
rank = dist.get_rank()
batch_size = 2
seq_len = 3
hidden_size = 128
hidden_size = 256
if two_d_input:
input_shape = [batch_size * seq_len, hidden_size]
else:
Expand All @@ -92,10 +97,10 @@ def test_dmoe(
common_args = {
'hidden_size': hidden_size,
'ffn_hidden_size': hidden_size,
'moe_top_k': 2,
'moe_top_k': moe_top_k,
'activation_fn': partial(F.gelu, approximate='none'),
'moe_jitter_eps': 0.0, # Disable randomiztion
'moe_normalize_expert_weights': 1,
'moe_normalize_expert_weights': moe_normalize_expert_weights,
'uniform_expert_assignment': False,
'bias': False,
'device': device,
Expand Down Expand Up @@ -197,6 +202,82 @@ def test_dmoe(
torch.testing.assert_close(torch_y, mb_y)


@pytest.mark.skipif(
not is_megablocks_imported,
reason='This test needs megablocks module',
)
@pytest.mark.gpu
@pytest.mark.world_size(2)
@pytest.mark.parametrize('two_d_input', [True, False])
def test_dmoe_defaults(two_d_input: bool,):
rank = dist.get_rank()
fp16 = False
bf16 = True
dtype = _get_torch_dtype(fp16, bf16)

# Construct DDP torch dMoE. torch_dmoe does not currently support bias.
device = torch.device(f'cuda:{dist.get_rank()}')
common_args = {
'device': device,
'bias': False,
}

torch_dmoe = dMoE(**common_args).to(device, dtype=dtype)
torch_dmoe = DDP(
torch_dmoe,
device_ids=[rank],
)
torch_dmoe_optimizer = optim.SGD(torch_dmoe.parameters(), lr=0.1)

# Construct TP MB dMoE
mp_dmoe_args = copy.deepcopy(common_args)
extra_args = {
'fp16': fp16,
'bf16': bf16,
'init_method': partial(torch.nn.init.uniform_, a=-1.0, b=1.0),
}

# Expert parallelism is not enabled by default
mp_dmoe_args.update(extra_args)
args = megablocks.layers.arguments.Arguments(**mp_dmoe_args,)
mb_dmoe = megablocks.layers.dmoe.dMoE(args).to(device)
mb_dmoe.router = DDP(mb_dmoe.router, device_ids=[rank])

mb_dmoe.experts = DDP(mb_dmoe.experts, device_ids=[rank])
mb_dmoe_state_dict = get_model_state_dict(
mb_dmoe,
options=StateDictOptions(full_state_dict=True,),
)
mb_dmoe_optimizer = optim.SGD(mb_dmoe.parameters(), lr=0.1)

# Generate inputs based on hidden_size in megablocks arguments
batch_size = 2
seq_len = 3
hidden_size = args.hidden_size
if two_d_input:
input_shape = [batch_size * seq_len, hidden_size]
else:
input_shape = [batch_size, seq_len, hidden_size]

x = _get_all_inputs(input_shape, dtype)[rank]

# Load mb_dmoe state dict to torch dmoe
torch_dmoe.module.load_state_dict(mb_dmoe_state_dict, strict=True)

# Run train_step check
torch_y = torch_dmoe(x)
mb_y = mb_dmoe(x)

torch_y.sum().backward()
mb_y.sum().backward()
torch_dmoe_optimizer.step()
mb_dmoe_optimizer.step()

torch_y = torch_dmoe(x)
mb_y = mb_dmoe(x)
torch.testing.assert_close(torch_y, mb_y)


@pytest.mark.skipif(
not is_megablocks_imported,
reason='This test needs megablocks module',
Expand Down

0 comments on commit b414626

Please sign in to comment.