diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 06b64101c3..cfe1172634 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -396,6 +396,7 @@ def __init__(self, config: MPTConfig): self.wte = SharedEmbedding( config.vocab_size, config.d_model, + padding_idx=config.pad_token_id, device=config.init_device, ) if self.learned_pos_emb: diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 180e7b894c..8ad6e77c57 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -224,6 +224,9 @@ def embedding_init( emb_init_fn_ = init_fn_ emb_init_fn_(module.weight) + if module.padding_idx is not None: + with torch.no_grad(): + module.weight[module.padding_idx].fill_(0) return True diff --git a/tests/models/utils/test_param_init_fns.py b/tests/models/utils/test_param_init_fns.py index 0eaf60c869..11d9fba430 100644 --- a/tests/models/utils/test_param_init_fns.py +++ b/tests/models/utils/test_param_init_fns.py @@ -199,3 +199,30 @@ def test_emb_init(emb_init_cfg: Optional[tuple[str, Union[int, list[int]]]]): emb_init_uniform_lim, ) == 2 and emb_init_uniform_lim[0] == emb_init_uniform_lim[1]: assert (model.emb.weight == emb_init_uniform_lim[0]).all() + + +@pytest.mark.parametrize( + 'padding_idx', + [0, 2], +) +def test_emb_padding_init(padding_idx: int,): + cfg: dict[str, Union[int, list[int]]] = { + 'vocab_size': 64, + 'in_features': 16, + 'n_layers': 2, + 'padding_idx': padding_idx, + 'emb_init_std': 5, + } + dict_cfg = om.create(cfg) + + model = nn.Embedding( + dict_cfg.vocab_size, + dict_cfg.in_features, + dict_cfg.padding_idx, + ) + + model.apply(partial(param_init_fns.get('kaiming_normal_'), **dict_cfg)) + assert isinstance(model, torch.nn.Embedding) + + if dict_cfg.get('emb_init_std') is not None: + assert (model.weight[padding_idx] == 0).all()