Skip to content

Commit

Permalink
clean up resolve_ffn_hidden_and_exp_ratio (#801)
Browse files Browse the repository at this point in the history
* remove superfulous return; add doc str

* pr cmts; add test
  • Loading branch information
vchiley authored Dec 13, 2023
1 parent 0797aa6 commit 5fdcc43
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 20 deletions.
26 changes: 18 additions & 8 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,21 @@
log = logging.getLogger(__name__)


def _resolve_ffn_hidden_and_exp_ratio(
def resolve_ffn_hidden_size(
d_model: int,
expansion_ratio: Union[int, float],
ffn_hidden_size: Optional[int] = None,
) -> tuple[Union[int, float], int]:
) -> int:
"""Resolve the hidden size of the feed-forward network.
Args:
d_model (int): The dimension of the input and output of the feed-forward network.
expansion_ratio (Union[int, float]): The expansion ratio of the feed-forward network.
ffn_hidden_size (Optional[int]): The hidden size of the feed-forward network.
Returns:
int: The hidden size of the feed-forward network.
"""
if ffn_hidden_size is not None:
log.info(
f'`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified.'
Expand All @@ -32,9 +42,9 @@ def _resolve_ffn_hidden_and_exp_ratio(
ffn_hidden_size = int(d_model * expansion_ratio)
if ffn_hidden_size != d_model * expansion_ratio:
raise ValueError(
f'`d_model * expansion_ratio` ({ffn_hidden_size}) must be an integer.'
f'`d_model * expansion_ratio` must be an integer ({d_model=}; {expansion_ratio=}; {d_model * expansion_ratio=}).'
)
return expansion_ratio, ffn_hidden_size
return ffn_hidden_size


class MPTMLP(nn.Module):
Expand All @@ -49,8 +59,8 @@ def __init__(
bias: bool = True,
):
super().__init__()
expansion_ratio, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio(
d_model, expansion_ratio, ffn_hidden_size)
ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio,
ffn_hidden_size)
self.fc_kwargs: dict[str, Any] = {
'bias': bias,
}
Expand Down Expand Up @@ -138,8 +148,8 @@ def build_ffn(
)
elif ffn_type == 'te_ln_mlp':
assert te is not None
_, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio(
d_model, expansion_ratio, ffn_hidden_size)
ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio,
ffn_hidden_size)
return te.LayerNormMLP(
hidden_size=d_model,
ffn_hidden_size=ffn_hidden_size,
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
d_model (int): The size of the embedding dimension of the model.
n_heads (int): The number of attention heads.
n_layers (int): The number of layers in the model.
expansion_ratio (int, float): The ratio of the up/down scale in the ffn.
expansion_ratio (Union[int, float]): The ratio of the up/down scale in the ffn.
max_seq_len (int): The maximum sequence length of the model.
vocab_size (int): The size of the vocabulary.
resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
Expand Down
38 changes: 27 additions & 11 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,14 +514,21 @@ def test_opt_wrapping():
@pytest.mark.parametrize('norm_type', NORM_CLASS_REGISTRY.keys())
@pytest.mark.parametrize('no_bias', [False, True])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool):
@pytest.mark.parametrize('expansion_ratio,ffn_hidden_size', [
(2, None),
(1.231, None),
(2, 128),
(2, 256),
])
def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool,
expansion_ratio: Union[int, float], ffn_hidden_size: int):
# Test that the config constructs the model as expected.
hf_config = MPTConfig(
init_device='cpu',
d_model=128,
n_heads=4,
n_layers=2,
expansion_ratio=2,
expansion_ratio=expansion_ratio,
max_seq_len=2048,
emb_pdrop=0.1,
resid_pdrop=0.2,
Expand All @@ -531,13 +538,24 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool):
norm_type=norm_type,
no_bias=no_bias,
tie_word_embeddings=tie_word_embeddings,
ffn_config={
'ffn_type': 'mptmlp',
'ffn_hidden_size': ffn_hidden_size,
},
)
if hf_config.d_model * hf_config.expansion_ratio != int(
hf_config.d_model * hf_config.expansion_ratio):
pytest.xfail('d_model * expansion_ratio must be an integer.')

mpt = MPTForCausalLM(hf_config)

assert mpt.config.d_model == 128
assert mpt.config.n_heads == 4
assert mpt.config.n_layers == 2
assert mpt.config.expansion_ratio == 2
if ffn_hidden_size is None:
assert mpt.config.expansion_ratio == expansion_ratio
else:
assert mpt.config.ffn_config['ffn_hidden_size'] == ffn_hidden_size
assert mpt.config.max_seq_len == 2048

assert mpt.transformer.wte.weight.shape == torch.Size(
Expand All @@ -551,21 +569,19 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool):
assert len(mpt.transformer.blocks) == 2

d_model = hf_config.d_model
if ffn_hidden_size is None:
ffn_hidden_size = int(hf_config.d_model * hf_config.expansion_ratio)
for block in mpt.transformer.blocks:
assert isinstance(block, MPTBlock)
assert block.norm_1.weight.shape == torch.Size([d_model])
assert block.norm_2 is not None
assert block.norm_2.weight.shape == torch.Size([d_model])
assert isinstance(block.ffn.up_proj, nn.Linear)
assert block.ffn.up_proj.weight.shape == torch.Size([
int(hf_config.d_model * hf_config.expansion_ratio),
hf_config.d_model
])
assert block.ffn.up_proj.weight.shape == torch.Size(
[ffn_hidden_size, hf_config.d_model])
assert isinstance(block.ffn.down_proj, nn.Linear)
assert block.ffn.down_proj.weight.shape == torch.Size([
hf_config.d_model,
int(hf_config.d_model * hf_config.expansion_ratio)
])
assert block.ffn.down_proj.weight.shape == torch.Size(
[hf_config.d_model, ffn_hidden_size])
assert block.resid_attn_dropout.p == 0.2
assert block.resid_ffn_dropout.p == 0.2

Expand Down

0 comments on commit 5fdcc43

Please sign in to comment.