diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index f6b36a3f94..d51558f04d 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -369,14 +369,7 @@ def __init__(self, config: MPTConfig): self.mb_args = None self.shift_labels = True - block_args = self.extract_block_args(config.to_dict()) - - self.blocks = nn.ModuleList([ - MPTBlock( - device=config.init_device, - **block_args, - ) for _ in range(config.n_layers) - ]) + self.blocks = self.construct_blocks(config=config,) # Tag all modules in the transformer blocks with the corresponding block_idx and max_block_idx for i, block in enumerate(self.blocks): @@ -438,6 +431,24 @@ def __init__(self, config: MPTConfig): log.debug(self) log.debug(f'Using {self.config.init_config["name"]} initialization.') + def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: + """Construct the nn.ModuleList with the Transformer blocks. + + Args: + config (MPTConfig): The configuration object. + + Returns: + nn.ModuleList: The list of Transformer blocks. + """ + block_args = self.extract_block_args(config.to_dict()) + + return nn.ModuleList([ + MPTBlock( + device=config.init_device, + **block_args, + ) for _ in range(config.n_layers) + ]) + def extract_block_args(self, block_args: Dict[str, Any]) -> Dict[str, Any]: """Sets the block args.""" if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: @@ -788,7 +799,7 @@ def __init__(self, config: MPTConfig): super().__init__(config) log.info(f'Instantiating an MPTForCausalLM model from {__file__}') - self.transformer: MPTModel = MPTModel(config) + self.transformer: MPTModel = self.backbone_model_class(config) self.lm_head = None if not config.tie_word_embeddings: @@ -820,6 +831,10 @@ def __init__(self, config: MPTConfig): ) self.logit_scale = logit_scale + @property + def backbone_model_class(self) -> Type[MPTModel]: + return MPTModel + def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]: return self.transformer.get_input_embeddings()