diff --git a/llmfoundry/models/layers/dmoe.py b/llmfoundry/models/layers/dmoe.py index e467ce227f..6190dbc6ea 100644 --- a/llmfoundry/models/layers/dmoe.py +++ b/llmfoundry/models/layers/dmoe.py @@ -1,9 +1,11 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Optional +from functools import partial +from typing import Callable, Optional, Tuple, Union import torch +import torch.nn.functional as F __all__ = [ 'dMoE', @@ -13,6 +15,8 @@ 'DroplessMLP', ] +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') + # Add option to route tokens uniformly across experts. We use # a custom autograd op router backwards is still run for benchmarking. @@ -36,8 +40,8 @@ def __init__( hidden_size: int, moe_num_experts: int, moe_top_k: int, - moe_jitter_eps: float, - moe_normalize_expert_weights: bool, + moe_jitter_eps: Optional[float], + moe_normalize_expert_weights: Optional[Union[int, float]], uniform_expert_assignment: bool, device: Optional[torch.device], ) -> None: @@ -45,8 +49,9 @@ def __init__( self.hidden_size: int = hidden_size self.moe_num_experts: int = moe_num_experts self.moe_top_k: int = moe_top_k - self.moe_jitter_eps: float = moe_jitter_eps - self.moe_normalize_expert_weights: bool = moe_normalize_expert_weights + self.moe_jitter_eps: Optional[float] = moe_jitter_eps + self.moe_normalize_expert_weights: Optional[Union[ + int, float]] = moe_normalize_expert_weights self.uniform_expert_assignment: bool = uniform_expert_assignment self.layer: torch.nn.Module = torch.nn.Linear( @@ -57,6 +62,7 @@ def __init__( ) def jitter(self, x: torch.Tensor) -> torch.Tensor: + assert self.moe_jitter_eps is not None low: float = 1.0 - self.moe_jitter_eps high: float = 1.0 + self.moe_jitter_eps noise: torch.Tensor = torch.rand( @@ -66,16 +72,15 @@ def jitter(self, x: torch.Tensor) -> torch.Tensor: ) return low + noise * (high - low) - def _top_k(self, scores: torch.Tensor) -> torch.Tensor: + def _top_k(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.moe_top_k == 1: - return scores.max( - dim=-1, - ) # pyright: ignore[reportGeneralTypeIssues] + values, indices = scores.max(dim=-1,) + return values.unsqueeze(-1), indices.unsqueeze(-1) return torch.topk( scores, self.moe_top_k, dim=-1, - ) # pyright: ignore[reportGeneralTypeIssues] + ) def forward(self, x: torch.Tensor): if self.training and self.moe_jitter_eps is not None: @@ -288,17 +293,17 @@ class dMoE(torch.nn.Module): def __init__( self, - hidden_size: int, - ffn_hidden_size: int, - moe_num_experts: int, - moe_top_k: int, - mlp_type: str, - activation_fn: Callable, - moe_jitter_eps: float, - moe_normalize_expert_weights: bool, - uniform_expert_assignment: bool, - bias: bool, device: Optional[torch.device], + hidden_size: int = 1024, + ffn_hidden_size: int = 4096, + moe_num_experts: int = 1, + moe_top_k: int = 1, + mlp_type: str = 'mlp', + activation_fn: Callable = DEFAULT_ACTIVATION_FN, + moe_jitter_eps: Optional[float] = None, + moe_normalize_expert_weights: Optional[Union[int, float]] = None, + uniform_expert_assignment: bool = False, + bias: bool = True, ): super().__init__() diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index 3a561b3d3c..8668aa2ec9 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -4,7 +4,7 @@ import copy from contextlib import nullcontext from functools import partial -from typing import List, Optional +from typing import List, Optional, Union import pytest import torch @@ -63,21 +63,26 @@ def _get_torch_dtype(fp16: bool, bf16: bool) -> Optional[torch.dtype]: ) @pytest.mark.gpu @pytest.mark.world_size(2) -@pytest.mark.parametrize('moe_num_experts', [8]) +@pytest.mark.parametrize('moe_num_experts', [1, 2, 8]) @pytest.mark.parametrize('mlp_type', ['glu', 'mlp']) @pytest.mark.parametrize('moe_world_size', [1, 2]) +@pytest.mark.parametrize('moe_normalize_expert_weights', [1, 2.0]) @pytest.mark.parametrize('two_d_input', [True, False]) def test_dmoe( moe_num_experts: int, mlp_type: str, moe_world_size: int, + moe_normalize_expert_weights: Union[float, int], two_d_input: bool, ): + if moe_world_size > moe_num_experts or moe_num_experts % moe_world_size != 0: + pytest.skip('Mismatch between moe_world_size and moe_num_experts.') + moe_top_k = min(2, moe_num_experts) # Generate inputs rank = dist.get_rank() batch_size = 2 seq_len = 3 - hidden_size = 128 + hidden_size = 256 if two_d_input: input_shape = [batch_size * seq_len, hidden_size] else: @@ -92,10 +97,10 @@ def test_dmoe( common_args = { 'hidden_size': hidden_size, 'ffn_hidden_size': hidden_size, - 'moe_top_k': 2, + 'moe_top_k': moe_top_k, 'activation_fn': partial(F.gelu, approximate='none'), 'moe_jitter_eps': 0.0, # Disable randomiztion - 'moe_normalize_expert_weights': 1, + 'moe_normalize_expert_weights': moe_normalize_expert_weights, 'uniform_expert_assignment': False, 'bias': False, 'device': device, @@ -197,6 +202,82 @@ def test_dmoe( torch.testing.assert_close(torch_y, mb_y) +@pytest.mark.skipif( + not is_megablocks_imported, + reason='This test needs megablocks module', +) +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.parametrize('two_d_input', [True, False]) +def test_dmoe_defaults(two_d_input: bool,): + rank = dist.get_rank() + fp16 = False + bf16 = True + dtype = _get_torch_dtype(fp16, bf16) + + # Construct DDP torch dMoE. torch_dmoe does not currently support bias. + device = torch.device(f'cuda:{dist.get_rank()}') + common_args = { + 'device': device, + 'bias': False, + } + + torch_dmoe = dMoE(**common_args).to(device, dtype=dtype) + torch_dmoe = DDP( + torch_dmoe, + device_ids=[rank], + ) + torch_dmoe_optimizer = optim.SGD(torch_dmoe.parameters(), lr=0.1) + + # Construct TP MB dMoE + mp_dmoe_args = copy.deepcopy(common_args) + extra_args = { + 'fp16': fp16, + 'bf16': bf16, + 'init_method': partial(torch.nn.init.uniform_, a=-1.0, b=1.0), + } + + # Expert parallelism is not enabled by default + mp_dmoe_args.update(extra_args) + args = megablocks.layers.arguments.Arguments(**mp_dmoe_args,) + mb_dmoe = megablocks.layers.dmoe.dMoE(args).to(device) + mb_dmoe.router = DDP(mb_dmoe.router, device_ids=[rank]) + + mb_dmoe.experts = DDP(mb_dmoe.experts, device_ids=[rank]) + mb_dmoe_state_dict = get_model_state_dict( + mb_dmoe, + options=StateDictOptions(full_state_dict=True,), + ) + mb_dmoe_optimizer = optim.SGD(mb_dmoe.parameters(), lr=0.1) + + # Generate inputs based on hidden_size in megablocks arguments + batch_size = 2 + seq_len = 3 + hidden_size = args.hidden_size + if two_d_input: + input_shape = [batch_size * seq_len, hidden_size] + else: + input_shape = [batch_size, seq_len, hidden_size] + + x = _get_all_inputs(input_shape, dtype)[rank] + + # Load mb_dmoe state dict to torch dmoe + torch_dmoe.module.load_state_dict(mb_dmoe_state_dict, strict=True) + + # Run train_step check + torch_y = torch_dmoe(x) + mb_y = mb_dmoe(x) + + torch_y.sum().backward() + mb_y.sum().backward() + torch_dmoe_optimizer.step() + mb_dmoe_optimizer.step() + + torch_y = torch_dmoe(x) + mb_y = mb_dmoe(x) + torch.testing.assert_close(torch_y, mb_y) + + @pytest.mark.skipif( not is_megablocks_imported, reason='This test needs megablocks module',