Skip to content

Commit

Permalink
Modularize components of megablocks layer builder (#1224)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored May 22, 2024
1 parent 77f9ab1 commit 001e7c3
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 39 deletions.
114 changes: 90 additions & 24 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def build_torch_dmoe(
)


def _mb_setup_args(
def mb_setup_args(
d_model: int,
expansion_ratio: Union[int, float],
ffn_hidden_size: Optional[int],
Expand All @@ -319,6 +319,21 @@ def _mb_setup_args(
bias: bool,
kwargs: dict[str, Any],
) -> tuple['megablocks.layers.arguments.Arguments', int, ProcessGroup]:
"""Setup the MegaBlocks args.
Args:
d_model (int): The dimension of the input and output of the FFN.
expansion_ratio (Union[int, float]): The expansion ratio of the FFN.
ffn_hidden_size (Optional[int]): The hidden size of the FFN.
ffn_act_fn (Optional[dict]): The activation function of the FFN.
device (Optional[str]): The device to run the FFN on.
bias (bool): Whether to include bias in the FFN.
kwargs (dict[str, Any]): Additional kwargs.
Returns:
tuple['megablocks.layers.arguments.Arguments', int, ProcessGroup]:
The MegaBlocks args, the MoE world size, and the expert parallel group.
"""
if megablocks is None:
raise RuntimeError(
'Requirements for megablocks not installed; see install instructions in `README.md`.',
Expand Down Expand Up @@ -350,18 +365,39 @@ def _mb_setup_args(
return args, moe_world_size, expert_parallel_group


def _patch_ffn_mb(
def attach_ffn_mb_args(
ffn: nn.Module,
moe_world_size: int,
expert_parallel_group: ProcessGroup,
device_mesh: DeviceMesh,
args: 'megablocks.layers.arguments.Arguments',
):
# Attach args to MLP directly for use in param_init_fn
"""Attach arguments used in parameter initialization to the FFN.
Args:
ffn (nn.Module): The FFN module.
expert_parallel_group (ProcessGroup): The expert parallel process group.
args (megablocks.layers.arguments.Arguments): The arguments for MegaBlocks.
"""
ffn.experts.mlp.hidden_size = args.ffn_hidden_size
ffn.experts.mlp.expert_parallel_group = expert_parallel_group
ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group


def set_ffn_device_mesh(
ffn: nn.Module,
moe_world_size: int,
device_mesh: DeviceMesh,
):
"""Sets the device mesh in FSDP kwargs.
Args:
ffn (nn.Module): The FFN module.
moe_world_size (int): The MoE world size.
device_mesh (DeviceMesh): The full device mesh.
Raises:
RuntimeError: If the device mesh is 3D.
ValueError: If the device mesh is not 2D or 3D.
"""
if moe_world_size > 1:
expert_mesh = device_mesh['expert_parallel']
expert_placements: List[Placement] = [Shard(0)]
Expand Down Expand Up @@ -389,6 +425,15 @@ def _patch_ffn_mb(
}


def moe_fused_init_setup(ffn: nn.Module,):
"""Attach the _stack_dim attribute to the FFN.
Args:
ffn (nn.Module): The FFN module.
"""
ffn.experts.mlp._stack_dim = 0


def build_mb_moe(
d_model: int,
expansion_ratio: Union[int, float],
Expand All @@ -403,7 +448,7 @@ def build_mb_moe(
'Requirements for megablocks not installed; see install instructions in `README.md`.',
)

args, moe_world_size, expert_parallel_group = _mb_setup_args(
args, moe_world_size, expert_parallel_group = mb_setup_args(
d_model=d_model,
expansion_ratio=expansion_ratio,
ffn_hidden_size=ffn_hidden_size,
Expand All @@ -415,21 +460,42 @@ def build_mb_moe(

ffn = megablocks.layers.moe.MoE(args)

# Fused initialization setup
# For param_init_fn, enables shape based init of stacked layers
ffn.experts.mlp._stack_dim = 0

_patch_ffn_mb(
moe_fused_init_setup(ffn=ffn,)
attach_ffn_mb_args(
ffn=ffn,
moe_world_size=moe_world_size,
expert_parallel_group=expert_parallel_group,
device_mesh=kwargs['device_mesh'],
args=args,
)
set_ffn_device_mesh(
ffn=ffn,
moe_world_size=moe_world_size,
device_mesh=kwargs['device_mesh'],
)

return ffn


def dmoe_fused_init_setup(
ffn: nn.Module,
args: 'megablocks.layers.arguments.Arguments',
moe_world_size: int,
):
"""Attach the _fused attribute to the dMoE model.
This is used for parameter initialization.
Args:
ffn (nn.Module): The FFN module.
args (megablocks.layers.arguments.Arguments): The arguments for MegaBlocks.
moe_world_size (int): The MoE world size.
"""
n_exp = min(1, args.moe_num_experts // moe_world_size)
ffn.experts.mlp._fused = (
0,
[(n + 1) * args.ffn_hidden_size for n in range(n_exp - 1)],
)


def build_mb_dmoe(
d_model: int,
expansion_ratio: Union[int, float],
Expand All @@ -444,7 +510,7 @@ def build_mb_dmoe(
'Requirements for megablocks not installed; see install instructions in `README.md`.',
)

args, moe_world_size, expert_parallel_group = _mb_setup_args(
args, moe_world_size, expert_parallel_group = mb_setup_args(
d_model=d_model,
expansion_ratio=expansion_ratio,
ffn_hidden_size=ffn_hidden_size,
Expand All @@ -456,21 +522,21 @@ def build_mb_dmoe(

ffn = megablocks.layers.dmoe.dMoE(args)

# Fused initialization setup
# For param_init_fn, enables shape based init of fused layers
n_exp = min(1, args.moe_num_experts // moe_world_size)
ffn.experts.mlp._fused = (
0,
[(n + 1) * args.ffn_hidden_size for n in range(n_exp - 1)],
)

_patch_ffn_mb(
dmoe_fused_init_setup(
ffn=ffn,
args=args,
moe_world_size=moe_world_size,
)
attach_ffn_mb_args(
ffn=ffn,
expert_parallel_group=expert_parallel_group,
device_mesh=kwargs['device_mesh'],
args=args,
)
set_ffn_device_mesh(
ffn=ffn,
moe_world_size=moe_world_size,
device_mesh=kwargs['device_mesh'],
)

return ffn

Expand Down
60 changes: 45 additions & 15 deletions llmfoundry/models/utils/config_moe_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
"""Helper function to configure MPT with MoEs."""

import inspect
from typing import Union
from typing import Callable, Optional, Union

import torch
from packaging import version
from torch import distributed
from torch.distributed._tensor import DeviceMesh

from llmfoundry.layers_registry import ffns_with_megablocks
from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size
Expand Down Expand Up @@ -64,11 +65,47 @@ def create_set_process_group(k: int):
return create_process_group_ranks(ranks)


def get_megablocks_device_mesh(
device_mesh_cfg: Optional[tuple[int]],
moe_world_size: int,
world_size: int,
) -> DeviceMesh:
"""Helper function to get the device mesh for MegaBlocks MoE.
Args:
device_mesh_cfg (Optional[tuple[int]]): The device mesh configuration specification.
moe_world_size (int): The MoE world size.
world_size (int): The world size.
Raises:
ValueError: If the device mesh configuration is not valid.
Returns:
The device mesh for MegaBlocks MoE.
"""
from torch.distributed._tensor.device_mesh import init_device_mesh

if device_mesh_cfg is None or len(device_mesh_cfg) == 1:
if device_mesh_cfg is not None:
world_size = device_mesh_cfg[0]
sharding_group_dim = world_size // moe_world_size
device_mesh = init_device_mesh(
'cuda',
(sharding_group_dim, moe_world_size),
mesh_dim_names=('weight_parallel', 'expert_parallel'),
)
else:
raise ValueError(f'{device_mesh_cfg=} must be length 1')

return device_mesh


def config_megablocks_moe_args(
ffn_config: dict,
d_model: int,
expansion_ratio: Union[int, float],
n_layers: int,
get_device_mesh: Callable,
) -> dict:
"""Configures `ffn_config` for MegaBlocks MoE.
Expand All @@ -80,6 +117,7 @@ def config_megablocks_moe_args(
d_model (int): Hidden size of the network.
expansion_ratio (Union[int, float]): Expansion ratio in FFN.
n_layers (int): Number of blocks used in the network.
get_device_mesh (Callable): Function to get the device mesh. Takes in the device mesh config and the MoE world size.
Returns:
ffn_config (dict): FFN configuration with MegaBlocks MoE configured.
Expand Down Expand Up @@ -112,26 +150,17 @@ def config_megablocks_moe_args(
'MoE world size > 1 is not supported in torch version {torch.__version__}<2.2.',
)

from torch.distributed._tensor.device_mesh import init_device_mesh

world_size = distributed.get_world_size()
if world_size < moe_world_size or world_size % moe_world_size:
raise ValueError(
f'Invalid world size configuration: {world_size=} and {moe_world_size=}',
)

# FSDP
if device_mesh_cfg is None or len(device_mesh_cfg) == 1:
if device_mesh_cfg is not None:
world_size = device_mesh_cfg[0]
sharding_group_dim = world_size // moe_world_size
device_mesh = init_device_mesh(
'cuda',
(sharding_group_dim, moe_world_size),
mesh_dim_names=('weight_parallel', 'expert_parallel'),
)
else:
raise ValueError(f'{device_mesh_cfg=} must be length 1')
device_mesh = get_device_mesh(
device_mesh_cfg=device_mesh_cfg,
moe_world_size=moe_world_size,
world_size=world_size,
)

ffn_config['moe_expert_model_parallelism'] = True
ffn_config['expert_parallel_group'] = device_mesh[
Expand Down Expand Up @@ -202,6 +231,7 @@ def config_moe_args(
d_model=d_model,
expansion_ratio=expansion_ratio,
n_layers=n_layers,
get_device_mesh=get_megablocks_device_mesh,
)
else:
raise ValueError(f'Invalid ffn_type ({ffn_config["ffn_type"]}).')

0 comments on commit 001e7c3

Please sign in to comment.