Skip to content

Commit

Permalink
enable correct padding_idx for embedding layers (#1527)
Browse files Browse the repository at this point in the history
  • Loading branch information
gupta-abhay authored Sep 17, 2024
1 parent 83ab9c3 commit 0114f33
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
1 change: 1 addition & 0 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def __init__(self, config: MPTConfig):
self.wte = SharedEmbedding(
config.vocab_size,
config.d_model,
padding_idx=config.pad_token_id,
device=config.init_device,
)
if self.learned_pos_emb:
Expand Down
3 changes: 3 additions & 0 deletions llmfoundry/models/utils/param_init_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ def embedding_init(
emb_init_fn_ = init_fn_

emb_init_fn_(module.weight)
if module.padding_idx is not None:
with torch.no_grad():
module.weight[module.padding_idx].fill_(0)

return True

Expand Down
27 changes: 27 additions & 0 deletions tests/models/utils/test_param_init_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,30 @@ def test_emb_init(emb_init_cfg: Optional[tuple[str, Union[int, list[int]]]]):
emb_init_uniform_lim,
) == 2 and emb_init_uniform_lim[0] == emb_init_uniform_lim[1]:
assert (model.emb.weight == emb_init_uniform_lim[0]).all()


@pytest.mark.parametrize(
'padding_idx',
[0, 2],
)
def test_emb_padding_init(padding_idx: int,):
cfg: dict[str, Union[int, list[int]]] = {
'vocab_size': 64,
'in_features': 16,
'n_layers': 2,
'padding_idx': padding_idx,
'emb_init_std': 5,
}
dict_cfg = om.create(cfg)

model = nn.Embedding(
dict_cfg.vocab_size,
dict_cfg.in_features,
dict_cfg.padding_idx,
)

model.apply(partial(param_init_fns.get('kaiming_normal_'), **dict_cfg))
assert isinstance(model, torch.nn.Embedding)

if dict_cfg.get('emb_init_std') is not None:
assert (model.weight[padding_idx] == 0).all()

0 comments on commit 0114f33

Please sign in to comment.