Skip to content

Commit

Permalink
FFN layer registry (#1095)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Apr 12, 2024
1 parent 3729ba3 commit cb0de4f
Show file tree
Hide file tree
Showing 20 changed files with 453 additions and 191 deletions.
4 changes: 1 addition & 3 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
flash_attn_fn, scaled_multihead_dot_product_attention)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.ffn import MPTMLP
from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
MPTForCausalLM, MPTModel, MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
Expand All @@ -37,9 +37,7 @@
'build_finetuning_dataloader',
'Seq2SeqFinetuningCollator',
'MPTBlock',
'FFN_CLASS_REGISTRY',
'MPTMLP',
'build_ffn',
'MPTConfig',
'MPTPreTrainedModel',
'MPTModel',
Expand Down
31 changes: 31 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,34 @@
entry_points=True,
description=_fc_description)

_ffns_description = (
'The ffns registry is used to register functions that build ffn layers.' +
'See ffn.py for examples.')
ffns = create_registry('llmfoundry',
'ffns',
generic_type=Callable,
entry_points=True,
description=_ffns_description)

_ffns_with_norm_description = (
'The ffns_with_norm registry is used to register functions that build ffn layers that apply a normalization layer.'
+ 'See ffn.py for examples.')
ffns_with_norm = create_registry('llmfoundry',
'ffns_with_norm',
generic_type=Callable,
entry_points=True,
description=_ffns_with_norm_description)

_ffns_with_megablocks_description = (
'The ffns_with_megablocks registry is used to register functions that build ffn layers using MegaBlocks.'
+ 'See ffn.py for examples.')
ffns_with_megablocks = create_registry(
'llmfoundry',
'ffns_with_megablocks',
generic_type=Callable,
entry_points=True,
description=_ffns_with_megablocks_description)

_attention_classes_description = (
'The attention_classes registry is used to register classes that implement attention layers. See '
+ 'attention.py for expected constructor signature.')
Expand All @@ -47,6 +75,9 @@

__all__ = [
'norms',
'ffns',
'ffns_with_norm',
'ffns_with_megablocks',
'attention_classes',
'attention_implementations',
'fcs',
Expand Down
4 changes: 1 addition & 3 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import *
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.ffn import MPTMLP
from llmfoundry.models.layers.norm import LPLayerNorm

__all__ = [
Expand All @@ -24,6 +24,4 @@
'MPTBlock',
'LPLayerNorm',
'SharedEmbedding',
'FFN_CLASS_REGISTRY',
'build_ffn',
]
23 changes: 13 additions & 10 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import torch
import torch.nn as nn

from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.layers_registry import ffns_with_norm
from llmfoundry.models.layers.layer_builders import (build_attention_layer,
build_norm)
build_ffn, build_norm)

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand Down Expand Up @@ -73,12 +73,15 @@ def __init__(
del kwargs # unused, just to capture any extra args from the config
super().__init__()

ffn_type = ffn_config['ffn_type']
ffn_has_norm = ffn_type in ffns_with_norm

if self.fuse_norm_attn_norm:
self.norm_attn_norm = FusedNormAttentionNorm(
d_model=d_model,
n_heads=n_heads,
attn_config=attn_config,
ffn_config=ffn_config,
ffn_has_norm=ffn_has_norm,
fc_type=fc_type,
resid_pdrop=resid_pdrop,
norm_type=norm_type,
Expand Down Expand Up @@ -116,21 +119,22 @@ def __init__(
},
)
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']],
'_has_norm', False):
if not ffn_has_norm:
self.norm_2 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
device=device,
)

self.ffn = build_ffn(
name=ffn_type,
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device,
bias=not no_bias,
**ffn_config,
ffn_kwargs=ffn_config,
)

self.resid_attn_dropout = nn.Dropout(resid_pdrop)
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
Expand Down Expand Up @@ -198,7 +202,7 @@ def __init__(
d_model: int,
n_heads: int,
attn_config: Optional[Dict] = None,
ffn_config: Optional[Dict] = None,
ffn_has_norm: bool = False,
fc_type: str = 'torch',
resid_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
Expand All @@ -208,7 +212,6 @@ def __init__(
):
super().__init__()
assert attn_config is not None
assert ffn_config is not None
assert isinstance(attn_config['attn_type'], str)

# necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
Expand Down Expand Up @@ -238,9 +241,9 @@ def __init__(
**attn_config_subset_for_attn_class
},
)

self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
False):
if not ffn_has_norm:
self.norm_2 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
Expand Down
51 changes: 36 additions & 15 deletions llmfoundry/models/layers/dmoe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Callable
from typing import Callable, Optional

import torch

Expand All @@ -24,7 +24,8 @@ class LearnedRouter(torch.nn.Module):

def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int,
moe_jitter_eps: float, moe_normalize_expert_weights: bool,
uniform_expert_assignment: bool, device: torch.device) -> None:
uniform_expert_assignment: bool,
device: Optional[torch.device]) -> None:
super().__init__()
self.hidden_size: int = hidden_size
self.moe_num_experts: int = moe_num_experts
Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(
ffn_hidden_size: int,
moe_num_experts: int,
activation_fn: Callable,
device: torch.device,
device: Optional[torch.device],
) -> None:
super().__init__()

Expand Down Expand Up @@ -117,9 +118,14 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:

class GLU(torch.nn.Module):

def __init__(self, hidden_size: int, ffn_hidden_size: int,
moe_num_experts: int, activation_fn: Callable,
device: torch.device):
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
moe_num_experts: int,
activation_fn: Callable,
device: Optional[torch.device],
):
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
Expand Down Expand Up @@ -157,9 +163,16 @@ def forward(self, x: torch.Tensor, expert_idx: torch.Tensor):

class DroplessMLP(torch.nn.Module):

def __init__(self, hidden_size: int, ffn_hidden_size: int, mlp_type: str,
moe_num_experts: int, activation_fn: Callable, bias: bool,
device: torch.device):
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
mlp_type: str,
moe_num_experts: int,
activation_fn: Callable,
bias: bool,
device: Optional[torch.device],
):
super().__init__()
self.moe_num_experts = moe_num_experts

Expand Down Expand Up @@ -209,12 +222,20 @@ def forward(self, x: torch.Tensor, scores: torch.Tensor,

class dMoE(torch.nn.Module):

def __init__(self, hidden_size: int, ffn_hidden_size: int,
moe_num_experts: int, moe_top_k: int, mlp_type: str,
activation_fn: Callable, moe_jitter_eps: float,
moe_normalize_expert_weights: bool,
uniform_expert_assignment: bool, bias: bool,
device: torch.device):
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
moe_num_experts: int,
moe_top_k: int,
mlp_type: str,
activation_fn: Callable,
moe_jitter_eps: float,
moe_normalize_expert_weights: bool,
uniform_expert_assignment: bool,
bias: bool,
device: Optional[torch.device],
):
super().__init__()

# Token router.
Expand Down
Loading

0 comments on commit cb0de4f

Please sign in to comment.