Skip to content

Commit

Permalink
Make fc_type a dict to pass fc kwargs through (#1201)
Browse files Browse the repository at this point in the history
* fc type as fict

* fc type as dict

* fc type as dict

* rework ffn fc config slightly

* rework ffn fc config slightly

* merged main

* merged main

* no circular imports

* Update llmfoundry/models/mpt/configuration_mpt.py

Co-authored-by: Daniel King <[email protected]>

* them configs

* yo

* linting man

* the deep copy

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
snarayan21 and dakinggg authored May 16, 2024
1 parent 3fe7f09 commit 38ae65b
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 79 deletions.
27 changes: 16 additions & 11 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Attention layers."""

import copy
import math
import warnings
from typing import Any, Dict, Optional, Tuple
Expand All @@ -18,6 +19,7 @@
attention_implementations,
)
from llmfoundry.models.layers.layer_builders import build_fc, build_norm
from llmfoundry.models.utils.config_defaults import fc_type_defaults

__all__ = [
'scaled_multihead_dot_product_attention',
Expand Down Expand Up @@ -410,7 +412,7 @@ def __init__(
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
Expand All @@ -429,6 +431,13 @@ def __init__(

self.head_dim = d_model // n_heads

# Usually, fc_type dict should be passed in through MPTBlock's __init__ function.
if fc_type is None:
fc_type = copy.deepcopy(fc_type_defaults)
fc_type['bias'] = bias
fc_type['device'] = device
fc_type_name = fc_type['name']

if self.kv_n_heads <= 0:
raise ValueError('kv_n_heads should be greater than zero.')

Expand All @@ -449,15 +458,11 @@ def __init__(
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop

fc_kwargs: dict[str, Any] = {
'bias': bias,
}
fc_kwargs['device'] = device
self.Wqkv = build_fc(
name=fc_type,
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim,
fc_kwargs=fc_kwargs,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
Expand All @@ -484,10 +489,10 @@ def __init__(
self.attn_fn = attention_implementations.get(self.attn_impl)

self.out_proj = build_fc(
name=fc_type,
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model,
fc_kwargs=fc_kwargs,
fc_kwargs=fc_type,
)
self.out_proj._is_residual = True

Expand Down Expand Up @@ -696,7 +701,7 @@ def __init__(
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
Expand Down Expand Up @@ -737,7 +742,7 @@ def __init__(
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
Expand Down
57 changes: 26 additions & 31 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""GPT Blocks used for the GPT Model."""

import copy
from typing import Any, Dict, Optional, Set, Tuple

import torch
Expand All @@ -14,6 +15,10 @@
build_ffn,
build_norm,
)
from llmfoundry.models.utils.config_defaults import (
attn_config_defaults,
fc_type_defaults,
)

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand All @@ -25,32 +30,6 @@
'FusedNormAttentionNorm',
]

attn_config_defaults: Dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
'attn_impl': 'flash',
'qk_ln': False,
'qk_gn': False,
'clip_qkv': None,
'softmax_scale': None,
'attn_uses_sequence_id': False,
'sliding_window_size': -1,
'alibi': False,
'alibi_bias_max': 8,
'rope': False,
'rope_theta': 10000,
'rope_impl': 'dail',
'rope_dail_config': {
'type': 'original',
'pos_idx_in_fp32': True,
'xpos_scale_base': 512,
},
'rope_hf_config': {
'type': 'no_scaling',
'factor': 1.0,
},
}


class MPTBlock(nn.Module):

Expand All @@ -63,7 +42,7 @@ def __init__(
ffn_config: Optional[Dict] = None,
resid_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
device: Optional[str] = None,
no_bias: bool = False,
use_pad_tok_in_ffn: bool = True,
Expand All @@ -73,15 +52,25 @@ def __init__(
attn_config = attn_config_defaults

if ffn_config is None:
ffn_config = {
self.ffn_config: dict[str, Any] = {
'ffn_type': 'mptmlp',
}
else:
self.ffn_config = ffn_config

if fc_type is None:
fc_type = copy.deepcopy(fc_type_defaults)
fc_type['bias'] = not no_bias
fc_type['device'] = device

self.ffn_config['fc_type'] = fc_type

self.fuse_norm_attn_norm = kwargs.get('fuse_norm_attn_norm', False)

del kwargs # unused, just to capture any extra args from the config
super().__init__()

ffn_type = ffn_config['ffn_type']
ffn_type = self.ffn_config['ffn_type']
ffn_has_norm = ffn_type in ffns_with_norm

if self.fuse_norm_attn_norm:
Expand Down Expand Up @@ -137,7 +126,7 @@ def __init__(
expansion_ratio=expansion_ratio,
device=device,
bias=not no_bias,
ffn_kwargs=ffn_config,
ffn_kwargs=self.ffn_config,
)

self.resid_attn_dropout = nn.Dropout(resid_pdrop)
Expand Down Expand Up @@ -240,7 +229,7 @@ def __init__(
args_to_exclude_in_attn_class: Set[str],
attn_config: Optional[Dict] = None,
ffn_has_norm: bool = False,
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
resid_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
device: Optional[str] = None,
Expand All @@ -251,6 +240,12 @@ def __init__(
assert attn_config is not None
assert isinstance(attn_config['attn_type'], str)

# Usually, fc_type dict should be passed in through MPTBlock's __init__ function.
if fc_type is None:
fc_type = copy.deepcopy(fc_type_defaults)
fc_type['bias'] = not no_bias
fc_type['device'] = device

# Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
attn_config_subset_for_attn_class = {
k: v
Expand Down
37 changes: 21 additions & 16 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from llmfoundry.models.layers.dmoe import dMoE
from llmfoundry.models.layers.layer_builders import build_fc
from llmfoundry.models.utils.config_defaults import fc_type_defaults

try:
import transformer_engine.pytorch as te
Expand Down Expand Up @@ -127,7 +128,7 @@ def __init__(
self,
d_model: int,
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
ffn_hidden_size: Optional[int] = None,
act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN,
device: Optional[str] = None,
Expand All @@ -139,24 +140,27 @@ def __init__(
expansion_ratio,
ffn_hidden_size,
)
self.fc_kwargs: dict[str, Any] = {
'bias': bias,
}

self.fc_kwargs['device'] = device
# Usually, fc_type dict should be passed in through MPTBlock's __init__ function.
if fc_type is None:
fc_type = fc_type_defaults
fc_type['bias'] = bias
fc_type['device'] = device
self.fc_type = fc_type
self.fc_type_name = self.fc_type['name']

self.up_proj = build_fc(
name=fc_type,
name=self.fc_type_name,
in_features=d_model,
out_features=ffn_hidden_size,
fc_kwargs=self.fc_kwargs,
fc_kwargs=self.fc_type,
)
self.act = act_fn
self.down_proj = build_fc(
name=fc_type,
name=self.fc_type_name,
in_features=ffn_hidden_size,
out_features=d_model,
fc_kwargs=self.fc_kwargs,
fc_kwargs=self.fc_type,
)
self.down_proj._is_residual = True

Expand All @@ -170,7 +174,7 @@ def __init__(
self,
d_model: int,
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
ffn_hidden_size: Optional[int] = None,
act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN,
device: Optional[str] = None,
Expand All @@ -185,11 +189,12 @@ def __init__(
device=device,
bias=bias,
)

self.gate_proj = build_fc(
name=fc_type,
name=self.fc_type_name,
in_features=d_model,
out_features=self.up_proj.out_features,
fc_kwargs=self.fc_kwargs,
fc_kwargs=self.fc_type,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -199,7 +204,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def build_mptglu(
d_model: int,
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
ffn_hidden_size: Optional[int] = None,
ffn_act_fn: Optional[dict] = None,
device: Optional[str] = None,
Expand All @@ -219,7 +224,7 @@ def build_mptglu(
def build_mptmlp(
d_model: int,
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
ffn_hidden_size: Optional[int] = None,
ffn_act_fn: Optional[dict] = None,
device: Optional[str] = None,
Expand All @@ -239,7 +244,7 @@ def build_mptmlp(
def build_te_ln_mlp(
d_model: int,
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
ffn_hidden_size: Optional[int] = None,
ffn_act_fn: Optional[dict] = None,
device: Optional[str] = None,
Expand Down Expand Up @@ -280,7 +285,7 @@ def build_torch_dmoe(
moe_normalize_expert_weights = kwargs.pop('moe_normalize_expert_weights')
uniform_expert_assignment = kwargs.pop('uniform_expert_assignment')

fc_type = kwargs.pop('fc_type', 'torch')
fc_type = kwargs.pop('fc_type', None)
del fc_type # Unused

if len(kwargs) > 0:
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/layers/layer_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def build_fc(
kwargs = {
'in_features': in_features,
'out_features': out_features,
**fc_kwargs,
**{k: v for k, v in fc_kwargs.items() if k != 'name'},
}

return construct_from_registry(
Expand Down
Loading

0 comments on commit 38ae65b

Please sign in to comment.