Skip to content

Commit

Permalink
updt with descriptive var names
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Nov 9, 2023
1 parent c5391c3 commit 6f0eae3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
6 changes: 3 additions & 3 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
use_cache: bool = False,
init_config: Dict = init_config_defaults,
fc_type: str = 'torch',
tie_embd: bool = True,
tie_word_embeddings: bool = True,
verbose: Optional[int] = None,
**kwargs: Any,
):
Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(
---
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
tie_embd (bool): Whether to tie the input embedding and output layers.
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
"""
self.d_model = d_model
self.n_heads = n_heads
Expand All @@ -150,7 +150,7 @@ def __init__(
self.use_cache = use_cache
self.init_config = init_config
self.fc_type = fc_type
self.tie_embd = tie_embd
self.tie_word_embeddings = tie_word_embeddings
if verbose is not None:
warnings.warn(
DeprecationWarning(
Expand Down
21 changes: 11 additions & 10 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,13 @@ def __init__(self, config: MPTConfig):
])
self.norm_f = norm_class(config.d_model, device=config.init_device)

self.unembed = None
if config.tie_embd:
self.unembed = nn.Linear(config.d_model,
self.lm_head = None
if config.tie_word_embeddings is False:
self.lm_head = nn.Linear(config.d_model,
config.vocab_size,
bias=False,
device=config.init_device)
self.lm_head._fsdp_wrap = True

self.rope = config.attn_config['rope']
self.rope_impl = None
Expand Down Expand Up @@ -581,10 +582,6 @@ class MPTForCausalLM(MPTPreTrainedModel):

def __init__(self, config: MPTConfig):
super().__init__(config)
if not config.tie_word_embeddings:
raise ValueError(
'MPTForCausalLM only supports tied word embeddings')

log.info(f'Instantiating an MPTForCausalLM model from {__file__}')

self.transformer: MPTModel = MPTModel(config)
Expand Down Expand Up @@ -666,8 +663,8 @@ def forward(
)

out = outputs.last_hidden_state.to(self.transformer.wte.weight.device)
if self.transformer.unembed is not None:
logits = self.transformer.unembed(out)
if self.transformer.lm_head is not None:
logits = self.transformer.lm_head(out)
else:
# move outputs to same device as weights for token embedding
# needed to support HF `device_map`
Expand Down Expand Up @@ -867,7 +864,11 @@ def flops_per_batch(self, batch: Mapping) -> int:
# assume the backward pass is approximately 2x the forward pass

bs, msl = batch['input_ids'].shape[0:2]
params_flops_per_token = 2 * self.n_active_params
params = self.n_active_params
if self.model.transformer.config.tie_word_embeddings is False:
# embedding layers are lookup tables, therefore are not counted in the FLOP computation
params -= self.model.transformer.wte.weight.numel()
params_flops_per_token = 2 * params
params_flops_per_seq = params_flops_per_token * msl
attn_flops_per_seq = (self.model.config.n_layers * 2 * 2 *
(self.model.config.d_model * (msl**2)))
Expand Down

0 comments on commit 6f0eae3

Please sign in to comment.