Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jun 21, 2024
1 parent fcc28a1 commit 877d80e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 11 deletions.
21 changes: 20 additions & 1 deletion llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,23 @@ def __init__(
also be a dictionary that specifies the fc layer name and any kwargs for the fc layer.
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks.
block_overrides: The allows for overriding default block configs for certain layers. This should contain two sub configs: order and overrides. order specifies the order of different kinds of layers (default refers to a layer that does not apply any overrides). For each kind of layer, specify the overrides in the overrides config.
Eg:
block_overrides:
order:
- - default
- 1
- - reuse_kv_layer
- 1
- - sliding_window_layer
- 1
overrides:
sliding_window_layer:
attn_config:
sliding_window_size: 128
reuse_kv_layer:
attn_config:
reuse_kv_layer_idx: -2
"""
self.d_model = d_model
self.n_heads = n_heads
Expand Down Expand Up @@ -145,7 +162,9 @@ def __init__(
self.init_config = init_config if init_config is not None else copy.deepcopy(
init_config_defaults,
)
self.block_overrides = block_overrides if block_overrides is not None else {'order': [['default', 1], ], }
self.block_overrides = block_overrides if block_overrides is not None else {
'order': [['default', 1],],
}

if isinstance(fc_type, str):
fc_type = {'name': fc_type}
Expand Down
29 changes: 19 additions & 10 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,7 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList:
self.kv_cache_layers = set()

attn_modules_order_expanded = []
for (name,
repetitions) in config['block_overrides']['order']:
for (name, repetitions) in config['block_overrides']['order']:
if not isinstance(repetitions, int) or repetitions < 1:
raise ValueError('repetitions should be a positive integer.')
attn_modules_order_expanded.extend([name] * repetitions)
Expand All @@ -469,12 +468,17 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList:
override_attn_config = override_config.get('attn_config', None)
if override_attn_config is not None and 'reuse_kv_layer_idx' in override_attn_config:
if override_attn_config['reuse_kv_layer_idx'] >= 0:
raise ValueError(f'The relative index of kv layer to reuse, {override_attn_config["reuse_kv_layer_idx"]=}, should be negative.')
raise ValueError(
f'The relative index of kv layer to reuse, {override_attn_config["reuse_kv_layer_idx"]=}, should be negative.',
)
reuse_kv_layer_idx = i + override_attn_config[
'reuse_kv_layer_idx']
if reuse_kv_layer_idx < 0:
raise ValueError(f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.')
override_attn_config['reuse_kv_layer_idx'] = reuse_kv_layer_idx
raise ValueError(
f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.',
)
override_attn_config['reuse_kv_layer_idx'
] = reuse_kv_layer_idx
self.kv_cache_layers.add(reuse_kv_layer_idx)

new_block_args = self.override_block_args(
Expand All @@ -498,12 +502,14 @@ def override_block_args(
common_keys = override_config.keys() & block_args.keys()
for k in common_keys:
if type(override_config[k]) != type(block_args[k]):
raise ValueError(f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.')
raise ValueError(
f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.',
)
if isinstance(override_config[k], dict):
new_block_args = self.override_block_args(
block_args[k],
override_config[k],
)
block_args[k],
override_config[k],
)
else:
new_block_args = override_config[k]
return new_block_args
Expand Down Expand Up @@ -786,7 +792,10 @@ def forward(
layer_kv_cache_dict = {}
prev_layer_key_value = None
for b_idx, block in enumerate(self.blocks):
attn_block = block.attn if hasattr(block, 'attn') else block.norm_attn_norm.attn
attn_block = block.attn if hasattr(
block,
'attn',
) else block.norm_attn_norm.attn
if attn_block.reuse_kv_layer_idx is not None:
if block.reuse_kv_layer_idx not in layer_kv_cache_dict:
raise KeyError(
Expand Down

0 comments on commit 877d80e

Please sign in to comment.