Skip to content

Commit

Permalink
Modularize backbone class and block creation (#1229)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored May 23, 2024
1 parent 9cc945c commit 9120c27
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,14 +369,7 @@ def __init__(self, config: MPTConfig):
self.mb_args = None
self.shift_labels = True

block_args = self.extract_block_args(config.to_dict())

self.blocks = nn.ModuleList([
MPTBlock(
device=config.init_device,
**block_args,
) for _ in range(config.n_layers)
])
self.blocks = self.construct_blocks(config=config,)

# Tag all modules in the transformer blocks with the corresponding block_idx and max_block_idx
for i, block in enumerate(self.blocks):
Expand Down Expand Up @@ -438,6 +431,24 @@ def __init__(self, config: MPTConfig):
log.debug(self)
log.debug(f'Using {self.config.init_config["name"]} initialization.')

def construct_blocks(self, config: MPTConfig) -> nn.ModuleList:
"""Construct the nn.ModuleList with the Transformer blocks.
Args:
config (MPTConfig): The configuration object.
Returns:
nn.ModuleList: The list of Transformer blocks.
"""
block_args = self.extract_block_args(config.to_dict())

return nn.ModuleList([
MPTBlock(
device=config.init_device,
**block_args,
) for _ in range(config.n_layers)
])

def extract_block_args(self, block_args: Dict[str, Any]) -> Dict[str, Any]:
"""Sets the block args."""
if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks:
Expand Down Expand Up @@ -788,7 +799,7 @@ def __init__(self, config: MPTConfig):
super().__init__(config)
log.info(f'Instantiating an MPTForCausalLM model from {__file__}')

self.transformer: MPTModel = MPTModel(config)
self.transformer: MPTModel = self.backbone_model_class(config)

self.lm_head = None
if not config.tie_word_embeddings:
Expand Down Expand Up @@ -820,6 +831,10 @@ def __init__(self, config: MPTConfig):
)
self.logit_scale = logit_scale

@property
def backbone_model_class(self) -> Type[MPTModel]:
return MPTModel

def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
return self.transformer.get_input_embeddings()

Expand Down

0 comments on commit 9120c27

Please sign in to comment.