Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable tie_word_embeddings config setting to enable / disable weight tied embeddings #728

Merged
merged 20 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
use_cache: bool = False,
init_config: Dict = init_config_defaults,
fc_type: str = 'torch',
tie_word_embeddings: bool = True,
verbose: Optional[int] = None,
**kwargs: Any,
):
Expand Down Expand Up @@ -128,6 +129,7 @@ def __init__(
---
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
"""
self.d_model = d_model
self.n_heads = n_heads
Expand Down Expand Up @@ -164,7 +166,11 @@ def __init__(
warnings.warn(
f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`'
)
super().__init__(**kwargs)
# tie_word_embeddings is set in Huggingface's PretrainedConfig __init__
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

self._validate_config()

Expand Down
72 changes: 52 additions & 20 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,11 @@ def __init__(self, config: MPTConfig):
log.debug(self)
log.debug(f'Using {self.config.init_config["name"]} initialization.')

def get_input_embeddings(self) -> nn.Embedding:
def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
return self.wte

def set_input_embeddings(self, value: nn.Embedding) -> None:
def set_input_embeddings(
self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
self.wte = value

@torch.no_grad()
Expand Down Expand Up @@ -574,14 +575,20 @@ class MPTForCausalLM(MPTPreTrainedModel):

def __init__(self, config: MPTConfig):
super().__init__(config)
if not config.tie_word_embeddings:
raise ValueError(
'MPTForCausalLM only supports tied word embeddings')

log.info(f'Instantiating an MPTForCausalLM model from {__file__}')

self.transformer: MPTModel = MPTModel(config)

self.lm_head = None
if not config.tie_word_embeddings:
self.lm_head = nn.Linear(
config.d_model,
config.vocab_size,
bias=False,
device=config.init_device,
)
self.lm_head._fsdp_wrap = True

for child in self.transformer.children():
if isinstance(child, torch.nn.ModuleList):
continue
Expand All @@ -602,19 +609,38 @@ def __init__(self, config: MPTConfig):
)
self.logit_scale = logit_scale

def get_input_embeddings(self) -> nn.Embedding:
return self.transformer.wte
def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
return self.transformer.get_input_embeddings()

def set_input_embeddings(
self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
self.transformer.wte = value
self.transformer.set_input_embeddings(value)

def get_output_embeddings(self) -> nn.Embedding:
return self.transformer.wte
def get_output_embeddings(
self) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]:
if self.lm_head is not None:
return self.lm_head
return self.transformer.get_input_embeddings()

def set_output_embeddings(
self, new_embeddings: Union[SharedEmbedding, nn.Embedding]) -> None:
self.transformer.wte = new_embeddings
self, new_embeddings: Union[SharedEmbedding, nn.Embedding,
nn.Linear]) -> None:
if self.lm_head is not None:
self.lm_head = new_embeddings
else:
if not isinstance(new_embeddings, (SharedEmbedding, nn.Embedding)):
raise ValueError(
'new_embeddings must be an instance of SharedEmbedding ' +
f'or nn.Embedding, but got {type(new_embeddings)}.')
warnings.warn(
'Using `set_output_embeddings` to set the embedding layer of ' +
'MPTForCausalLM with tied weights. Given weights are tied, ' +
'using `set_input_embeddings` is recommended over using ' +
'`set_output_embeddings`.')
self.transformer.set_input_embeddings(new_embeddings)

def tie_weights(self) -> None:
self.lm_head = None

def set_decoder(self, decoder: MPTModel) -> None:
self.transformer = decoder
Expand Down Expand Up @@ -658,12 +684,14 @@ def forward(
use_cache=use_cache,
)

# move outputs to same device as weights for token embedding
# needed to support HF `device_map`
logits = self.transformer.wte(
outputs.last_hidden_state.to(self.transformer.wte.weight.device),
True,
)
if self.lm_head is not None:
logits = self.lm_head(outputs.last_hidden_state)
else:
# move outputs to same device as weights for token embedding
# needed to support HF `device_map`
out = outputs.last_hidden_state
out = out.to(self.transformer.wte.weight.device)
logits = self.transformer.wte(out, True)

if self.logit_scale is not None:
if self.logit_scale == 0:
Expand Down Expand Up @@ -859,7 +887,11 @@ def flops_per_batch(self, batch: Mapping) -> int:
# assume the backward pass is approximately 2x the forward pass

bs, msl = batch['input_ids'].shape[0:2]
params_flops_per_token = 2 * self.n_active_params
params = self.n_active_params
if not self.model.transformer.config.tie_word_embeddings:
# embedding layers are lookup tables, therefore are not counted in the FLOP computation
params -= self.model.transformer.wte.weight.numel()
params_flops_per_token = 2 * params
params_flops_per_seq = params_flops_per_token * msl
attn_flops_per_seq = (self.model.config.n_layers * 2 * 2 *
(self.model.config.d_model * (msl**2)))
Expand Down
41 changes: 29 additions & 12 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,20 +248,21 @@ def test_callback_inits_with_defaults():

@pytest.mark.world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])
@pytest.mark.parametrize(
'model,tie_word_embeddings',
[('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)],
)
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize('log_to_mlflow', [True, False])
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)])
@patch('os.cpu_count', MagicMock(return_value=None))
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
fsdp_state_dict_type: Optional[str],
log_to_mlflow: bool,
hf_save_interval: str,
save_interval: str, max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int):
def test_huggingface_conversion_callback(
model: str, tmp_path: pathlib.Path, tie_word_embeddings: bool,
fsdp_state_dict_type: Optional[str], log_to_mlflow: bool,
hf_save_interval: str, save_interval: str, max_duration: str,
expected_hf_checkpoints: int, expected_normal_checkpoints: int):
delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))
Expand Down Expand Up @@ -298,9 +299,11 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
'attn_impl': 'torch',
},
'loss_fn': 'torch_crossentropy',
'tie_word_embeddings': tie_word_embeddings,
}
tokenizer_name = 'EleutherAI/gpt-neox-20b'
elif model == 'neo':
assert tie_word_embeddings is None
model_cfg = {
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': 'EleutherAI/gpt-neo-125M',
Expand All @@ -313,6 +316,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
}
tokenizer_name = 'EleutherAI/gpt-neo-125M'
elif model == 'llama2':
assert tie_word_embeddings is None
if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
pytest.skip(
'The CI cluster does not have access to the Llama models, so skip this test.'
Expand Down Expand Up @@ -489,19 +493,26 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
delete_transformers_cache()


@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])
def test_convert_and_generate(model: str, tmp_path: pathlib.Path):
@pytest.mark.parametrize(
'model,tie_word_embeddings',
[('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)],
)
def test_convert_and_generate(model: str, tie_word_embeddings: bool,
tmp_path: pathlib.Path):
delete_transformers_cache()

om_cfg = None
if model == 'mpt':
om_cfg = get_config(
conf_path='scripts/train/yamls/pretrain/testing.yaml')
om_cfg['tie_word_embeddings'] = tie_word_embeddings
elif model == 'neo':
assert tie_word_embeddings is None
om_cfg = get_config(
conf_path='scripts/train/yamls/pretrain/gpt-neo-125m.yaml')
om_cfg['model']['config_overrides']['hidden_size'] = 36
elif model == 'llama2':
assert tie_word_embeddings is None
if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
pytest.skip(
'The CI cluster does not have access to the Llama models, so skip this test.'
Expand Down Expand Up @@ -562,11 +573,14 @@ def test_convert_and_generate(model: str, tmp_path: pathlib.Path):


@pytest.mark.gpu
def test_convert_and_generate_triton(tmp_path: pathlib.Path):
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_convert_and_generate_triton(tie_word_embeddings: str,
tmp_path: pathlib.Path):
delete_transformers_cache()

cfg = get_config()
cfg['model']['init_device'] = 'cpu'
cfg['tie_word_embeddings'] = tie_word_embeddings
tokenizer = transformers.AutoTokenizer.from_pretrained(
'EleutherAI/gpt-neox-20b')
model = ComposerMPTCausalLM(cfg['model'], tokenizer)
Expand Down Expand Up @@ -602,7 +616,9 @@ def test_convert_and_generate_triton(tmp_path: pathlib.Path):
delete_transformers_cache()


def test_convert_and_generate_meta(tmp_path: pathlib.Path):
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_convert_and_generate_meta(tie_word_embeddings: str,
tmp_path: pathlib.Path):
delete_transformers_cache()

from composer.utils import dist
Expand All @@ -612,6 +628,7 @@ def test_convert_and_generate_meta(tmp_path: pathlib.Path):
om_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml')

om_cfg['model']['init_device'] = 'cpu'
om_cfg['tie_word_embeddings'] = tie_word_embeddings
tokenizer = transformers.AutoTokenizer.from_pretrained(
om_cfg.tokenizer.name)
original_model = COMPOSER_MODEL_REGISTRY[om_cfg['model'].name](
Expand Down
Loading
Loading