Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable tie_word_embeddings config setting to enable / disable weight tied embeddings #728

Merged
merged 20 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
use_cache: bool = False,
init_config: Dict = init_config_defaults,
fc_type: str = 'torch',
tie_word_embeddings: bool = True,
verbose: Optional[int] = None,
**kwargs: Any,
):
Expand Down Expand Up @@ -128,6 +129,7 @@ def __init__(
---
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
"""
self.d_model = d_model
self.n_heads = n_heads
Expand Down Expand Up @@ -164,6 +166,8 @@ def __init__(
warnings.warn(
f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`'
)
# tie_word_embeddings is set in Huggingface's PretrainedConfig __init__
kwargs['tie_word_embeddings'] = tie_word_embeddings
super().__init__(**kwargs)

self._validate_config()
Expand Down
64 changes: 44 additions & 20 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,11 @@ def __init__(self, config: MPTConfig):
log.debug(self)
log.debug(f'Using {self.config.init_config["name"]} initialization.')

def get_input_embeddings(self) -> nn.Embedding:
def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
return self.wte

def set_input_embeddings(self, value: nn.Embedding) -> None:
def set_input_embeddings(
self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
self.wte = value

@torch.no_grad()
Expand Down Expand Up @@ -574,14 +575,20 @@ class MPTForCausalLM(MPTPreTrainedModel):

def __init__(self, config: MPTConfig):
super().__init__(config)
if not config.tie_word_embeddings:
raise ValueError(
'MPTForCausalLM only supports tied word embeddings')

log.info(f'Instantiating an MPTForCausalLM model from {__file__}')

self.transformer: MPTModel = MPTModel(config)

self.lm_head = None
if config.tie_word_embeddings is False:
vchiley marked this conversation as resolved.
Show resolved Hide resolved
self.lm_head = nn.Linear(
config.d_model,
config.vocab_size,
bias=False,
device=config.init_device,
)
self.lm_head._fsdp_wrap = True

for child in self.transformer.children():
if isinstance(child, torch.nn.ModuleList):
continue
Expand All @@ -602,19 +609,30 @@ def __init__(self, config: MPTConfig):
)
self.logit_scale = logit_scale

def get_input_embeddings(self) -> nn.Embedding:
return self.transformer.wte
def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
return self.transformer.get_input_embeddings()

def set_input_embeddings(
self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
self.transformer.wte = value
self.transformer.set_input_embeddings(value)

def get_output_embeddings(self) -> nn.Embedding:
return self.transformer.wte
def get_output_embeddings(
self) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]:
return self.lm_head or self.transformer.get_input_embeddings()
vchiley marked this conversation as resolved.
Show resolved Hide resolved

def set_output_embeddings(
self, new_embeddings: Union[SharedEmbedding, nn.Embedding]) -> None:
self.transformer.wte = new_embeddings
self, new_embeddings: Union[SharedEmbedding, nn.Embedding,
nn.Linear]) -> None:
if self.lm_head is not None:
self.lm_head = new_embeddings
else:
assert isinstance(new_embeddings, (SharedEmbedding, nn.Embedding))
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
self.transformer.set_input_embeddings(new_embeddings)

def tie_weights(self) -> None:
if self.lm_head is not None:
del self.lm_head
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
self.lm_head = None

def set_decoder(self, decoder: MPTModel) -> None:
self.transformer = decoder
Expand Down Expand Up @@ -658,12 +676,14 @@ def forward(
use_cache=use_cache,
)

# move outputs to same device as weights for token embedding
# needed to support HF `device_map`
logits = self.transformer.wte(
outputs.last_hidden_state.to(self.transformer.wte.weight.device),
True,
)
if self.lm_head is not None:
logits = self.lm_head(outputs.last_hidden_state)
else:
# move outputs to same device as weights for token embedding
# needed to support HF `device_map`
out = outputs.last_hidden_state
out = out.to(self.transformer.wte.weight.device)
logits = self.transformer.wte(out, True)

if self.logit_scale is not None:
if self.logit_scale == 0:
Expand Down Expand Up @@ -859,7 +879,11 @@ def flops_per_batch(self, batch: Mapping) -> int:
# assume the backward pass is approximately 2x the forward pass

bs, msl = batch['input_ids'].shape[0:2]
params_flops_per_token = 2 * self.n_active_params
params = self.n_active_params
if self.model.transformer.config.tie_word_embeddings is False:
vchiley marked this conversation as resolved.
Show resolved Hide resolved
# embedding layers are lookup tables, therefore are not counted in the FLOP computation
params -= self.model.transformer.wte.weight.numel()
params_flops_per_token = 2 * params
params_flops_per_seq = params_flops_per_token * msl
attn_flops_per_seq = (self.model.config.n_layers * 2 * 2 *
(self.model.config.d_model * (msl**2)))
Expand Down
Loading