From 877d80e98de4e5f362852d8f51118a2f9649f10c Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 21 Jun 2024 16:27:22 -0700 Subject: [PATCH] add docstring --- llmfoundry/models/mpt/configuration_mpt.py | 21 +++++++++++++++- llmfoundry/models/mpt/modeling_mpt.py | 29 ++++++++++++++-------- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 7ece197973..cec1fbbb79 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -118,6 +118,23 @@ def __init__( also be a dictionary that specifies the fc layer name and any kwargs for the fc layer. tie_word_embeddings (bool): Whether to tie the input embedding and output layers. use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks. + block_overrides: The allows for overriding default block configs for certain layers. This should contain two sub configs: order and overrides. order specifies the order of different kinds of layers (default refers to a layer that does not apply any overrides). For each kind of layer, specify the overrides in the overrides config. + Eg: + block_overrides: + order: + - - default + - 1 + - - reuse_kv_layer + - 1 + - - sliding_window_layer + - 1 + overrides: + sliding_window_layer: + attn_config: + sliding_window_size: 128 + reuse_kv_layer: + attn_config: + reuse_kv_layer_idx: -2 """ self.d_model = d_model self.n_heads = n_heads @@ -145,7 +162,9 @@ def __init__( self.init_config = init_config if init_config is not None else copy.deepcopy( init_config_defaults, ) - self.block_overrides = block_overrides if block_overrides is not None else {'order': [['default', 1], ], } + self.block_overrides = block_overrides if block_overrides is not None else { + 'order': [['default', 1],], + } if isinstance(fc_type, str): fc_type = {'name': fc_type} diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index ada644e9ba..cdda00f7ff 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -445,8 +445,7 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: self.kv_cache_layers = set() attn_modules_order_expanded = [] - for (name, - repetitions) in config['block_overrides']['order']: + for (name, repetitions) in config['block_overrides']['order']: if not isinstance(repetitions, int) or repetitions < 1: raise ValueError('repetitions should be a positive integer.') attn_modules_order_expanded.extend([name] * repetitions) @@ -469,12 +468,17 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: override_attn_config = override_config.get('attn_config', None) if override_attn_config is not None and 'reuse_kv_layer_idx' in override_attn_config: if override_attn_config['reuse_kv_layer_idx'] >= 0: - raise ValueError(f'The relative index of kv layer to reuse, {override_attn_config["reuse_kv_layer_idx"]=}, should be negative.') + raise ValueError( + f'The relative index of kv layer to reuse, {override_attn_config["reuse_kv_layer_idx"]=}, should be negative.', + ) reuse_kv_layer_idx = i + override_attn_config[ 'reuse_kv_layer_idx'] if reuse_kv_layer_idx < 0: - raise ValueError(f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.') - override_attn_config['reuse_kv_layer_idx'] = reuse_kv_layer_idx + raise ValueError( + f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.', + ) + override_attn_config['reuse_kv_layer_idx' + ] = reuse_kv_layer_idx self.kv_cache_layers.add(reuse_kv_layer_idx) new_block_args = self.override_block_args( @@ -498,12 +502,14 @@ def override_block_args( common_keys = override_config.keys() & block_args.keys() for k in common_keys: if type(override_config[k]) != type(block_args[k]): - raise ValueError(f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.') + raise ValueError( + f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.', + ) if isinstance(override_config[k], dict): new_block_args = self.override_block_args( - block_args[k], - override_config[k], - ) + block_args[k], + override_config[k], + ) else: new_block_args = override_config[k] return new_block_args @@ -786,7 +792,10 @@ def forward( layer_kv_cache_dict = {} prev_layer_key_value = None for b_idx, block in enumerate(self.blocks): - attn_block = block.attn if hasattr(block, 'attn') else block.norm_attn_norm.attn + attn_block = block.attn if hasattr( + block, + 'attn', + ) else block.norm_attn_norm.attn if attn_block.reuse_kv_layer_idx is not None: if block.reuse_kv_layer_idx not in layer_kv_cache_dict: raise KeyError(