From 880447232661651e9cbb35ac03ca81b175b65d96 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 21 Jun 2024 13:07:57 -0700 Subject: [PATCH 01/69] [WIP] Allows interweaving of arbitrary kinds of 'attention' layers, like RNN, sliding window etc. --- llmfoundry/models/layers/attention.py | 15 ++- llmfoundry/models/layers/blocks.py | 12 +++ llmfoundry/models/mpt/modeling_mpt.py | 101 +++++++++++++++++++-- llmfoundry/models/utils/config_defaults.py | 3 + 4 files changed, 123 insertions(+), 8 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 9b34190edf..4a976e70c8 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -416,6 +416,7 @@ def __init__( device: Optional[str] = None, bias: bool = True, sliding_window_size: int = -1, + reuse_prev_layer_kv: bool = False, ): super().__init__() @@ -428,6 +429,7 @@ def __init__( self.n_heads = n_heads self.kv_n_heads = kv_n_heads self.sliding_window_size = sliding_window_size + self.reuse_prev_layer_kv = reuse_prev_layer_kv self.head_dim = d_model // n_heads @@ -507,9 +509,11 @@ def forward( needs_weights: bool = False, alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, + prev_layer_key_value: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: - query, key, value = self.get_qkv(x) + query, key, value = self.get_qkv(x, prev_layer_key_value) if rotary_emb_w_meta_info is not None: query, key, value = self._apply_rotary_embeddings( @@ -546,6 +550,7 @@ def forward( def get_qkv( self, x: torch.Tensor, + prev_layer_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Computes and returns the query, key, and value tensors. @@ -582,6 +587,14 @@ def get_qkv( query = self.q_ln(query).to(dtype).view(q_shape) key = self.k_ln(key).to(dtype).view(k_shape) + if self.reuse_prev_layer_kv: + # TODO: We still compute key and values in the code above, even if we end up reusing previous layer's kv cache. We should avoid this wasteful computation. + if prev_layer_key_value is None: + raise ValueError( + 'prev_layer_key_value is None, cannot reuse_prev_layer_kv.', + ) + key, value = prev_layer_key_value # TODO: We should not even compute key, value for this layer if we are just reusing prev layer's. Also, W_qkv should just be W_q for this case. + return query, key, value def _apply_rotary_embeddings( diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 59aa497b78..0a9829d3b1 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -48,6 +48,10 @@ def __init__( use_pad_tok_in_ffn: bool = True, **kwargs: Any, ): + self.reuse_kv_layer_idx = None if attn_config is None else attn_config.get( + 'reuse_kv_layer_idx', + None, + ) if attn_config is None: attn_config = attn_config_defaults @@ -145,6 +149,8 @@ def args_to_exclude_in_attn_class(self): 'rope_impl', 'rope_dail_config', 'rope_hf_config', + 'attention_modules', + 'reuse_kv_layer_idx', } def forward( @@ -158,6 +164,8 @@ def forward( output_attentions: bool = False, alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, + prev_layer_key_value: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: if self.fuse_norm_attn_norm: @@ -171,6 +179,7 @@ def forward( output_attentions=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + prev_layer_key_value=prev_layer_key_value, ) else: a = self.norm_1(x) @@ -308,6 +317,8 @@ def forward( output_attentions: bool = False, alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, + prev_layer_key_value: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) @@ -321,6 +332,7 @@ def forward( needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + prev_layer_key_value=prev_layer_key_value, ) x = x + self.resid_attn_dropout(b) m = x diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index a7cfac1724..a2260966a9 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -8,6 +8,7 @@ from __future__ import annotations +import copy import math import warnings from functools import cached_property @@ -440,13 +441,82 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: nn.ModuleList: The list of Transformer blocks. """ block_args = self.extract_block_args(config.to_dict()) + module_list = [] + self.kv_cache_layers = set() + + attn_modules_order_expanded = [] + for (name, + repetitions) in config.attn_config['attention_modules']['order']: + if name == 'order': + raise ValueError('The name of module cannot be "order".') + if not isinstance(repetitions, int) or repetitions < 1: + raise ValueError('repetitions should be a positive integer.') + attn_modules_order_expanded.extend([name] * repetitions) + if config.n_layers % len(attn_modules_order_expanded) != 0: + raise ValueError( + 'Number of layers should be divisible by attention modules order.', + ) + attn_modules_order_expanded = attn_modules_order_expanded * ( + config.n_layers // len(attn_modules_order_expanded) + ) - return nn.ModuleList([ - self.block_class( - device=config.init_device, - **block_args, - ) for _ in range(config.n_layers) - ]) + for i in range(config.n_layers): + module_name = attn_modules_order_expanded[i] + + override_config = {} + if module_name != 'default': + override_config = copy.deepcopy( + config.attn_config['attention_modules'][module_name], + ) + if 'reuse_kv_layer_idx' in override_config: + assert override_config['reuse_kv_layer_idx'] < 0 + reuse_kv_layer_idx = i + override_config['reuse_kv_layer_idx'] + assert reuse_kv_layer_idx >= 0 + override_config['reuse_kv_layer_idx'] = reuse_kv_layer_idx + self.kv_cache_layers.add(reuse_kv_layer_idx) + + orig_config, new_keys = self.update_block_args( + block_args, + override_config, + ) + module_list.append( + MPTBlock( + device=config.init_device, + **block_args, + ), + ) + self.reset_block_args( + block_args, + orig_config, + new_keys, + ) + return nn.ModuleList(module_list) + + def update_block_args( + self, + block_args: Dict[str, Any], + override_config: Dict[str, Any], + ) -> Tuple[Dict[str, Any], List[str]]: + orig_config = {} + new_keys = [] + for k, v in override_config.items(): + if k in block_args['attn_config']: + orig_config[k] = block_args['attn_config'][k] + else: + new_keys.append(k) + block_args['attn_config'][k] = v + return orig_config, new_keys + + def reset_block_args( + self, + block_args: Dict[str, Any], + orig_config: Dict[str, Any], + new_keys: List[str], + ) -> None: + for k in new_keys: + del block_args['attn_config'][k] + for k, v in orig_config.items(): + block_args['attn_config'][k] = v def extract_block_args(self, block_args: Dict[str, Any]) -> Dict[str, Any]: """Sets the block args.""" @@ -706,7 +776,7 @@ def forward( # initialize the past key values cache if it should be used presents = () if use_cache else None - if use_cache and past_key_values is None: + if past_key_values is None: past_key_values = [() for _ in range(self.config.n_layers) ] # type: ignore @@ -723,7 +793,12 @@ def forward( attention_mask, ) + layer_kv_cache_dict = {} for b_idx, block in enumerate(self.blocks): + if block.reuse_kv_layer_idx is not None: + if block.reuse_kv_layer_idx not in layer_kv_cache_dict: + raise KeyError(f'kv cache for layer {block.reuse_kv_layer_idx} not found in {layer_kv_cache_dict=}.') + prev_layer_key_value = layer_kv_cache_dict[block.reuse_kv_layer_idx] if output_hidden_states: assert all_hidden_states is not None # pyright all_hidden_states = all_hidden_states + (x,) @@ -740,9 +815,21 @@ def forward( output_attentions=bool(output_attentions), alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + prev_layer_key_value=prev_layer_key_value, ) if presents is not None: presents += (present,) + if b_idx in self.kv_cache_layers: + if self.attn_impl != 'torch': + layer_kv_cache_dict[b_idx] = [ + present[0][:, past_position:], + present[1][:, past_position:], + ] + else: + layer_kv_cache_dict[b_idx] = [ + present[0][:, :, :, past_position:], + present[1][:, :, :, past_position:], + ] if output_attentions: assert all_self_attns is not None # pyright diff --git a/llmfoundry/models/utils/config_defaults.py b/llmfoundry/models/utils/config_defaults.py index 2b6fc2f7c7..539962cf48 100644 --- a/llmfoundry/models/utils/config_defaults.py +++ b/llmfoundry/models/utils/config_defaults.py @@ -33,6 +33,9 @@ 'type': 'no_scaling', 'factor': 1.0, }, + 'attention_modules': { + 'order': [['default', 1]], + }, } init_config_defaults: Dict = { From a50755ec36c3646552256e68cf3595f453e2c246 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 21 Jun 2024 13:12:53 -0700 Subject: [PATCH 02/69] lint --- llmfoundry/models/mpt/modeling_mpt.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index a2260966a9..e2daca1e09 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -470,7 +470,8 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: ) if 'reuse_kv_layer_idx' in override_config: assert override_config['reuse_kv_layer_idx'] < 0 - reuse_kv_layer_idx = i + override_config['reuse_kv_layer_idx'] + reuse_kv_layer_idx = i + override_config[ + 'reuse_kv_layer_idx'] assert reuse_kv_layer_idx >= 0 override_config['reuse_kv_layer_idx'] = reuse_kv_layer_idx self.kv_cache_layers.add(reuse_kv_layer_idx) @@ -794,11 +795,15 @@ def forward( ) layer_kv_cache_dict = {} + prev_layer_key_value = None for b_idx, block in enumerate(self.blocks): if block.reuse_kv_layer_idx is not None: if block.reuse_kv_layer_idx not in layer_kv_cache_dict: - raise KeyError(f'kv cache for layer {block.reuse_kv_layer_idx} not found in {layer_kv_cache_dict=}.') - prev_layer_key_value = layer_kv_cache_dict[block.reuse_kv_layer_idx] + raise KeyError( + f'kv cache for layer {block.reuse_kv_layer_idx} not found in {layer_kv_cache_dict=}.', + ) + prev_layer_key_value = layer_kv_cache_dict[ + block.reuse_kv_layer_idx] if output_hidden_states: assert all_hidden_states is not None # pyright all_hidden_states = all_hidden_states + (x,) From fcc28a10b26e661db304487e0247d293c1f02a7c Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 21 Jun 2024 16:21:15 -0700 Subject: [PATCH 03/69] applying overrides to blocks rather than just attentions --- llmfoundry/models/layers/attention.py | 6 +- llmfoundry/models/layers/blocks.py | 6 -- llmfoundry/models/mpt/configuration_mpt.py | 2 + llmfoundry/models/mpt/modeling_mpt.py | 66 +++++++++------------- llmfoundry/models/utils/config_defaults.py | 3 - 5 files changed, 33 insertions(+), 50 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 4a976e70c8..02f370280b 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -416,7 +416,7 @@ def __init__( device: Optional[str] = None, bias: bool = True, sliding_window_size: int = -1, - reuse_prev_layer_kv: bool = False, + reuse_kv_layer_idx: Optional[int] = None, ): super().__init__() @@ -429,7 +429,7 @@ def __init__( self.n_heads = n_heads self.kv_n_heads = kv_n_heads self.sliding_window_size = sliding_window_size - self.reuse_prev_layer_kv = reuse_prev_layer_kv + self.reuse_kv_layer_idx = reuse_kv_layer_idx self.head_dim = d_model // n_heads @@ -587,7 +587,7 @@ def get_qkv( query = self.q_ln(query).to(dtype).view(q_shape) key = self.k_ln(key).to(dtype).view(k_shape) - if self.reuse_prev_layer_kv: + if self.reuse_kv_layer_idx != None: # TODO: We still compute key and values in the code above, even if we end up reusing previous layer's kv cache. We should avoid this wasteful computation. if prev_layer_key_value is None: raise ValueError( diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 0a9829d3b1..706395c02a 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -48,10 +48,6 @@ def __init__( use_pad_tok_in_ffn: bool = True, **kwargs: Any, ): - self.reuse_kv_layer_idx = None if attn_config is None else attn_config.get( - 'reuse_kv_layer_idx', - None, - ) if attn_config is None: attn_config = attn_config_defaults @@ -149,8 +145,6 @@ def args_to_exclude_in_attn_class(self): 'rope_impl', 'rope_dail_config', 'rope_hf_config', - 'attention_modules', - 'reuse_kv_layer_idx', } def forward( diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 9205c0e505..7ece197973 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -48,6 +48,7 @@ def __init__( fc_type: Union[str, Dict] = 'torch', tie_word_embeddings: bool = True, use_pad_tok_in_ffn: bool = True, + block_overrides: Optional[Dict[str, Any]] = None, **kwargs: Any, ): """The MPT configuration class. @@ -144,6 +145,7 @@ def __init__( self.init_config = init_config if init_config is not None else copy.deepcopy( init_config_defaults, ) + self.block_overrides = block_overrides if block_overrides is not None else {'order': [['default', 1], ], } if isinstance(fc_type, str): fc_type = {'name': fc_type} diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index e2daca1e09..ada644e9ba 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -446,9 +446,7 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: attn_modules_order_expanded = [] for (name, - repetitions) in config.attn_config['attention_modules']['order']: - if name == 'order': - raise ValueError('The name of module cannot be "order".') + repetitions) in config['block_overrides']['order']: if not isinstance(repetitions, int) or repetitions < 1: raise ValueError('repetitions should be a positive integer.') attn_modules_order_expanded.extend([name] * repetitions) @@ -466,58 +464,49 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: override_config = {} if module_name != 'default': override_config = copy.deepcopy( - config.attn_config['attention_modules'][module_name], + config['block_overrides']['overrides'][module_name], ) - if 'reuse_kv_layer_idx' in override_config: - assert override_config['reuse_kv_layer_idx'] < 0 - reuse_kv_layer_idx = i + override_config[ + override_attn_config = override_config.get('attn_config', None) + if override_attn_config is not None and 'reuse_kv_layer_idx' in override_attn_config: + if override_attn_config['reuse_kv_layer_idx'] >= 0: + raise ValueError(f'The relative index of kv layer to reuse, {override_attn_config["reuse_kv_layer_idx"]=}, should be negative.') + reuse_kv_layer_idx = i + override_attn_config[ 'reuse_kv_layer_idx'] - assert reuse_kv_layer_idx >= 0 - override_config['reuse_kv_layer_idx'] = reuse_kv_layer_idx + if reuse_kv_layer_idx < 0: + raise ValueError(f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.') + override_attn_config['reuse_kv_layer_idx'] = reuse_kv_layer_idx self.kv_cache_layers.add(reuse_kv_layer_idx) - orig_config, new_keys = self.update_block_args( + new_block_args = self.override_block_args( block_args, override_config, ) module_list.append( MPTBlock( device=config.init_device, - **block_args, + **new_block_args, ), ) - self.reset_block_args( - block_args, - orig_config, - new_keys, - ) return nn.ModuleList(module_list) - def update_block_args( + def override_block_args( self, block_args: Dict[str, Any], override_config: Dict[str, Any], - ) -> Tuple[Dict[str, Any], List[str]]: - orig_config = {} - new_keys = [] - for k, v in override_config.items(): - if k in block_args['attn_config']: - orig_config[k] = block_args['attn_config'][k] + ) -> Dict[str, Any]: + new_block_args = override_config | block_args + common_keys = override_config.keys() & block_args.keys() + for k in common_keys: + if type(override_config[k]) != type(block_args[k]): + raise ValueError(f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.') + if isinstance(override_config[k], dict): + new_block_args = self.override_block_args( + block_args[k], + override_config[k], + ) else: - new_keys.append(k) - block_args['attn_config'][k] = v - return orig_config, new_keys - - def reset_block_args( - self, - block_args: Dict[str, Any], - orig_config: Dict[str, Any], - new_keys: List[str], - ) -> None: - for k in new_keys: - del block_args['attn_config'][k] - for k, v in orig_config.items(): - block_args['attn_config'][k] = v + new_block_args = override_config[k] + return new_block_args def extract_block_args(self, block_args: Dict[str, Any]) -> Dict[str, Any]: """Sets the block args.""" @@ -797,7 +786,8 @@ def forward( layer_kv_cache_dict = {} prev_layer_key_value = None for b_idx, block in enumerate(self.blocks): - if block.reuse_kv_layer_idx is not None: + attn_block = block.attn if hasattr(block, 'attn') else block.norm_attn_norm.attn + if attn_block.reuse_kv_layer_idx is not None: if block.reuse_kv_layer_idx not in layer_kv_cache_dict: raise KeyError( f'kv cache for layer {block.reuse_kv_layer_idx} not found in {layer_kv_cache_dict=}.', diff --git a/llmfoundry/models/utils/config_defaults.py b/llmfoundry/models/utils/config_defaults.py index 539962cf48..2b6fc2f7c7 100644 --- a/llmfoundry/models/utils/config_defaults.py +++ b/llmfoundry/models/utils/config_defaults.py @@ -33,9 +33,6 @@ 'type': 'no_scaling', 'factor': 1.0, }, - 'attention_modules': { - 'order': [['default', 1]], - }, } init_config_defaults: Dict = { From 877d80e98de4e5f362852d8f51118a2f9649f10c Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 21 Jun 2024 16:27:22 -0700 Subject: [PATCH 04/69] add docstring --- llmfoundry/models/mpt/configuration_mpt.py | 21 +++++++++++++++- llmfoundry/models/mpt/modeling_mpt.py | 29 ++++++++++++++-------- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 7ece197973..cec1fbbb79 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -118,6 +118,23 @@ def __init__( also be a dictionary that specifies the fc layer name and any kwargs for the fc layer. tie_word_embeddings (bool): Whether to tie the input embedding and output layers. use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks. + block_overrides: The allows for overriding default block configs for certain layers. This should contain two sub configs: order and overrides. order specifies the order of different kinds of layers (default refers to a layer that does not apply any overrides). For each kind of layer, specify the overrides in the overrides config. + Eg: + block_overrides: + order: + - - default + - 1 + - - reuse_kv_layer + - 1 + - - sliding_window_layer + - 1 + overrides: + sliding_window_layer: + attn_config: + sliding_window_size: 128 + reuse_kv_layer: + attn_config: + reuse_kv_layer_idx: -2 """ self.d_model = d_model self.n_heads = n_heads @@ -145,7 +162,9 @@ def __init__( self.init_config = init_config if init_config is not None else copy.deepcopy( init_config_defaults, ) - self.block_overrides = block_overrides if block_overrides is not None else {'order': [['default', 1], ], } + self.block_overrides = block_overrides if block_overrides is not None else { + 'order': [['default', 1],], + } if isinstance(fc_type, str): fc_type = {'name': fc_type} diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index ada644e9ba..cdda00f7ff 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -445,8 +445,7 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: self.kv_cache_layers = set() attn_modules_order_expanded = [] - for (name, - repetitions) in config['block_overrides']['order']: + for (name, repetitions) in config['block_overrides']['order']: if not isinstance(repetitions, int) or repetitions < 1: raise ValueError('repetitions should be a positive integer.') attn_modules_order_expanded.extend([name] * repetitions) @@ -469,12 +468,17 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: override_attn_config = override_config.get('attn_config', None) if override_attn_config is not None and 'reuse_kv_layer_idx' in override_attn_config: if override_attn_config['reuse_kv_layer_idx'] >= 0: - raise ValueError(f'The relative index of kv layer to reuse, {override_attn_config["reuse_kv_layer_idx"]=}, should be negative.') + raise ValueError( + f'The relative index of kv layer to reuse, {override_attn_config["reuse_kv_layer_idx"]=}, should be negative.', + ) reuse_kv_layer_idx = i + override_attn_config[ 'reuse_kv_layer_idx'] if reuse_kv_layer_idx < 0: - raise ValueError(f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.') - override_attn_config['reuse_kv_layer_idx'] = reuse_kv_layer_idx + raise ValueError( + f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.', + ) + override_attn_config['reuse_kv_layer_idx' + ] = reuse_kv_layer_idx self.kv_cache_layers.add(reuse_kv_layer_idx) new_block_args = self.override_block_args( @@ -498,12 +502,14 @@ def override_block_args( common_keys = override_config.keys() & block_args.keys() for k in common_keys: if type(override_config[k]) != type(block_args[k]): - raise ValueError(f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.') + raise ValueError( + f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.', + ) if isinstance(override_config[k], dict): new_block_args = self.override_block_args( - block_args[k], - override_config[k], - ) + block_args[k], + override_config[k], + ) else: new_block_args = override_config[k] return new_block_args @@ -786,7 +792,10 @@ def forward( layer_kv_cache_dict = {} prev_layer_key_value = None for b_idx, block in enumerate(self.blocks): - attn_block = block.attn if hasattr(block, 'attn') else block.norm_attn_norm.attn + attn_block = block.attn if hasattr( + block, + 'attn', + ) else block.norm_attn_norm.attn if attn_block.reuse_kv_layer_idx is not None: if block.reuse_kv_layer_idx not in layer_kv_cache_dict: raise KeyError( From dd1c64b693d3bc764c2d674383dd588a0f9f8eae Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 21 Jun 2024 16:34:54 -0700 Subject: [PATCH 05/69] minor --- llmfoundry/models/mpt/configuration_mpt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index cec1fbbb79..dda189f2cf 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -130,11 +130,11 @@ def __init__( - 1 overrides: sliding_window_layer: - attn_config: - sliding_window_size: 128 + attn_config: + sliding_window_size: 128 reuse_kv_layer: - attn_config: - reuse_kv_layer_idx: -2 + attn_config: + reuse_kv_layer_idx: -2 """ self.d_model = d_model self.n_heads = n_heads From fc1bf0b7e9393b073891f769b2bd0194edec1bb3 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 21 Jun 2024 17:54:33 -0700 Subject: [PATCH 06/69] changing yaml specification style --- llmfoundry/models/mpt/configuration_mpt.py | 41 +++++++++++++++------- llmfoundry/models/mpt/modeling_mpt.py | 37 +++++++++++++------ 2 files changed, 56 insertions(+), 22 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index dda189f2cf..4833fa35d7 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -119,22 +119,36 @@ def __init__( tie_word_embeddings (bool): Whether to tie the input embedding and output layers. use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks. block_overrides: The allows for overriding default block configs for certain layers. This should contain two sub configs: order and overrides. order specifies the order of different kinds of layers (default refers to a layer that does not apply any overrides). For each kind of layer, specify the overrides in the overrides config. - Eg: + To specify this model (https://research.character.ai/optimizing-inference/) , the following config will be needed: block_overrides: - order: - - - default - - 1 - - - reuse_kv_layer - - 1 - - - sliding_window_layer - - 1 + prefix: + - name: default + repeat: 1 + pattern: + - name: sliding_window_layer + repeat: 1 + - name: sliding_window_layer_reuse + repeat: 1 + - name: sliding_window_layer + repeat: 1 + - name: sliding_window_layer_reuse + repeat: 2 + - name: reuse_kv_layer + repeat: 1 + suffix: + - name: default + repeat: 0 overrides: sliding_window_layer: attn_config: - sliding_window_size: 128 + sliding_window_size: 1024 + sliding_window_layer_reuse: + attn_config: + sliding_window_size: 1024 + reuse_kv_layer_idx: -1 reuse_kv_layer: attn_config: - reuse_kv_layer_idx: -2 + reuse_kv_layer_idx: -6 """ self.d_model = d_model self.n_heads = n_heads @@ -163,8 +177,11 @@ def __init__( init_config_defaults, ) self.block_overrides = block_overrides if block_overrides is not None else { - 'order': [['default', 1],], - } + 'repeating_pattern': [{ + 'name': 'default', + 'repeat': 1, + }], + } # TODO: Raise warning if using block overrides that this is experimental, and the yaml design may change in the future. if isinstance(fc_type, str): fc_type = {'name': fc_type} diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index cdda00f7ff..c0f409cf2a 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -444,18 +444,33 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: module_list = [] self.kv_cache_layers = set() - attn_modules_order_expanded = [] - for (name, repetitions) in config['block_overrides']['order']: - if not isinstance(repetitions, int) or repetitions < 1: - raise ValueError('repetitions should be a positive integer.') - attn_modules_order_expanded.extend([name] * repetitions) - if config.n_layers % len(attn_modules_order_expanded) != 0: + modules_order_expanded = {} + for type in 'start', 'repeating_pattern', 'end': + modules_order_expanded[type] = [] + for block in config['block_overrides'][type]: + if not isinstance(block['repetitions'], + int) or block['repetitions'] < 1: + raise ValueError( + 'repetitions should be a positive integer.', + ) + modules_order_expanded[type].extend([block['name']] * + block['repetitions']) + + start_len = len(modules_order_expanded['start']) + repeating_pattern_len = len(modules_order_expanded['repeating_pattern']) + end_len = len(modules_order_expanded['end']) + + if ( + config.n_layers - (start_len + end_len) + ) % repeating_pattern_len != 0: raise ValueError( - 'Number of layers should be divisible by attention modules order.', + 'Number of layers should be divisible by the specified custom modules order.', ) - attn_modules_order_expanded = attn_modules_order_expanded * ( - config.n_layers // len(attn_modules_order_expanded) - ) + attn_modules_order_expanded = modules_order_expanded[ + 'start'] + modules_order_expanded['repeating_pattern'] * ( + (config.n_layers - + (start_len + end_len)) // repeating_pattern_len + ) + modules_order_expanded['end'] for i in range(config.n_layers): module_name = attn_modules_order_expanded[i] @@ -1323,6 +1338,8 @@ def flops_per_batch(self, batch: Mapping): # that the dataset has been constructed without padding. Additionally, we # assume the backward pass is approximately 2x the forward pass + # TODO: Raise warning and set to 0 if using mixed attention modules. + bs, msl = batch['input_ids'].shape[0:2] params = self.n_active_params params_flops_per_token = 2 * params From 81e29305f7de14afa4910b5745e02f0eea53a163 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 21 Jun 2024 17:55:21 -0700 Subject: [PATCH 07/69] .. --- llmfoundry/models/mpt/configuration_mpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 4833fa35d7..726342c142 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -121,10 +121,10 @@ def __init__( block_overrides: The allows for overriding default block configs for certain layers. This should contain two sub configs: order and overrides. order specifies the order of different kinds of layers (default refers to a layer that does not apply any overrides). For each kind of layer, specify the overrides in the overrides config. To specify this model (https://research.character.ai/optimizing-inference/) , the following config will be needed: block_overrides: - prefix: + start: - name: default repeat: 1 - pattern: + repeating_pattern: - name: sliding_window_layer repeat: 1 - name: sliding_window_layer_reuse @@ -135,7 +135,7 @@ def __init__( repeat: 2 - name: reuse_kv_layer repeat: 1 - suffix: + end: - name: default repeat: 0 overrides: From 9b3d81384be39d1ebbaf43192593515ce606eddc Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 21 Jun 2024 18:46:37 -0700 Subject: [PATCH 08/69] fixes --- llmfoundry/models/mpt/modeling_mpt.py | 37 ++++++++++++++------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index c0f409cf2a..bc243acaac 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -441,20 +441,19 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: nn.ModuleList: The list of Transformer blocks. """ block_args = self.extract_block_args(config.to_dict()) - module_list = [] - self.kv_cache_layers = set() modules_order_expanded = {} for type in 'start', 'repeating_pattern', 'end': modules_order_expanded[type] = [] - for block in config['block_overrides'][type]: - if not isinstance(block['repetitions'], - int) or block['repetitions'] < 1: - raise ValueError( - 'repetitions should be a positive integer.', - ) - modules_order_expanded[type].extend([block['name']] * - block['repetitions']) + if type in config['block_overrides']: + for block in config['block_overrides'][type]: + if not isinstance(block['repeat'], + int) or block['repeat'] < 1: + raise ValueError( + 'repeat should be a positive integer.', + ) + modules_order_expanded[type].extend([block['name']] * + block['repeat']) start_len = len(modules_order_expanded['start']) repeating_pattern_len = len(modules_order_expanded['repeating_pattern']) @@ -466,14 +465,15 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: raise ValueError( 'Number of layers should be divisible by the specified custom modules order.', ) - attn_modules_order_expanded = modules_order_expanded[ + model_modules_order_expanded = modules_order_expanded[ 'start'] + modules_order_expanded['repeating_pattern'] * ( - (config.n_layers - - (start_len + end_len)) // repeating_pattern_len - ) + modules_order_expanded['end'] + config.n_layers - (start_len + end_len) + ) // repeating_pattern_len + modules_order_expanded['end'] + self.kv_cache_layers = set() + module_list = [] for i in range(config.n_layers): - module_name = attn_modules_order_expanded[i] + module_name = model_modules_order_expanded[i] override_config = {} if module_name != 'default': @@ -521,12 +521,12 @@ def override_block_args( f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.', ) if isinstance(override_config[k], dict): - new_block_args = self.override_block_args( + new_block_args[k] = self.override_block_args( block_args[k], override_config[k], ) else: - new_block_args = override_config[k] + new_block_args[k] = override_config[k] return new_block_args def extract_block_args(self, block_args: Dict[str, Any]) -> Dict[str, Any]: @@ -805,7 +805,6 @@ def forward( ) layer_kv_cache_dict = {} - prev_layer_key_value = None for b_idx, block in enumerate(self.blocks): attn_block = block.attn if hasattr( block, @@ -818,6 +817,8 @@ def forward( ) prev_layer_key_value = layer_kv_cache_dict[ block.reuse_kv_layer_idx] + else: + prev_layer_key_value = None if output_hidden_states: assert all_hidden_states is not None # pyright all_hidden_states = all_hidden_states + (x,) From aafcebbfdb925117fd922d322752bbcc288f9649 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 21 Jun 2024 18:50:10 -0700 Subject: [PATCH 09/69] fix --- llmfoundry/models/mpt/modeling_mpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index bc243acaac..768b515cc2 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -445,8 +445,8 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: modules_order_expanded = {} for type in 'start', 'repeating_pattern', 'end': modules_order_expanded[type] = [] - if type in config['block_overrides']: - for block in config['block_overrides'][type]: + if type in config.block_overrides: + for block in config.block_overrides[type]: if not isinstance(block['repeat'], int) or block['repeat'] < 1: raise ValueError( @@ -478,7 +478,7 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: override_config = {} if module_name != 'default': override_config = copy.deepcopy( - config['block_overrides']['overrides'][module_name], + config.block_overrides['overrides'][module_name], ) override_attn_config = override_config.get('attn_config', None) if override_attn_config is not None and 'reuse_kv_layer_idx' in override_attn_config: From b46756fc9c3302a75fe8d40a75a068b34d5434eb Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 21 Jun 2024 19:32:46 -0700 Subject: [PATCH 10/69] fix --- llmfoundry/models/mpt/modeling_mpt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 768b515cc2..0af9640fe0 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -467,8 +467,9 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: ) model_modules_order_expanded = modules_order_expanded[ 'start'] + modules_order_expanded['repeating_pattern'] * ( - config.n_layers - (start_len + end_len) - ) // repeating_pattern_len + modules_order_expanded['end'] + (config.n_layers - + (start_len + end_len)) // repeating_pattern_len + ) + modules_order_expanded['end'] self.kv_cache_layers = set() module_list = [] From ad6ba32041f5a8644684424bac08240b3a9835d4 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 21 Jun 2024 21:07:54 -0700 Subject: [PATCH 11/69] fix --- llmfoundry/models/layers/attention.py | 6 ++++-- llmfoundry/models/mpt/modeling_mpt.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 02f370280b..c1b53e8518 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -516,12 +516,14 @@ def forward( query, key, value = self.get_qkv(x, prev_layer_key_value) if rotary_emb_w_meta_info is not None: - query, key, value = self._apply_rotary_embeddings( + query, key_rotated, value_rotated = self._apply_rotary_embeddings( rotary_emb_w_meta_info, query, key, value, ) + if self.reuse_kv_layer_idx is None: # TODO: We should not rotate if using prev layer kv, that wastes computation + key, value = key_rotated, value_rotated extra_attn_kwargs = self.get_implementation_specific_args( attention_mask, @@ -587,7 +589,7 @@ def get_qkv( query = self.q_ln(query).to(dtype).view(q_shape) key = self.k_ln(key).to(dtype).view(k_shape) - if self.reuse_kv_layer_idx != None: + if self.reuse_kv_layer_idx is not None: # TODO: We still compute key and values in the code above, even if we end up reusing previous layer's kv cache. We should avoid this wasteful computation. if prev_layer_key_value is None: raise ValueError( diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 0af9640fe0..eaff673d61 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -448,9 +448,9 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: if type in config.block_overrides: for block in config.block_overrides[type]: if not isinstance(block['repeat'], - int) or block['repeat'] < 1: + int) or block['repeat'] < 0: raise ValueError( - 'repeat should be a positive integer.', + 'repeat should be a non-negative integer.', ) modules_order_expanded[type].extend([block['name']] * block['repeat']) @@ -812,12 +812,12 @@ def forward( 'attn', ) else block.norm_attn_norm.attn if attn_block.reuse_kv_layer_idx is not None: - if block.reuse_kv_layer_idx not in layer_kv_cache_dict: + if attn_block.reuse_kv_layer_idx not in layer_kv_cache_dict: raise KeyError( f'kv cache for layer {block.reuse_kv_layer_idx} not found in {layer_kv_cache_dict=}.', ) prev_layer_key_value = layer_kv_cache_dict[ - block.reuse_kv_layer_idx] + attn_block.reuse_kv_layer_idx] else: prev_layer_key_value = None if output_hidden_states: From 3ea79fd36ed6dd6077101e418ddd14b2eed7794a Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 22 Jun 2024 09:51:50 -0700 Subject: [PATCH 12/69] refactoring --- llmfoundry/models/layers/attention.py | 2 +- llmfoundry/models/mpt/configuration_mpt.py | 12 ++++++------ llmfoundry/models/mpt/modeling_mpt.py | 16 ++++++++++++++++ 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index c1b53e8518..9bfbcffce3 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -595,7 +595,7 @@ def get_qkv( raise ValueError( 'prev_layer_key_value is None, cannot reuse_prev_layer_kv.', ) - key, value = prev_layer_key_value # TODO: We should not even compute key, value for this layer if we are just reusing prev layer's. Also, W_qkv should just be W_q for this case. + key, value = prev_layer_key_value return query, key, value diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 726342c142..3f53a3df89 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -176,12 +176,12 @@ def __init__( self.init_config = init_config if init_config is not None else copy.deepcopy( init_config_defaults, ) - self.block_overrides = block_overrides if block_overrides is not None else { - 'repeating_pattern': [{ - 'name': 'default', - 'repeat': 1, - }], - } # TODO: Raise warning if using block overrides that this is experimental, and the yaml design may change in the future. + + if block_overrides is not None: + warnings.warn( + 'Warning, this is an experimental feature. The YAML design may change in the future.', + ) + self.block_overrides = block_overrides if isinstance(fc_type, str): fc_type = {'name': fc_type} diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index eaff673d61..b290d33f64 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -442,6 +442,21 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: """ block_args = self.extract_block_args(config.to_dict()) + if config.block_overrides is not None: + return self._construct_blocks_with_overrides(config, block_args) + + return nn.ModuleList([ + self.block_class( + device=config.init_device, + **block_args, + ) for _ in range(config.n_layers) + ]) + + def _construct_blocks_with_overrides( + self, + config: MPTConfig, + block_args: Dict[str, Any], + ) -> nn.ModuleList: modules_order_expanded = {} for type in 'start', 'repeating_pattern', 'end': modules_order_expanded[type] = [] @@ -507,6 +522,7 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: **new_block_args, ), ) + return nn.ModuleList(module_list) def override_block_args( From 13802cb9f37941641778a86f484a7e2de5c98ac1 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 22 Jun 2024 10:02:54 -0700 Subject: [PATCH 13/69] add warning --- llmfoundry/models/mpt/configuration_mpt.py | 2 +- llmfoundry/models/mpt/modeling_mpt.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 3f53a3df89..32819e4528 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -179,7 +179,7 @@ def __init__( if block_overrides is not None: warnings.warn( - 'Warning, this is an experimental feature. The YAML design may change in the future.', + 'Warning, block_overrides is an experimental feature. The YAML design may change in the future.', ) self.block_overrides = block_overrides diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index b290d33f64..55886c9aa8 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -1356,7 +1356,11 @@ def flops_per_batch(self, batch: Mapping): # that the dataset has been constructed without padding. Additionally, we # assume the backward pass is approximately 2x the forward pass - # TODO: Raise warning and set to 0 if using mixed attention modules. + if self.model.config.block_overrides is not None: + warnings.warn( + 'Warning, flop computation is not supported when using block overrides. Returning 0 flops per batch.', + ) + return 0 bs, msl = batch['input_ids'].shape[0:2] params = self.n_active_params From 9b6ae9c2ff4918efa94637d397b9f94e7b865e8a Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 22 Jun 2024 10:24:21 -0700 Subject: [PATCH 14/69] compute only query vector when reusing kv --- llmfoundry/models/layers/attention.py | 79 ++++++++++++++++++--------- 1 file changed, 52 insertions(+), 27 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 9bfbcffce3..01832fc823 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -460,18 +460,29 @@ def __init__( self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = attn_pdrop - self.Wqkv = build_fc( - name=fc_type_name, - in_features=self.d_model, - out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim, - fc_kwargs=fc_type, - ) - # for param init fn; enables shape based init of fused layers - fuse_splits = [ - i * self.head_dim - for i in range(1, self.n_heads + 2 * self.kv_n_heads) - ] - self.Wqkv._fused = (0, fuse_splits) + if self.reuse_kv_layer_idx is None: + self.Wqkv = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim, + fc_kwargs=fc_type, + ) + # for param init fn; enables shape based init of fused layers + fuse_splits = [ + i * self.head_dim + for i in range(1, self.n_heads + 2 * self.kv_n_heads) + ] + self.Wqkv._fused = (0, fuse_splits) + else: + self.Wq = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.d_model, + fc_kwargs=fc_type, + ) + # for param init fn; enables shape based init of fused layers + fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)] + self.Wq._fused = (0, fuse_splits) if self.qk_ln or self.qk_gn: norm_size = self.head_dim if qk_gn else d_model @@ -480,13 +491,14 @@ def __init__( normalized_shape=norm_size, device=device, ) - if qk_ln: - norm_size = self.head_dim * kv_n_heads - self.k_ln = build_norm( - name=norm_type.lower(), - normalized_shape=norm_size, - device=device, - ) + if self.reuse_kv_layer_idx is None: + if qk_ln: + norm_size = self.head_dim * kv_n_heads + self.k_ln = build_norm( + name=norm_type.lower(), + normalized_shape=norm_size, + device=device, + ) self.attn_fn = attention_implementations.get(self.attn_impl) @@ -564,6 +576,27 @@ def get_qkv( key (torch.Tensor): The key tensor. value (torch.Tensor): The value tensor. """ + if self.reuse_kv_layer_idx is not None: + if prev_layer_key_value is None: + raise ValueError( + 'prev_layer_key_value is None, cannot reuse_prev_layer_kv.', + ) + key, value = prev_layer_key_value + + query = self.Wq(x) + if self.clip_qkv: + query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv) + + if self.qk_ln or self.qk_gn: + # Applying layernorm to qk + q_shape = query.shape + if self.qk_gn: + b, s = query.shape[:2] + query = query.view(b, s, self.n_heads, -1) + dtype = query.dtype + query = self.q_ln(query).to(dtype).view(q_shape) + return query, key, value + qkv = self.Wqkv(x) if self.clip_qkv: @@ -589,14 +622,6 @@ def get_qkv( query = self.q_ln(query).to(dtype).view(q_shape) key = self.k_ln(key).to(dtype).view(k_shape) - if self.reuse_kv_layer_idx is not None: - # TODO: We still compute key and values in the code above, even if we end up reusing previous layer's kv cache. We should avoid this wasteful computation. - if prev_layer_key_value is None: - raise ValueError( - 'prev_layer_key_value is None, cannot reuse_prev_layer_kv.', - ) - key, value = prev_layer_key_value - return query, key, value def _apply_rotary_embeddings( From c774a4b026f18a2a007bc393fad8d67eac4d81ec Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 22 Jun 2024 10:49:27 -0700 Subject: [PATCH 15/69] refactor --- llmfoundry/models/layers/attention.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 01832fc823..2ba75a95ed 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -528,14 +528,12 @@ def forward( query, key, value = self.get_qkv(x, prev_layer_key_value) if rotary_emb_w_meta_info is not None: - query, key_rotated, value_rotated = self._apply_rotary_embeddings( + query, key, value = self._apply_rotary_embeddings( rotary_emb_w_meta_info, query, key, value, ) - if self.reuse_kv_layer_idx is None: # TODO: We should not rotate if using prev layer kv, that wastes computation - key, value = key_rotated, value_rotated extra_attn_kwargs = self.get_implementation_specific_args( attention_mask, @@ -631,6 +629,10 @@ def _apply_rotary_embeddings( key: torch.Tensor, value: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self.reuse_kv_layer_idx is not None: + orig_key, orig_value = key, value + key, value = torch.zeros_like(key), torch.zeros_like(value) + rotary_emb = rotary_emb_w_meta_info['rotary_emb'] seq_len = rotary_emb_w_meta_info['seq_len'] offset_info = rotary_emb_w_meta_info['offset_info'] @@ -642,6 +644,7 @@ def _apply_rotary_embeddings( value = value.view(bsz, seqlen, -1, self.head_dim) kv = torch.stack([key, value], dim=2) + # Note: Rotates in place (https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/flash_attn/layers/rotary.py#L429) query, kv = rotary_emb( query, kv, @@ -692,6 +695,8 @@ def _apply_rotary_embeddings( query = query.view(bsz, seqlen, -1) key = key.view(bsz, seqlen, -1) + if self.reuse_kv_layer_idx is not None: + return query, orig_key, orig_value return query, key, value def get_implementation_specific_args( From 8dee35e291e736f48e9326bd6305e15d407bfdaa Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 22 Jun 2024 12:09:29 -0700 Subject: [PATCH 16/69] fixing --- llmfoundry/models/layers/attention.py | 4 ++-- llmfoundry/models/mpt/modeling_mpt.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 2ba75a95ed..3458e99432 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -631,7 +631,7 @@ def _apply_rotary_embeddings( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if self.reuse_kv_layer_idx is not None: orig_key, orig_value = key, value - key, value = torch.zeros_like(key), torch.zeros_like(value) + key, value = torch.empty_like(key), torch.empty_like(value) rotary_emb = rotary_emb_w_meta_info['rotary_emb'] seq_len = rotary_emb_w_meta_info['seq_len'] @@ -696,7 +696,7 @@ def _apply_rotary_embeddings( query = query.view(bsz, seqlen, -1) key = key.view(bsz, seqlen, -1) if self.reuse_kv_layer_idx is not None: - return query, orig_key, orig_value + return query, orig_key, orig_value # type: ignore return query, key, value def get_implementation_specific_args( diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 55886c9aa8..bfef0700c3 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -441,6 +441,7 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: nn.ModuleList: The list of Transformer blocks. """ block_args = self.extract_block_args(config.to_dict()) + self.kv_cache_layers = None if config.block_overrides is not None: return self._construct_blocks_with_overrides(config, block_args) @@ -457,6 +458,10 @@ def _construct_blocks_with_overrides( config: MPTConfig, block_args: Dict[str, Any], ) -> nn.ModuleList: + if config.block_overrides is None: + raise ValueError( + 'config.block_overrides should not be None when calling _construct_blocks_with_overrides.', + ) modules_order_expanded = {} for type in 'start', 'repeating_pattern', 'end': modules_order_expanded[type] = [] @@ -512,7 +517,7 @@ def _construct_blocks_with_overrides( ] = reuse_kv_layer_idx self.kv_cache_layers.add(reuse_kv_layer_idx) - new_block_args = self.override_block_args( + new_block_args = self._override_block_args( block_args, override_config, ) @@ -525,7 +530,7 @@ def _construct_blocks_with_overrides( return nn.ModuleList(module_list) - def override_block_args( + def _override_block_args( self, block_args: Dict[str, Any], override_config: Dict[str, Any], @@ -538,7 +543,7 @@ def override_block_args( f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.', ) if isinstance(override_config[k], dict): - new_block_args[k] = self.override_block_args( + new_block_args[k] = self._override_block_args( block_args[k], override_config[k], ) @@ -856,7 +861,7 @@ def forward( ) if presents is not None: presents += (present,) - if b_idx in self.kv_cache_layers: + if self.kv_cache_layers is not None and b_idx in self.kv_cache_layers: if self.attn_impl != 'torch': layer_kv_cache_dict[b_idx] = [ present[0][:, past_position:], From 8ff15b4bf879f6be0f79a1246723db6dca34bde8 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 22 Jun 2024 17:54:29 -0700 Subject: [PATCH 17/69] adding test for reusing previous layer kv cache --- tests/models/layers/test_flash_torch.py | 221 ++++++++++++++++++++++++ 1 file changed, 221 insertions(+) diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index 669a6a93a1..dae044f03e 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -537,3 +537,224 @@ def test_grouped_query_invalid_heads(): with pytest.raises(ValueError, match=expected_error): _ = attention.GroupedQueryAttention(**cfg) + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'pos_emb_config', + [{ + 'alibi': True, + 'rope': False, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + }], +) +def test_reuse_prev_layer_kv_cache( + pos_emb_config: dict, + device: str = 'cuda', +): + """Checks reusing previous layer's kv cache.""" + alibi = pos_emb_config['alibi'] + rope = pos_emb_config['rope'] + if alibi and not ( + check_alibi_support('flash') and check_alibi_support('flash') + ): + pytest.skip('flash attention below v2.4.2 does not support alibi.') + if rope and (pos_emb_config['rope_impl'] + == 'dail') and (not is_flash_v2_installed()): + pytest.skip('dail implementation of rope requires flash attention 2.') + + if (not is_flash_v2_installed(v2_version='v2.1.2')): + pytest.skip( + 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.', + ) + + cfg = om.create({ + 'attn_impl': 'flash', + 'd_model': 64, + 'n_heads': 4, + 'attn_pdrop': 0, + 'clip_qkv': True, + }) + + n, s, f = 2, 4, cfg.d_model + assert cfg.d_model % cfg.n_heads == 0 + cfg.kv_n_heads = 2 + + sequence_id = torch.LongTensor([ + [0] * 2 + [1] * (s - 2), + [0] * 4 + [1] * (s - 4), + ]).to(device=device) + + # Computes its own kv cache + cfg.reuse_kv_layer_idx = None + attn0 = build_attention_layer( + name='grouped_query_attention', + attn_kwargs=om.to_container(cfg), # type: ignore + ).to(device) + + # Reuses layer 0's kv cache + cfg.reuse_kv_layer_idx = 0 + attn1 = build_attention_layer( + name='grouped_query_attention', + attn_kwargs=om.to_container(cfg), # type: ignore + ).to(device) + attn0_sd = attn0.state_dict() + attn0_sd['Wq.weight'] = attn0_sd['Wqkv.weight'][:cfg.d_model] + attn0_sd['Wq.bias'] = attn0_sd['Wqkv.bias'][:cfg.d_model] + del attn0_sd['Wqkv.weight'] + del attn0_sd['Wqkv.bias'] + attn1.load_state_dict(attn0_sd) + + attention_mask = torch.ones(n, s).to(device).bool() + + def gen_bias(attn_impl: str): + causal = True + attn_bias = None + bs = attention.attn_bias_shape( + attn_impl, + cfg.n_heads, + s, + alibi, + use_sequence_id=True, + causal=causal, + ) + if bs is not None: + attn_bias = torch.zeros(*bs, device=device) + attn_bias = attention.build_attn_bias( + attn_impl, + attn_bias, + cfg.n_heads, + s, + causal=causal, + alibi=alibi, + alibi_bias_max=8, + ) + + return attn_bias + + attention_mask_in_length = gen_attention_mask_in_length( + sequence_id=sequence_id, + S=s, + attn_uses_sequence_id=True, + attn_impl='flash', + attention_mask=attention_mask, + ) + + flash_attn_padding_info = gen_flash_attn_padding_info( + n, + s, + 0, + torch.device(device), + attention_mask_in_length, + attention_mask, + ) + + x0 = torch.randn(n, s, f).to(device) + x1 = x0.clone().detach() + x0.requires_grad = True + x1.requires_grad = True + + with torch.autocast(x0.device.type): + attn_bias_0 = gen_bias('flash') + alibi_slopes_0 = None + if alibi: + alibi_slopes_0 = gen_slopes( + n_heads=cfg.n_heads, + alibi_bias_max=8, + device=torch.device(device), + return_1d=True, + ) + rotary_emb_w_meta_info = None + if rope: + rotary_embedding = gen_rotary_embedding( + rope_head_dim=cfg.d_model // cfg.n_heads, + rope_impl=pos_emb_config['rope_impl'], + rope_theta=pos_emb_config['rope_theta'], + rope_dail_config=pos_emb_config.get('rope_dail_config', {}), + rope_hf_config=pos_emb_config.get('rope_hf_config', {}), + max_seq_len=s, + ).to(device) + pos = torch.arange(s).unsqueeze(0).to(device=device) + # adjust the position indices to account for padding tokens + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1), + min=0, + ) + rotary_emb_w_meta_info = { + 'impl': + pos_emb_config['rope_impl'], + 'rotary_emb': + rotary_embedding, + 'offset_info': + pos if (pos_emb_config['rope_impl'] == 'hf') else 0, + 'seq_len': + s, + } + + y0, _, prev_layer_key_value = attn0( + x0, + past_key_value=(), + attn_bias=attn_bias_0, + attention_mask=attention_mask, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info, + alibi_slopes=alibi_slopes_0, + ) + attn_bias_1 = gen_bias('flash') + alibi_slopes_1 = None + if alibi: + alibi_slopes_1 = gen_slopes( + n_heads=cfg.n_heads, + alibi_bias_max=8, + device=torch.device(device), + return_1d=True, + ) + + prev_layer_key_value = [ + t.clone().detach() for t in prev_layer_key_value + ] + y1, _, _ = attn1( + x1, + past_key_value=None, + attn_bias=attn_bias_1, + attention_mask=attention_mask, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info, + alibi_slopes=alibi_slopes_1, + prev_layer_key_value=prev_layer_key_value, + ) + y0 *= attention_mask.unsqueeze(-1) + y1 *= attention_mask.unsqueeze(-1) + + loss0 = y0.sum() + loss1 = y1.sum() + + loss0.backward() + loss1.backward() + assert allclose_helper(y0, y1) + + torch_name_param_map = dict(attn1.named_parameters()) + for n, p in attn0.named_parameters(): + if 'Wq' in n: + tp = torch_name_param_map[n.replace('Wqkv', 'Wq')] + assert p.grad is not None + assert tp.grad is not None + assert allclose_helper(p[:cfg.d_model], tp) + assert allclose_helper(p.grad[:cfg.d_model], tp.grad) + else: + tp = torch_name_param_map[n] + assert p.grad is not None + assert tp.grad is not None + assert allclose_helper(p, tp) + assert allclose_helper(p.grad, tp.grad) From 04e9888f2b1cd9852555d6e11369c4258a0a3988 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 22 Jun 2024 19:43:58 -0700 Subject: [PATCH 18/69] adding error messages --- llmfoundry/models/mpt/configuration_mpt.py | 23 ++++++++++++++++++++++ llmfoundry/models/mpt/modeling_mpt.py | 14 ++++--------- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 32819e4528..bded7eb56f 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -177,10 +177,33 @@ def __init__( init_config_defaults, ) + if 'reuse_kv_layer_idx' in self.attn_config and self.attn_config[ + 'attn_impl'] == 'torch': + raise NotImplementedError( + 'reusing kv cache from a previous layer is not implemented for torch attention.', + ) if block_overrides is not None: warnings.warn( 'Warning, block_overrides is an experimental feature. The YAML design may change in the future.', ) + if 'start' not in block_overrides and 'repeating_pattern' not in block_overrides and 'end' not in block_overrides: + raise ValueError( + 'either start, repeating_pattern, or end should be defined in block_overrides', + ) + if 'overrides' not in block_overrides: + raise ValueError( + 'overrides should be defined in block_overrides', + ) + for name, override in block_overrides['overrides'].items(): + if name == 'default': + raise ValueError( + 'block overrides cannot be named "default".', + ) + if 'attn_config' in override and 'reuse_kv_layer_idx' in override[ + 'attn_config'] and self.attn_config['attn_impl'] == 'torch': + raise NotImplementedError( + 'reusing kv cache from a previous layer is not implemented for torch attention.', + ) self.block_overrides = block_overrides if isinstance(fc_type, str): diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index bfef0700c3..618df54748 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -862,16 +862,10 @@ def forward( if presents is not None: presents += (present,) if self.kv_cache_layers is not None and b_idx in self.kv_cache_layers: - if self.attn_impl != 'torch': - layer_kv_cache_dict[b_idx] = [ - present[0][:, past_position:], - present[1][:, past_position:], - ] - else: - layer_kv_cache_dict[b_idx] = [ - present[0][:, :, :, past_position:], - present[1][:, :, :, past_position:], - ] + layer_kv_cache_dict[b_idx] = [ + present[0][:, past_position:], + present[1][:, past_position:], + ] if output_attentions: assert all_self_attns is not None # pyright From 5eee910429cd0a54b287ead3616a862252a79ca6 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 22 Jun 2024 20:41:24 -0700 Subject: [PATCH 19/69] .. --- llmfoundry/models/mpt/modeling_mpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 618df54748..a3926a92b8 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -444,7 +444,7 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: self.kv_cache_layers = None if config.block_overrides is not None: - return self._construct_blocks_with_overrides(config, block_args) + return self.construct_blocks_with_overrides(config, block_args) return nn.ModuleList([ self.block_class( @@ -453,14 +453,14 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: ) for _ in range(config.n_layers) ]) - def _construct_blocks_with_overrides( + def construct_blocks_with_overrides( self, config: MPTConfig, block_args: Dict[str, Any], ) -> nn.ModuleList: if config.block_overrides is None: raise ValueError( - 'config.block_overrides should not be None when calling _construct_blocks_with_overrides.', + 'config.block_overrides should not be None when calling construct_blocks_with_overrides.', ) modules_order_expanded = {} for type in 'start', 'repeating_pattern', 'end': From 2a6c9864ec583ab49e28a201f461b74383cd5c5d Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 22 Jun 2024 23:41:17 -0700 Subject: [PATCH 20/69] adding test --- llmfoundry/models/mpt/configuration_mpt.py | 9 +- llmfoundry/models/mpt/modeling_mpt.py | 35 ++++--- tests/models/test_model.py | 109 ++++++++++++++++++++- 3 files changed, 139 insertions(+), 14 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index bded7eb56f..42602bcac9 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -184,7 +184,7 @@ def __init__( ) if block_overrides is not None: warnings.warn( - 'Warning, block_overrides is an experimental feature. The YAML design may change in the future.', + 'block_overrides is an experimental feature. The YAML design may change in the future.', ) if 'start' not in block_overrides and 'repeating_pattern' not in block_overrides and 'end' not in block_overrides: raise ValueError( @@ -396,3 +396,10 @@ def _validate_config(self) -> None: ) self.validate_attention_config() + + @property + def allowed_block_overrides(self): + return { + 'sliding_window_size', + 'reuse_kv_layer_idx', + } diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index a3926a92b8..41ea6433ec 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -444,7 +444,7 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: self.kv_cache_layers = None if config.block_overrides is not None: - return self.construct_blocks_with_overrides(config, block_args) + return self._construct_blocks_with_overrides(config, block_args) return nn.ModuleList([ self.block_class( @@ -453,14 +453,14 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: ) for _ in range(config.n_layers) ]) - def construct_blocks_with_overrides( + def _construct_blocks_with_overrides( self, config: MPTConfig, block_args: Dict[str, Any], ) -> nn.ModuleList: if config.block_overrides is None: raise ValueError( - 'config.block_overrides should not be None when calling construct_blocks_with_overrides.', + 'config.block_overrides should not be None when calling _construct_blocks_with_overrides.', ) modules_order_expanded = {} for type in 'start', 'repeating_pattern', 'end': @@ -479,17 +479,23 @@ def construct_blocks_with_overrides( repeating_pattern_len = len(modules_order_expanded['repeating_pattern']) end_len = len(modules_order_expanded['end']) - if ( - config.n_layers - (start_len + end_len) - ) % repeating_pattern_len != 0: - raise ValueError( - 'Number of layers should be divisible by the specified custom modules order.', - ) - model_modules_order_expanded = modules_order_expanded[ - 'start'] + modules_order_expanded['repeating_pattern'] * ( + if repeating_pattern_len > 0: + if ( + config.n_layers - (start_len + end_len) + ) % repeating_pattern_len != 0: + raise ValueError( + 'Number of layers should be divisible by the specified custom modules order.', + ) + modules_order_expanded[ + 'repeating_pattern' + ] = modules_order_expanded['repeating_pattern'] * ( (config.n_layers - (start_len + end_len)) // repeating_pattern_len - ) + modules_order_expanded['end'] + ) + + model_modules_order_expanded = modules_order_expanded[ + 'start'] + modules_order_expanded['repeating_pattern' + ] + modules_order_expanded['end'] self.kv_cache_layers = set() module_list = [] @@ -520,6 +526,7 @@ def construct_blocks_with_overrides( new_block_args = self._override_block_args( block_args, override_config, + config.allowed_block_overrides, ) module_list.append( MPTBlock( @@ -534,6 +541,7 @@ def _override_block_args( self, block_args: Dict[str, Any], override_config: Dict[str, Any], + allowed_block_overrides: set, ) -> Dict[str, Any]: new_block_args = override_config | block_args common_keys = override_config.keys() & block_args.keys() @@ -546,8 +554,11 @@ def _override_block_args( new_block_args[k] = self._override_block_args( block_args[k], override_config[k], + allowed_block_overrides, ) else: + if k not in allowed_block_overrides: + raise KeyError(f'Overriding {k} is not supported.') new_block_args[k] = override_config[k] return new_block_args diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 2f93b1d3ce..cea59866a6 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -44,7 +44,7 @@ is_flash_v2_installed, ) from llmfoundry.models.layers.blocks import MPTBlock -from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM +from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM, MPTModel from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model from llmfoundry.utils.config_utils import to_dict_container @@ -2617,3 +2617,110 @@ def test_head_dim_8_flash_mqa_attn(batch_size: int = 2): output = model(batch) assert not torch.isnan(output.logits).any() + + +@pytest.mark.parametrize( + 'start', + [[], [{ + 'name': 'default', + 'repeat': 1, + }, { + 'name': 'layer_s', + 'repeat': 2, + }]], +) +@pytest.mark.parametrize( + 'repeating_pattern', + [[{ + 'name': 'layer_rp', + 'repeat': 1, + }]], +) +@pytest.mark.parametrize( + 'end', + [[], [{ + 'name': 'layer_e', + 'repeat': 2, + }, { + 'name': 'default', + 'repeat': 1, + }]], +) +def test_construct_blocks(start: list, repeating_pattern: list, end: list): + n_layers = 9 + config = MPTConfig( + d_model=32, + n_heads=16, + n_layers=n_layers, + expansion_ratio=2, + max_seq_len=64, + attn_config={ + 'attn_impl': 'flash', + 'attn_type': 'grouped_query_attention', + 'kv_n_heads': 4, + }, + ) + overrides = { + 'layer_s': { + 'attn_config': { + 'sliding_window_size': 1024, + 'reuse_kv_layer_idx': -1, + }, + }, + 'layer_rp': { + 'attn_config': { + 'sliding_window_size': 512, + }, + }, + 'layer_e': { + 'attn_config': { + 'sliding_window_size': 256, + 'reuse_kv_layer_idx': -2, + }, + }, + } + config.block_overrides = {} + + if len(start) > 0: + config.block_overrides['start'] = start + if len(repeating_pattern) > 0: + config.block_overrides['repeating_pattern'] = repeating_pattern + if len(end) > 0: + config.block_overrides['end'] = end + if len(start) > 0 or len(repeating_pattern) > 0 or len(end) > 0: + config.block_overrides['overrides'] = overrides + else: + pytest.skip(f'Skipping test because no overrides.') + + block_list = MPTModel(config).construct_blocks(config) + + assert len(block_list) == n_layers + + if len(start) > 0: + assert block_list[0].attn.sliding_window_size == -1 + assert block_list[0].attn.reuse_kv_layer_idx is None + assert block_list[1].attn.sliding_window_size == 1024 + assert block_list[1].attn.reuse_kv_layer_idx == 0 + assert block_list[2].attn.sliding_window_size == 1024 + assert block_list[2].attn.reuse_kv_layer_idx == 1 + else: + assert block_list[0].attn.sliding_window_size == 512 + assert block_list[0].attn.reuse_kv_layer_idx is None + + if len(end) > 0: + assert block_list[6].attn.sliding_window_size == 256 + assert block_list[6].attn.reuse_kv_layer_idx == 4 + assert block_list[7].attn.sliding_window_size == 256 + assert block_list[7].attn.reuse_kv_layer_idx == 5 + assert block_list[8].attn.sliding_window_size == -1 + assert block_list[8].attn.reuse_kv_layer_idx is None + else: + assert block_list[8].attn.sliding_window_size == 512 + assert block_list[8].attn.reuse_kv_layer_idx is None + + assert block_list[3].attn.sliding_window_size == 512 + assert block_list[3].attn.reuse_kv_layer_idx is None + assert block_list[4].attn.sliding_window_size == 512 + assert block_list[4].attn.reuse_kv_layer_idx is None + assert block_list[5].attn.sliding_window_size == 512 + assert block_list[5].attn.reuse_kv_layer_idx is None From 7bf89f2680f62a59ede6d3aca4f221c3e3fcf762 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sun, 23 Jun 2024 10:20:36 -0700 Subject: [PATCH 21/69] add logging --- llmfoundry/models/mpt/modeling_mpt.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 41ea6433ec..208a2fcbf0 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -522,7 +522,9 @@ def _construct_blocks_with_overrides( override_attn_config['reuse_kv_layer_idx' ] = reuse_kv_layer_idx self.kv_cache_layers.add(reuse_kv_layer_idx) - + log.info( + f'Layer: {i}. Name: {module_name}. Overrides: {self._get_overrides_for_logging(override_config)}', + ) new_block_args = self._override_block_args( block_args, override_config, @@ -537,6 +539,18 @@ def _construct_blocks_with_overrides( return nn.ModuleList(module_list) + def _get_overrides_for_logging( + self, + override_config: Dict[str, Any], + ) -> List[dict[str, str]]: + overrides_list = [] + for k, v in override_config: + if isinstance(v, dict): + overrides_list.extend(self._get_overrides_for_logging(v)) + else: + overrides_list.append({k: v}) + return overrides_list + def _override_block_args( self, block_args: Dict[str, Any], From dcc5cc0e33ea4166504607ddbf7cacc3c1202756 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sun, 23 Jun 2024 10:50:12 -0700 Subject: [PATCH 22/69] adding logging --- llmfoundry/models/mpt/modeling_mpt.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 208a2fcbf0..42bd60d1d6 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -29,6 +29,7 @@ import torch.nn.functional as F from composer.models import HuggingFaceModel from composer.utils import dist +from tabulate import tabulate from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.attention import is_flash_v2_installed @@ -499,6 +500,7 @@ def _construct_blocks_with_overrides( self.kv_cache_layers = set() module_list = [] + layer_description_list = [] for i in range(config.n_layers): module_name = model_modules_order_expanded[i] @@ -522,9 +524,11 @@ def _construct_blocks_with_overrides( override_attn_config['reuse_kv_layer_idx' ] = reuse_kv_layer_idx self.kv_cache_layers.add(reuse_kv_layer_idx) - log.info( - f'Layer: {i}. Name: {module_name}. Overrides: {self._get_overrides_for_logging(override_config)}', - ) + layer_description_list.append([ + i, + module_name, + self._get_overrides_for_logging(override_config), + ],) new_block_args = self._override_block_args( block_args, override_config, @@ -536,7 +540,12 @@ def _construct_blocks_with_overrides( **new_block_args, ), ) - + log.info( + 'The following is a summary of overrides per layer.\n' + tabulate( + layer_description_list, + headers=['idx', 'name', 'overrides'], + ), + ) return nn.ModuleList(module_list) def _get_overrides_for_logging( @@ -544,7 +553,7 @@ def _get_overrides_for_logging( override_config: Dict[str, Any], ) -> List[dict[str, str]]: overrides_list = [] - for k, v in override_config: + for k, v in override_config.items(): if isinstance(v, dict): overrides_list.extend(self._get_overrides_for_logging(v)) else: From 06d03c1e69e6e4eadabd407dc4d7e6ed52009c4e Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sun, 23 Jun 2024 11:09:15 -0700 Subject: [PATCH 23/69] minor --- llmfoundry/models/mpt/modeling_mpt.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 42bd60d1d6..788a3830de 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -498,7 +498,6 @@ def _construct_blocks_with_overrides( 'start'] + modules_order_expanded['repeating_pattern' ] + modules_order_expanded['end'] - self.kv_cache_layers = set() module_list = [] layer_description_list = [] for i in range(config.n_layers): @@ -523,6 +522,8 @@ def _construct_blocks_with_overrides( ) override_attn_config['reuse_kv_layer_idx' ] = reuse_kv_layer_idx + if self.kv_cache_layers is None: + self.kv_cache_layers = set() self.kv_cache_layers.add(reuse_kv_layer_idx) layer_description_list.append([ i, @@ -843,7 +844,9 @@ def forward( # initialize the past key values cache if it should be used presents = () if use_cache else None - if past_key_values is None: + if ( + use_cache or self.kv_cache_layers is not None + ) and past_key_values is None: past_key_values = [() for _ in range(self.config.n_layers) ] # type: ignore From ec42e725944f33e77c57d9cd99caf5f1a1c9c066 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sun, 23 Jun 2024 13:06:57 -0700 Subject: [PATCH 24/69] bug fix, adding test --- llmfoundry/models/layers/attention.py | 4 ++ llmfoundry/models/layers/blocks.py | 1 + llmfoundry/models/mpt/modeling_mpt.py | 6 ++- tests/models/test_model.py | 66 ++++++++++++++++++++++++++- 4 files changed, 75 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 3458e99432..bf645881af 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -750,6 +750,7 @@ def __init__( device: Optional[str] = None, bias: bool = True, sliding_window_size: int = -1, + reuse_kv_layer_idx: Optional[int] = None, ): super().__init__( d_model=d_model, @@ -766,6 +767,7 @@ def __init__( device=device, bias=bias, sliding_window_size=sliding_window_size, + reuse_kv_layer_idx=reuse_kv_layer_idx, ) @@ -791,6 +793,7 @@ def __init__( device: Optional[str] = None, bias: bool = True, sliding_window_size: int = -1, + reuse_kv_layer_idx: Optional[int] = None, ): super().__init__( d_model=d_model, @@ -807,6 +810,7 @@ def __init__( device=device, bias=bias, sliding_window_size=sliding_window_size, + reuse_kv_layer_idx=reuse_kv_layer_idx, ) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 706395c02a..71bd3be5c3 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -187,6 +187,7 @@ def forward( needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + prev_layer_key_value=prev_layer_key_value, ) x = x + self.resid_attn_dropout(b) m = x diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 788a3830de..d772a7fad4 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -500,9 +500,13 @@ def _construct_blocks_with_overrides( module_list = [] layer_description_list = [] + if len(model_modules_order_expanded) != config.n_layers: + raise ValueError( + f'The specified block overrides do not match the number of layers: {len(model_modules_order_expanded)} vs {config.n_layers}.', + ) + for i in range(config.n_layers): module_name = model_modules_order_expanded[i] - override_config = {} if module_name != 'default': override_config = copy.deepcopy( diff --git a/tests/models/test_model.py b/tests/models/test_model.py index cea59866a6..151dec0c4e 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -72,12 +72,17 @@ def _load_tokenizer_cfg(cfg: Union[Dict[str, Any], DictConfig]) -> Dict: def _get_objs( request: pytest.FixtureRequest, conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml', + model_config_overrides: Optional[Dict] = None, + attn_impl: str = 'torch', ): warnings.filterwarnings( action='ignore', message='Torchmetrics v0.9 introduced a new argument class property', ) test_cfg = get_config(conf_path=conf_path) + if model_config_overrides is not None: + for k, v in model_config_overrides.items(): + test_cfg.model[k] = v # Read FSDP Config as a dict fsdp_config = test_cfg.get('fsdp_config', None) @@ -97,7 +102,7 @@ def _get_objs( device = 'cuda' if is_gpu else 'cpu' test_cfg.precision = 'amp_bf16' if is_gpu else 'fp32' test_cfg.model.attn_config = { - 'attn_impl': 'torch', + 'attn_impl': attn_impl, } test_cfg.model.init_device = device test_cfg.device = device @@ -2724,3 +2729,62 @@ def test_construct_blocks(start: list, repeating_pattern: list, end: list): assert block_list[4].attn.reuse_kv_layer_idx is None assert block_list[5].attn.sliding_window_size == 512 assert block_list[5].attn.reuse_kv_layer_idx is None + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'conf_path', + [ + 'scripts/train/yamls/pretrain/testing.yaml', + ], +) +def test_reuse_prev_layer_kv_cache( + request: pytest.FixtureRequest, + conf_path: str, + batch_size: int = 2, +): + model_config_overrides = { + 'block_overrides': { + 'start': [ + { + 'name': 'default', + 'repeat': 1, + }, + { + 'name': 'kv_reuse_layer', + 'repeat': 1, + }, + ], + 'overrides': { + 'kv_reuse_layer': { + 'attn_config': { + 'reuse_kv_layer_idx': -1, + }, + }, + }, + }, + 'use_cache': True, + } + test_cfg, model, _ = _get_objs( + request=request, + conf_path=conf_path, + model_config_overrides=model_config_overrides, + attn_impl='flash', + ) + + batch = gen_random_batch(batch_size, test_cfg) + + assert batch['input_ids'].shape == torch.Size([ + batch_size, + test_cfg.max_seq_len, + ]) + model.train() + with get_precision_context(test_cfg.precision): + outputs = model(batch) + len(outputs.past_key_values) == 2 + assert torch.all( + outputs.past_key_values[0][0] == outputs.past_key_values[1][0], + ) + assert torch.all( + outputs.past_key_values[0][1] == outputs.past_key_values[1][1], + ) From cc1f2f3f9811f81f584bc70a1bf3bd5bdd6d3c3b Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sun, 23 Jun 2024 13:20:45 -0700 Subject: [PATCH 25/69] minor --- tests/models/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 151dec0c4e..b5a3e1cf0a 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2781,7 +2781,7 @@ def test_reuse_prev_layer_kv_cache( model.train() with get_precision_context(test_cfg.precision): outputs = model(batch) - len(outputs.past_key_values) == 2 + assert len(outputs.past_key_values) == 2 assert torch.all( outputs.past_key_values[0][0] == outputs.past_key_values[1][0], ) From 63f4196167eeece9b6f844a16fdae4c5cfe83193 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 25 Jun 2024 12:20:14 -0700 Subject: [PATCH 26/69] addressing some comments --- tests/models/layers/test_flash_torch.py | 44 +++++++++---------------- tests/models/test_model.py | 30 +++++------------ 2 files changed, 25 insertions(+), 49 deletions(-) diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index dae044f03e..6208919a9d 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -564,30 +564,18 @@ def test_reuse_prev_layer_kv_cache( """Checks reusing previous layer's kv cache.""" alibi = pos_emb_config['alibi'] rope = pos_emb_config['rope'] - if alibi and not ( - check_alibi_support('flash') and check_alibi_support('flash') - ): - pytest.skip('flash attention below v2.4.2 does not support alibi.') - if rope and (pos_emb_config['rope_impl'] - == 'dail') and (not is_flash_v2_installed()): - pytest.skip('dail implementation of rope requires flash attention 2.') - - if (not is_flash_v2_installed(v2_version='v2.1.2')): - pytest.skip( - 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.', - ) - cfg = om.create({ + cfg = { 'attn_impl': 'flash', 'd_model': 64, 'n_heads': 4, 'attn_pdrop': 0, 'clip_qkv': True, - }) + } - n, s, f = 2, 4, cfg.d_model - assert cfg.d_model % cfg.n_heads == 0 - cfg.kv_n_heads = 2 + n, s, f = 2, 4, cfg['d_model'] + assert cfg['d_model'] % cfg['n_heads'] == 0 + cfg['kv_n_heads'] = 2 sequence_id = torch.LongTensor([ [0] * 2 + [1] * (s - 2), @@ -595,21 +583,21 @@ def test_reuse_prev_layer_kv_cache( ]).to(device=device) # Computes its own kv cache - cfg.reuse_kv_layer_idx = None + cfg['reuse_kv_layer_idx'] = None attn0 = build_attention_layer( name='grouped_query_attention', attn_kwargs=om.to_container(cfg), # type: ignore ).to(device) # Reuses layer 0's kv cache - cfg.reuse_kv_layer_idx = 0 + cfg['reuse_kv_layer_idx'] = 0 attn1 = build_attention_layer( name='grouped_query_attention', attn_kwargs=om.to_container(cfg), # type: ignore ).to(device) attn0_sd = attn0.state_dict() - attn0_sd['Wq.weight'] = attn0_sd['Wqkv.weight'][:cfg.d_model] - attn0_sd['Wq.bias'] = attn0_sd['Wqkv.bias'][:cfg.d_model] + attn0_sd['Wq.weight'] = attn0_sd['Wqkv.weight'][:cfg['d_model']] + attn0_sd['Wq.bias'] = attn0_sd['Wqkv.bias'][:cfg['d_model']] del attn0_sd['Wqkv.weight'] del attn0_sd['Wqkv.bias'] attn1.load_state_dict(attn0_sd) @@ -621,7 +609,7 @@ def gen_bias(attn_impl: str): attn_bias = None bs = attention.attn_bias_shape( attn_impl, - cfg.n_heads, + cfg['n_heads'], s, alibi, use_sequence_id=True, @@ -632,7 +620,7 @@ def gen_bias(attn_impl: str): attn_bias = attention.build_attn_bias( attn_impl, attn_bias, - cfg.n_heads, + cfg['n_heads'], s, causal=causal, alibi=alibi, @@ -668,7 +656,7 @@ def gen_bias(attn_impl: str): alibi_slopes_0 = None if alibi: alibi_slopes_0 = gen_slopes( - n_heads=cfg.n_heads, + n_heads=cfg['n_heads'], alibi_bias_max=8, device=torch.device(device), return_1d=True, @@ -676,7 +664,7 @@ def gen_bias(attn_impl: str): rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, + rope_head_dim=cfg['d_model'] // cfg['n_heads'], rope_impl=pos_emb_config['rope_impl'], rope_theta=pos_emb_config['rope_theta'], rope_dail_config=pos_emb_config.get('rope_dail_config', {}), @@ -714,7 +702,7 @@ def gen_bias(attn_impl: str): alibi_slopes_1 = None if alibi: alibi_slopes_1 = gen_slopes( - n_heads=cfg.n_heads, + n_heads=cfg['n_heads'], alibi_bias_max=8, device=torch.device(device), return_1d=True, @@ -750,8 +738,8 @@ def gen_bias(attn_impl: str): tp = torch_name_param_map[n.replace('Wqkv', 'Wq')] assert p.grad is not None assert tp.grad is not None - assert allclose_helper(p[:cfg.d_model], tp) - assert allclose_helper(p.grad[:cfg.d_model], tp.grad) + assert allclose_helper(p[:cfg['d_model']], tp) + assert allclose_helper(p.grad[:cfg['d_model']], tp.grad) else: tp = torch_name_param_map[n] assert p.grad is not None diff --git a/tests/models/test_model.py b/tests/models/test_model.py index b5a3e1cf0a..5cd0880565 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2634,13 +2634,6 @@ def test_head_dim_8_flash_mqa_attn(batch_size: int = 2): 'repeat': 2, }]], ) -@pytest.mark.parametrize( - 'repeating_pattern', - [[{ - 'name': 'layer_rp', - 'repeat': 1, - }]], -) @pytest.mark.parametrize( 'end', [[], [{ @@ -2651,8 +2644,13 @@ def test_head_dim_8_flash_mqa_attn(batch_size: int = 2): 'repeat': 1, }]], ) -def test_construct_blocks(start: list, repeating_pattern: list, end: list): +def test_construct_blocks(start: list, end: list): n_layers = 9 + repeating_pattern = [{ + 'name': 'layer_rp', + 'repeat': 1, + }] + config = MPTConfig( d_model=32, n_heads=16, @@ -2688,14 +2686,10 @@ def test_construct_blocks(start: list, repeating_pattern: list, end: list): if len(start) > 0: config.block_overrides['start'] = start - if len(repeating_pattern) > 0: - config.block_overrides['repeating_pattern'] = repeating_pattern + config.block_overrides['repeating_pattern'] = repeating_pattern if len(end) > 0: config.block_overrides['end'] = end - if len(start) > 0 or len(repeating_pattern) > 0 or len(end) > 0: - config.block_overrides['overrides'] = overrides - else: - pytest.skip(f'Skipping test because no overrides.') + config.block_overrides['overrides'] = overrides block_list = MPTModel(config).construct_blocks(config) @@ -2732,17 +2726,11 @@ def test_construct_blocks(start: list, repeating_pattern: list, end: list): @pytest.mark.gpu -@pytest.mark.parametrize( - 'conf_path', - [ - 'scripts/train/yamls/pretrain/testing.yaml', - ], -) def test_reuse_prev_layer_kv_cache( request: pytest.FixtureRequest, - conf_path: str, batch_size: int = 2, ): + conf_path = 'scripts/train/yamls/pretrain/testing.yaml' model_config_overrides = { 'block_overrides': { 'start': [ From 89bc22ec2cb250382f553bb16aa3d4f6a465c1c3 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 25 Jun 2024 12:44:14 -0700 Subject: [PATCH 27/69] addressing some comments --- llmfoundry/models/mpt/configuration_mpt.py | 42 +++++++++++----------- llmfoundry/models/mpt/modeling_mpt.py | 10 +++--- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 42602bcac9..94417174eb 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -20,6 +20,7 @@ ffn_config_defaults, init_config_defaults, ) +from llmfoundry.utils.warnings import ExperimentalWarning class MPTConfig(PretrainedConfig): @@ -118,7 +119,7 @@ def __init__( also be a dictionary that specifies the fc layer name and any kwargs for the fc layer. tie_word_embeddings (bool): Whether to tie the input embedding and output layers. use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks. - block_overrides: The allows for overriding default block configs for certain layers. This should contain two sub configs: order and overrides. order specifies the order of different kinds of layers (default refers to a layer that does not apply any overrides). For each kind of layer, specify the overrides in the overrides config. + block_overrides: This allows for overriding default block configs for certain layers. This must contain `overrides` and at least one of `start`, `repeating_pattern`, and `end`. `start`, `repeating_pattern`, and `end` specify the order of different kinds of layers (default refers to a layer that does not apply any overrides). For each kind of layer, specify the `overrides` in the overrides config. To specify this model (https://research.character.ai/optimizing-inference/) , the following config will be needed: block_overrides: start: @@ -183,27 +184,7 @@ def __init__( 'reusing kv cache from a previous layer is not implemented for torch attention.', ) if block_overrides is not None: - warnings.warn( - 'block_overrides is an experimental feature. The YAML design may change in the future.', - ) - if 'start' not in block_overrides and 'repeating_pattern' not in block_overrides and 'end' not in block_overrides: - raise ValueError( - 'either start, repeating_pattern, or end should be defined in block_overrides', - ) - if 'overrides' not in block_overrides: - raise ValueError( - 'overrides should be defined in block_overrides', - ) - for name, override in block_overrides['overrides'].items(): - if name == 'default': - raise ValueError( - 'block overrides cannot be named "default".', - ) - if 'attn_config' in override and 'reuse_kv_layer_idx' in override[ - 'attn_config'] and self.attn_config['attn_impl'] == 'torch': - raise NotImplementedError( - 'reusing kv cache from a previous layer is not implemented for torch attention.', - ) + self._validate_block_overrides(block_overrides) self.block_overrides = block_overrides if isinstance(fc_type, str): @@ -230,6 +211,23 @@ def __init__( self._validate_config() + def _validate_block_overrides(self, block_overrides): + warnings.warn(ExperimentalWarning('block_overrides')) + if 'start' not in block_overrides and 'repeating_pattern' not in block_overrides and 'end' not in block_overrides: + raise ValueError( + 'either start, repeating_pattern, or end should be defined in block_overrides', + ) + if 'overrides' not in block_overrides: + raise ValueError('overrides should be defined in block_overrides',) + for name, override in block_overrides['overrides'].items(): + if name == 'default': + raise ValueError('block overrides cannot be named "default".',) + if 'attn_config' in override and 'reuse_kv_layer_idx' in override[ + 'attn_config'] and self.attn_config['attn_impl'] == 'torch': + raise NotImplementedError( + 'reusing kv cache from a previous layer is not implemented for torch attention.', + ) + def _set_config_defaults( self, config: Dict[str, Any], diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d69188c684..60bedc0f2d 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -464,7 +464,7 @@ def _construct_blocks_with_overrides( 'config.block_overrides should not be None when calling _construct_blocks_with_overrides.', ) modules_order_expanded = {} - for type in 'start', 'repeating_pattern', 'end': + for type in ['start', 'repeating_pattern', 'end']: modules_order_expanded[type] = [] if type in config.block_overrides: for block in config.block_overrides[type]: @@ -487,12 +487,12 @@ def _construct_blocks_with_overrides( raise ValueError( 'Number of layers should be divisible by the specified custom modules order.', ) + num_repetitions = ( + config.n_layers - (start_len + end_len) + ) // repeating_pattern_len modules_order_expanded[ 'repeating_pattern' - ] = modules_order_expanded['repeating_pattern'] * ( - (config.n_layers - - (start_len + end_len)) // repeating_pattern_len - ) + ] = modules_order_expanded['repeating_pattern'] * num_repetitions model_modules_order_expanded = modules_order_expanded[ 'start'] + modules_order_expanded['repeating_pattern' From 0a3f1b46b555473179bc9cb811d59091f057273c Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 25 Jun 2024 16:18:05 -0700 Subject: [PATCH 28/69] setting absolute absolute value for reuse_kv_layer_idx --- llmfoundry/models/mpt/modeling_mpt.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 60bedc0f2d..f7039c8444 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -505,6 +505,7 @@ def _construct_blocks_with_overrides( f'The specified block overrides do not match the number of layers: {len(model_modules_order_expanded)} vs {config.n_layers}.', ) + reuse_kv_layer_idx_dict = {} for i in range(config.n_layers): module_name = model_modules_order_expanded[i] override_config = {} @@ -524,6 +525,10 @@ def _construct_blocks_with_overrides( raise ValueError( f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.', ) + if reuse_kv_layer_idx in reuse_kv_layer_idx_dict: + reuse_kv_layer_idx = reuse_kv_layer_idx_dict[ + reuse_kv_layer_idx] + reuse_kv_layer_idx_dict[i] = reuse_kv_layer_idx override_attn_config['reuse_kv_layer_idx' ] = reuse_kv_layer_idx if self.kv_cache_layers is None: From b74330eb6afee7bf4d74f084dd447eab12248e9a Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 25 Jun 2024 16:19:34 -0700 Subject: [PATCH 29/69] lint --- llmfoundry/models/mpt/configuration_mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 94417174eb..cb22c2f6f4 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -211,7 +211,7 @@ def __init__( self._validate_config() - def _validate_block_overrides(self, block_overrides): + def _validate_block_overrides(self, block_overrides: Dict[str, Any]): warnings.warn(ExperimentalWarning('block_overrides')) if 'start' not in block_overrides and 'repeating_pattern' not in block_overrides and 'end' not in block_overrides: raise ValueError( From 5079d2e9d61fd8ae386880571ecdcf867d2a2d76 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 25 Jun 2024 16:50:20 -0700 Subject: [PATCH 30/69] adding tests for override_block_args --- llmfoundry/models/mpt/modeling_mpt.py | 6 +++--- tests/models/test_model.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index f7039c8444..d76dcc20ae 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -539,7 +539,7 @@ def _construct_blocks_with_overrides( module_name, self._get_overrides_for_logging(override_config), ],) - new_block_args = self._override_block_args( + new_block_args = MPTModel._override_block_args( block_args, override_config, config.allowed_block_overrides, @@ -570,8 +570,8 @@ def _get_overrides_for_logging( overrides_list.append({k: v}) return overrides_list + @staticmethod def _override_block_args( - self, block_args: Dict[str, Any], override_config: Dict[str, Any], allowed_block_overrides: set, @@ -584,7 +584,7 @@ def _override_block_args( f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.', ) if isinstance(override_config[k], dict): - new_block_args[k] = self._override_block_args( + new_block_args[k] = MPTModel._override_block_args( block_args[k], override_config[k], allowed_block_overrides, diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 5cd0880565..63ce070af7 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2776,3 +2776,18 @@ def test_reuse_prev_layer_kv_cache( assert torch.all( outputs.past_key_values[0][1] == outputs.past_key_values[1][1], ) + + +def test_override_block_args(): + block_args = {'a': 1, 'b': {'c': 3}, 'd': 4} + override_config = {'a': 2, 'b': {'c': 5}, 'e': 6} + allowed_block_overrides = {'a', 'c', 'e'} + new_config = MPTModel._override_block_args( + block_args, + override_config, + allowed_block_overrides, + ) + assert new_config['a'] == 2 + assert new_config['d'] == 4 + assert new_config['e'] == 6 + assert new_config['b']['c'] == 5 From 2a7e94875b34bafb65fdb113ccc961adc6197fd2 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 25 Jun 2024 17:46:32 -0700 Subject: [PATCH 31/69] adding error if reusing kv cache from a mismatch layer --- llmfoundry/models/mpt/modeling_mpt.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d76dcc20ae..a480dc5c29 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -529,6 +529,27 @@ def _construct_blocks_with_overrides( reuse_kv_layer_idx = reuse_kv_layer_idx_dict[ reuse_kv_layer_idx] reuse_kv_layer_idx_dict[i] = reuse_kv_layer_idx + + def _get_keys(nested_config: Dict[str, Any]) -> set: + keys_set = set() + for k in nested_config.keys(): + keys_set.add(k) + if isinstance(nested_config[k], dict): + keys_set.update(_get_keys(nested_config[k])) + return keys_set + + parent_layer_name = model_modules_order_expanded[ + reuse_kv_layer_idx] + parent_config = {} if parent_layer_name == 'default' else config.block_overrides[ + 'overrides'][parent_layer_name] + parent_keys = _get_keys(parent_config) + child_keys = _get_keys(override_config) + parent_keys = parent_keys.add('reuse_kv_layer_idx') + if child_keys != parent_keys: + raise ValueError( + 'For reusing the kv cache of a previous layer, the previous layer should match the block config as the current layer.', + ) + override_attn_config['reuse_kv_layer_idx' ] = reuse_kv_layer_idx if self.kv_cache_layers is None: From 1daf53426a4f639e1eaecb653b4fa905927fbe1b Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 25 Jun 2024 17:51:19 -0700 Subject: [PATCH 32/69] fixing test --- tests/models/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 63ce070af7..3066aae1db 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2701,7 +2701,7 @@ def test_construct_blocks(start: list, end: list): assert block_list[1].attn.sliding_window_size == 1024 assert block_list[1].attn.reuse_kv_layer_idx == 0 assert block_list[2].attn.sliding_window_size == 1024 - assert block_list[2].attn.reuse_kv_layer_idx == 1 + assert block_list[2].attn.reuse_kv_layer_idx == 0 else: assert block_list[0].attn.sliding_window_size == 512 assert block_list[0].attn.reuse_kv_layer_idx is None From afa4c09c7e9d4612dca0fc09daf469e7f6a6d44c Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 25 Jun 2024 18:38:08 -0700 Subject: [PATCH 33/69] fixing code, test --- llmfoundry/models/mpt/modeling_mpt.py | 30 +++++++++++---------------- tests/models/test_model.py | 11 +++++----- 2 files changed, 17 insertions(+), 24 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index a480dc5c29..5729b7664a 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -464,17 +464,17 @@ def _construct_blocks_with_overrides( 'config.block_overrides should not be None when calling _construct_blocks_with_overrides.', ) modules_order_expanded = {} - for type in ['start', 'repeating_pattern', 'end']: - modules_order_expanded[type] = [] - if type in config.block_overrides: - for block in config.block_overrides[type]: + for location in ['start', 'repeating_pattern', 'end']: + modules_order_expanded[location] = [] + if location in config.block_overrides: + for block in config.block_overrides[location]: if not isinstance(block['repeat'], int) or block['repeat'] < 0: raise ValueError( 'repeat should be a non-negative integer.', ) - modules_order_expanded[type].extend([block['name']] * - block['repeat']) + modules_order_expanded[location].extend([block['name']] * + block['repeat']) start_len = len(modules_order_expanded['start']) repeating_pattern_len = len(modules_order_expanded['repeating_pattern']) @@ -530,22 +530,16 @@ def _construct_blocks_with_overrides( reuse_kv_layer_idx] reuse_kv_layer_idx_dict[i] = reuse_kv_layer_idx - def _get_keys(nested_config: Dict[str, Any]) -> set: - keys_set = set() - for k in nested_config.keys(): - keys_set.add(k) - if isinstance(nested_config[k], dict): - keys_set.update(_get_keys(nested_config[k])) - return keys_set - parent_layer_name = model_modules_order_expanded[ reuse_kv_layer_idx] parent_config = {} if parent_layer_name == 'default' else config.block_overrides[ 'overrides'][parent_layer_name] - parent_keys = _get_keys(parent_config) - child_keys = _get_keys(override_config) - parent_keys = parent_keys.add('reuse_kv_layer_idx') - if child_keys != parent_keys: + if 'attn_config' not in parent_config: + parent_config['attn_config'] = {} + parent_config['attn_config'][ + 'reuse_kv_layer_idx'] = override_config['attn_config'][ + 'reuse_kv_layer_idx'] + if override_config != parent_config: raise ValueError( 'For reusing the kv cache of a previous layer, the previous layer should match the block config as the current layer.', ) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 3066aae1db..cd9e4fecdb 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2666,7 +2666,6 @@ def test_construct_blocks(start: list, end: list): overrides = { 'layer_s': { 'attn_config': { - 'sliding_window_size': 1024, 'reuse_kv_layer_idx': -1, }, }, @@ -2677,7 +2676,7 @@ def test_construct_blocks(start: list, end: list): }, 'layer_e': { 'attn_config': { - 'sliding_window_size': 256, + 'sliding_window_size': 512, 'reuse_kv_layer_idx': -2, }, }, @@ -2698,18 +2697,18 @@ def test_construct_blocks(start: list, end: list): if len(start) > 0: assert block_list[0].attn.sliding_window_size == -1 assert block_list[0].attn.reuse_kv_layer_idx is None - assert block_list[1].attn.sliding_window_size == 1024 + assert block_list[1].attn.sliding_window_size == -1 assert block_list[1].attn.reuse_kv_layer_idx == 0 - assert block_list[2].attn.sliding_window_size == 1024 + assert block_list[2].attn.sliding_window_size == -1 assert block_list[2].attn.reuse_kv_layer_idx == 0 else: assert block_list[0].attn.sliding_window_size == 512 assert block_list[0].attn.reuse_kv_layer_idx is None if len(end) > 0: - assert block_list[6].attn.sliding_window_size == 256 + assert block_list[6].attn.sliding_window_size == 512 assert block_list[6].attn.reuse_kv_layer_idx == 4 - assert block_list[7].attn.sliding_window_size == 256 + assert block_list[7].attn.sliding_window_size == 512 assert block_list[7].attn.reuse_kv_layer_idx == 5 assert block_list[8].attn.sliding_window_size == -1 assert block_list[8].attn.reuse_kv_layer_idx is None From a8a8a8baa4cb54370f73c0f707593ccb13d5716c Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 25 Jun 2024 18:54:49 -0700 Subject: [PATCH 34/69] fix --- llmfoundry/models/mpt/modeling_mpt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 5729b7664a..6f5bec18aa 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -532,8 +532,9 @@ def _construct_blocks_with_overrides( parent_layer_name = model_modules_order_expanded[ reuse_kv_layer_idx] - parent_config = {} if parent_layer_name == 'default' else config.block_overrides[ - 'overrides'][parent_layer_name] + parent_config = {} if parent_layer_name == 'default' else copy.deepcopy( + config.block_overrides['overrides'][parent_layer_name], + ) if 'attn_config' not in parent_config: parent_config['attn_config'] = {} parent_config['attn_config'][ From 6f468ef3a2358e9d5ffe52e3b4b477b1ba092ea5 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 25 Jun 2024 19:35:15 -0700 Subject: [PATCH 35/69] .. --- llmfoundry/models/mpt/modeling_mpt.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 6f5bec18aa..28221000d1 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -466,15 +466,13 @@ def _construct_blocks_with_overrides( modules_order_expanded = {} for location in ['start', 'repeating_pattern', 'end']: modules_order_expanded[location] = [] - if location in config.block_overrides: - for block in config.block_overrides[location]: - if not isinstance(block['repeat'], - int) or block['repeat'] < 0: - raise ValueError( - 'repeat should be a non-negative integer.', - ) - modules_order_expanded[location].extend([block['name']] * - block['repeat']) + for block in config.block_overrides.get(location, []): + if not isinstance(block['repeat'], int) or block['repeat'] < 0: + raise ValueError( + 'repeat should be a non-negative integer.', + ) + modules_order_expanded[location].extend([block['name']] * + block['repeat']) start_len = len(modules_order_expanded['start']) repeating_pattern_len = len(modules_order_expanded['repeating_pattern']) From 0f96a876814edc8955ed252f54979ee24b2664a8 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 25 Jun 2024 20:53:32 -0700 Subject: [PATCH 36/69] refactoring --- llmfoundry/models/mpt/modeling_mpt.py | 164 +++++++++++++++----------- 1 file changed, 94 insertions(+), 70 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 28221000d1..d203969ca5 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -463,6 +463,99 @@ def _construct_blocks_with_overrides( raise ValueError( 'config.block_overrides should not be None when calling _construct_blocks_with_overrides.', ) + model_modules_order_expanded = self._get_modules_order_expanded(config) + module_list = [] + layer_description_list = [] + + reuse_kv_layer_idx_dict = {} + for i in range(config.n_layers): + module_name = model_modules_order_expanded[i] + override_config = {} + if module_name != 'default': + if 'attn_config' in config.block_overrides['overrides'][ + module_name + ] and 'reuse_kv_layer_idx' in config.block_overrides[ + 'overrides'][module_name]['attn_config']: + override_config = self._validate_reuse_kv_layer_config( + config, + model_modules_order_expanded, + reuse_kv_layer_idx_dict, + i, + module_name, + ) + layer_description_list.append([ + i, + module_name, + self._get_overrides_for_logging(override_config), + ],) + new_block_args = MPTModel._override_block_args( + block_args, + override_config, + config.allowed_block_overrides, + ) + module_list.append( + MPTBlock( + device=config.init_device, + **new_block_args, + ), + ) + log.info( + 'The following is a summary of overrides per layer.\n' + tabulate( + layer_description_list, + headers=['idx', 'name', 'overrides'], + ), + ) + return nn.ModuleList(module_list) + + def _validate_reuse_kv_layer_config( + self, + config, + model_modules_order_expanded, + reuse_kv_layer_idx_dict, + i, + module_name, + ): + override_config = copy.deepcopy( + config.block_overrides['overrides'][module_name], + ) + override_attn_config = override_config.get('attn_config', None) + if override_attn_config['reuse_kv_layer_idx'] >= 0: + raise ValueError( + f'The relative index of kv layer to reuse, {override_attn_config["reuse_kv_layer_idx"]=}, should be negative.', + ) + reuse_kv_layer_idx = i + override_attn_config['reuse_kv_layer_idx'] + if reuse_kv_layer_idx < 0: + raise ValueError( + f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.', + ) + if reuse_kv_layer_idx in reuse_kv_layer_idx_dict: + reuse_kv_layer_idx = reuse_kv_layer_idx_dict[reuse_kv_layer_idx] + reuse_kv_layer_idx_dict[i] = reuse_kv_layer_idx + + parent_layer_name = model_modules_order_expanded[reuse_kv_layer_idx] + parent_config = {} if parent_layer_name == 'default' else copy.deepcopy( + config.block_overrides['overrides'][parent_layer_name], + ) + if 'attn_config' not in parent_config: + parent_config['attn_config'] = {} + parent_config['attn_config']['reuse_kv_layer_idx'] = override_config[ + 'attn_config']['reuse_kv_layer_idx'] + if override_config != parent_config: + raise ValueError( + 'For reusing the kv cache of a previous layer, the previous layer should match the block config as the current layer.', + ) + + override_attn_config['reuse_kv_layer_idx'] = reuse_kv_layer_idx + if self.kv_cache_layers is None: + self.kv_cache_layers = set() + self.kv_cache_layers.add(reuse_kv_layer_idx) + return override_config + + def _get_modules_order_expanded(self, config: MPTConfig): + if config.block_overrides is None: + raise ValueError( + 'config.block_overrides should not be None when calling _get_modules_order_expanded.', + ) modules_order_expanded = {} for location in ['start', 'repeating_pattern', 'end']: modules_order_expanded[location] = [] @@ -496,81 +589,12 @@ def _construct_blocks_with_overrides( 'start'] + modules_order_expanded['repeating_pattern' ] + modules_order_expanded['end'] - module_list = [] - layer_description_list = [] if len(model_modules_order_expanded) != config.n_layers: raise ValueError( f'The specified block overrides do not match the number of layers: {len(model_modules_order_expanded)} vs {config.n_layers}.', ) - reuse_kv_layer_idx_dict = {} - for i in range(config.n_layers): - module_name = model_modules_order_expanded[i] - override_config = {} - if module_name != 'default': - override_config = copy.deepcopy( - config.block_overrides['overrides'][module_name], - ) - override_attn_config = override_config.get('attn_config', None) - if override_attn_config is not None and 'reuse_kv_layer_idx' in override_attn_config: - if override_attn_config['reuse_kv_layer_idx'] >= 0: - raise ValueError( - f'The relative index of kv layer to reuse, {override_attn_config["reuse_kv_layer_idx"]=}, should be negative.', - ) - reuse_kv_layer_idx = i + override_attn_config[ - 'reuse_kv_layer_idx'] - if reuse_kv_layer_idx < 0: - raise ValueError( - f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.', - ) - if reuse_kv_layer_idx in reuse_kv_layer_idx_dict: - reuse_kv_layer_idx = reuse_kv_layer_idx_dict[ - reuse_kv_layer_idx] - reuse_kv_layer_idx_dict[i] = reuse_kv_layer_idx - - parent_layer_name = model_modules_order_expanded[ - reuse_kv_layer_idx] - parent_config = {} if parent_layer_name == 'default' else copy.deepcopy( - config.block_overrides['overrides'][parent_layer_name], - ) - if 'attn_config' not in parent_config: - parent_config['attn_config'] = {} - parent_config['attn_config'][ - 'reuse_kv_layer_idx'] = override_config['attn_config'][ - 'reuse_kv_layer_idx'] - if override_config != parent_config: - raise ValueError( - 'For reusing the kv cache of a previous layer, the previous layer should match the block config as the current layer.', - ) - - override_attn_config['reuse_kv_layer_idx' - ] = reuse_kv_layer_idx - if self.kv_cache_layers is None: - self.kv_cache_layers = set() - self.kv_cache_layers.add(reuse_kv_layer_idx) - layer_description_list.append([ - i, - module_name, - self._get_overrides_for_logging(override_config), - ],) - new_block_args = MPTModel._override_block_args( - block_args, - override_config, - config.allowed_block_overrides, - ) - module_list.append( - MPTBlock( - device=config.init_device, - **new_block_args, - ), - ) - log.info( - 'The following is a summary of overrides per layer.\n' + tabulate( - layer_description_list, - headers=['idx', 'name', 'overrides'], - ), - ) - return nn.ModuleList(module_list) + return model_modules_order_expanded def _get_overrides_for_logging( self, From 085de4c14abd96c6713fc2e482137089269b5172 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 25 Jun 2024 21:11:39 -0700 Subject: [PATCH 37/69] fix --- llmfoundry/models/mpt/modeling_mpt.py | 37 +++++++++++++-------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d203969ca5..b811630308 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -467,11 +467,14 @@ def _construct_blocks_with_overrides( module_list = [] layer_description_list = [] - reuse_kv_layer_idx_dict = {} - for i in range(config.n_layers): - module_name = model_modules_order_expanded[i] + self.reuse_kv_layer_idx_dict = {} + for b_idx in range(config.n_layers): + module_name = model_modules_order_expanded[b_idx] override_config = {} if module_name != 'default': + override_config = copy.deepcopy( + config.block_overrides['overrides'][module_name], + ) if 'attn_config' in config.block_overrides['overrides'][ module_name ] and 'reuse_kv_layer_idx' in config.block_overrides[ @@ -479,12 +482,11 @@ def _construct_blocks_with_overrides( override_config = self._validate_reuse_kv_layer_config( config, model_modules_order_expanded, - reuse_kv_layer_idx_dict, - i, - module_name, + b_idx, + override_config, ) layer_description_list.append([ - i, + b_idx, module_name, self._get_overrides_for_logging(override_config), ],) @@ -509,28 +511,25 @@ def _construct_blocks_with_overrides( def _validate_reuse_kv_layer_config( self, - config, - model_modules_order_expanded, - reuse_kv_layer_idx_dict, - i, - module_name, + config: MPTConfig, + model_modules_order_expanded: List[str], + b_idx: int, + override_config: Dict[str, Any], ): - override_config = copy.deepcopy( - config.block_overrides['overrides'][module_name], - ) override_attn_config = override_config.get('attn_config', None) if override_attn_config['reuse_kv_layer_idx'] >= 0: raise ValueError( f'The relative index of kv layer to reuse, {override_attn_config["reuse_kv_layer_idx"]=}, should be negative.', ) - reuse_kv_layer_idx = i + override_attn_config['reuse_kv_layer_idx'] + reuse_kv_layer_idx = b_idx + override_attn_config['reuse_kv_layer_idx'] if reuse_kv_layer_idx < 0: raise ValueError( f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.', ) - if reuse_kv_layer_idx in reuse_kv_layer_idx_dict: - reuse_kv_layer_idx = reuse_kv_layer_idx_dict[reuse_kv_layer_idx] - reuse_kv_layer_idx_dict[i] = reuse_kv_layer_idx + if reuse_kv_layer_idx in self.reuse_kv_layer_idx_dict: + reuse_kv_layer_idx = self.reuse_kv_layer_idx_dict[reuse_kv_layer_idx + ] + self.reuse_kv_layer_idx_dict[b_idx] = reuse_kv_layer_idx parent_layer_name = model_modules_order_expanded[reuse_kv_layer_idx] parent_config = {} if parent_layer_name == 'default' else copy.deepcopy( From 65b816eaefa08bd37f4f7580ade34579616372e0 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 25 Jun 2024 21:48:55 -0700 Subject: [PATCH 38/69] .. --- llmfoundry/models/mpt/modeling_mpt.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index b811630308..10d35347cf 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -475,11 +475,9 @@ def _construct_blocks_with_overrides( override_config = copy.deepcopy( config.block_overrides['overrides'][module_name], ) - if 'attn_config' in config.block_overrides['overrides'][ - module_name - ] and 'reuse_kv_layer_idx' in config.block_overrides[ - 'overrides'][module_name]['attn_config']: - override_config = self._validate_reuse_kv_layer_config( + if 'reuse_kv_layer_idx' in config.block_overrides['overrides'][ + module_name].get('attn_config', {}): + self._validate_reuse_kv_layer_config( config, model_modules_order_expanded, b_idx, @@ -516,7 +514,11 @@ def _validate_reuse_kv_layer_config( b_idx: int, override_config: Dict[str, Any], ): - override_attn_config = override_config.get('attn_config', None) + if config.block_overrides is None: + raise ValueError( + 'config.block_overrides should not be None when calling _construct_blocks_with_overrides.', + ) + override_attn_config = override_config['attn_config'] if override_attn_config['reuse_kv_layer_idx'] >= 0: raise ValueError( f'The relative index of kv layer to reuse, {override_attn_config["reuse_kv_layer_idx"]=}, should be negative.', From e919422b6a62484269ff723f240a83ad00819fc1 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 25 Jun 2024 21:50:13 -0700 Subject: [PATCH 39/69] .. --- llmfoundry/models/mpt/modeling_mpt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 10d35347cf..4654800f00 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -478,7 +478,7 @@ def _construct_blocks_with_overrides( if 'reuse_kv_layer_idx' in config.block_overrides['overrides'][ module_name].get('attn_config', {}): self._validate_reuse_kv_layer_config( - config, + config.block_overrides, model_modules_order_expanded, b_idx, override_config, @@ -509,12 +509,12 @@ def _construct_blocks_with_overrides( def _validate_reuse_kv_layer_config( self, - config: MPTConfig, + block_overrides: Dict[str, Any], model_modules_order_expanded: List[str], b_idx: int, override_config: Dict[str, Any], ): - if config.block_overrides is None: + if block_overrides is None: raise ValueError( 'config.block_overrides should not be None when calling _construct_blocks_with_overrides.', ) @@ -535,7 +535,7 @@ def _validate_reuse_kv_layer_config( parent_layer_name = model_modules_order_expanded[reuse_kv_layer_idx] parent_config = {} if parent_layer_name == 'default' else copy.deepcopy( - config.block_overrides['overrides'][parent_layer_name], + block_overrides['overrides'][parent_layer_name], ) if 'attn_config' not in parent_config: parent_config['attn_config'] = {} From b2c9ee9f9e798b7b27714596d457690ba5005edd Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 25 Jun 2024 21:52:17 -0700 Subject: [PATCH 40/69] .. --- llmfoundry/models/mpt/modeling_mpt.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 4654800f00..bdbeb0e2f1 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -475,8 +475,10 @@ def _construct_blocks_with_overrides( override_config = copy.deepcopy( config.block_overrides['overrides'][module_name], ) - if 'reuse_kv_layer_idx' in config.block_overrides['overrides'][ - module_name].get('attn_config', {}): + if 'reuse_kv_layer_idx' in override_config.get( + 'attn_config', + {}, + ): self._validate_reuse_kv_layer_config( config.block_overrides, model_modules_order_expanded, From 7e7797910b8a70bad0446069ec92faea1210ca22 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 26 Jun 2024 13:21:01 -0700 Subject: [PATCH 41/69] refactoring --- llmfoundry/models/mpt/modeling_mpt.py | 28 ++++++++++++--------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index bdbeb0e2f1..dc4587a924 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -445,26 +445,31 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: self.kv_cache_layers = None if config.block_overrides is not None: - return self._construct_blocks_with_overrides(config, block_args) + block_args_list = self._construct_blocks_with_overrides( + config, + block_args, + ) + else: + block_args_list = [block_args for _ in range(config.n_layers)] return nn.ModuleList([ self.block_class( device=config.init_device, - **block_args, - ) for _ in range(config.n_layers) + **block_args_i, + ) for block_args_i in block_args_list ]) def _construct_blocks_with_overrides( self, config: MPTConfig, block_args: Dict[str, Any], - ) -> nn.ModuleList: + ) -> List[Dict[str, Any]]: if config.block_overrides is None: raise ValueError( 'config.block_overrides should not be None when calling _construct_blocks_with_overrides.', ) model_modules_order_expanded = self._get_modules_order_expanded(config) - module_list = [] + new_block_args_list = [] layer_description_list = [] self.reuse_kv_layer_idx_dict = {} @@ -495,19 +500,14 @@ def _construct_blocks_with_overrides( override_config, config.allowed_block_overrides, ) - module_list.append( - MPTBlock( - device=config.init_device, - **new_block_args, - ), - ) + new_block_args_list.append(new_block_args) log.info( 'The following is a summary of overrides per layer.\n' + tabulate( layer_description_list, headers=['idx', 'name', 'overrides'], ), ) - return nn.ModuleList(module_list) + return new_block_args_list def _validate_reuse_kv_layer_config( self, @@ -516,10 +516,6 @@ def _validate_reuse_kv_layer_config( b_idx: int, override_config: Dict[str, Any], ): - if block_overrides is None: - raise ValueError( - 'config.block_overrides should not be None when calling _construct_blocks_with_overrides.', - ) override_attn_config = override_config['attn_config'] if override_attn_config['reuse_kv_layer_idx'] >= 0: raise ValueError( From 0ff081a00d1c7330e02f02b6859f4ede29baae84 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 26 Jun 2024 13:22:40 -0700 Subject: [PATCH 42/69] .. --- llmfoundry/models/mpt/modeling_mpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index dc4587a924..cc854f8a1d 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -445,7 +445,7 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: self.kv_cache_layers = None if config.block_overrides is not None: - block_args_list = self._construct_blocks_with_overrides( + block_args_list = self._get_override_block_args_list( config, block_args, ) @@ -459,14 +459,14 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: ) for block_args_i in block_args_list ]) - def _construct_blocks_with_overrides( + def _get_override_block_args_list( self, config: MPTConfig, block_args: Dict[str, Any], ) -> List[Dict[str, Any]]: if config.block_overrides is None: raise ValueError( - 'config.block_overrides should not be None when calling _construct_blocks_with_overrides.', + 'config.block_overrides should not be None when calling _get_override_block_args_list.', ) model_modules_order_expanded = self._get_modules_order_expanded(config) new_block_args_list = [] From d15d34b18611d98d75fe7510e66eedd9bfb35466 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 26 Jun 2024 13:24:53 -0700 Subject: [PATCH 43/69] .. --- llmfoundry/models/mpt/modeling_mpt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index cc854f8a1d..4b8a15e60f 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -599,6 +599,7 @@ def _get_overrides_for_logging( self, override_config: Dict[str, Any], ) -> List[dict[str, str]]: + # Flattens the override config for logging. overrides_list = [] for k, v in override_config.items(): if isinstance(v, dict): From 114b42f34c4ab6869d39438c570604a6cbd62696 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 26 Jun 2024 13:56:43 -0700 Subject: [PATCH 44/69] .. --- llmfoundry/models/mpt/modeling_mpt.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 4b8a15e60f..2c136d461a 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -472,7 +472,7 @@ def _get_override_block_args_list( new_block_args_list = [] layer_description_list = [] - self.reuse_kv_layer_idx_dict = {} + reuse_kv_layer_idx_dict = {} for b_idx in range(config.n_layers): module_name = model_modules_order_expanded[b_idx] override_config = {} @@ -485,10 +485,12 @@ def _get_override_block_args_list( {}, ): self._validate_reuse_kv_layer_config( - config.block_overrides, + block_overrides=config.block_overrides, + model_modules_order_expanded= model_modules_order_expanded, - b_idx, - override_config, + b_idx=b_idx, + override_config=override_config, + reuse_kv_layer_idx_dict=reuse_kv_layer_idx_dict, ) layer_description_list.append([ b_idx, @@ -515,6 +517,7 @@ def _validate_reuse_kv_layer_config( model_modules_order_expanded: List[str], b_idx: int, override_config: Dict[str, Any], + reuse_kv_layer_idx_dict: Dict[int, int], ): override_attn_config = override_config['attn_config'] if override_attn_config['reuse_kv_layer_idx'] >= 0: @@ -526,10 +529,9 @@ def _validate_reuse_kv_layer_config( raise ValueError( f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.', ) - if reuse_kv_layer_idx in self.reuse_kv_layer_idx_dict: - reuse_kv_layer_idx = self.reuse_kv_layer_idx_dict[reuse_kv_layer_idx - ] - self.reuse_kv_layer_idx_dict[b_idx] = reuse_kv_layer_idx + if reuse_kv_layer_idx in reuse_kv_layer_idx_dict: + reuse_kv_layer_idx = reuse_kv_layer_idx_dict[reuse_kv_layer_idx] + reuse_kv_layer_idx_dict[b_idx] = reuse_kv_layer_idx parent_layer_name = model_modules_order_expanded[reuse_kv_layer_idx] parent_config = {} if parent_layer_name == 'default' else copy.deepcopy( From 380e954531dadbe8f430132bde62e9287219e5a8 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 26 Jun 2024 14:12:09 -0700 Subject: [PATCH 45/69] adding test for _get_modules_order_expanded --- llmfoundry/models/mpt/modeling_mpt.py | 7 ++- tests/models/test_model.py | 66 +++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 2c136d461a..895258c2e7 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -468,7 +468,9 @@ def _get_override_block_args_list( raise ValueError( 'config.block_overrides should not be None when calling _get_override_block_args_list.', ) - model_modules_order_expanded = self._get_modules_order_expanded(config) + model_modules_order_expanded = MPTModel._get_modules_order_expanded( + config, + ) new_block_args_list = [] layer_description_list = [] @@ -552,7 +554,8 @@ def _validate_reuse_kv_layer_config( self.kv_cache_layers.add(reuse_kv_layer_idx) return override_config - def _get_modules_order_expanded(self, config: MPTConfig): + @staticmethod + def _get_modules_order_expanded(config: MPTConfig) -> List[str]: if config.block_overrides is None: raise ValueError( 'config.block_overrides should not be None when calling _get_modules_order_expanded.', diff --git a/tests/models/test_model.py b/tests/models/test_model.py index cd9e4fecdb..ceff104b0b 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2790,3 +2790,69 @@ def test_override_block_args(): assert new_config['d'] == 4 assert new_config['e'] == 6 assert new_config['b']['c'] == 5 + + +@pytest.mark.parametrize('start', [True, False]) +@pytest.mark.parametrize('repeating_pattern', [True, False]) +@pytest.mark.parametrize('end', [True, False]) +def test_get_modules_order_expanded( + start: bool, + repeating_pattern: bool, + end: bool, +): + n_layers = 0 + block_overrides = {} + expected_list = [] + + if start: + block_overrides['start'] = [ + { + 'name': 'default', + 'repeat': 1, + }, + { + 'name': 'layer_a', + 'repeat': 2, + }, + ] + n_layers += 3 + expected_list.extend(['default', 'layer_a', 'layer_a']) + + if start: + block_overrides['start'] = [ + { + 'name': 'layer_b', + 'repeat': 3, + }, + ] + n_layers += 6 + expected_list.extend(['layer_b'] * 6) + + if end: + block_overrides['start'] = [ + { + 'name': 'layer_c', + 'repeat': 1, + }, + { + 'name': 'default', + 'repeat': 2, + }, + ] + n_layers += 3 + expected_list.extend(['layer_c', 'default', 'default']) + + config = MPTConfig( + d_model=32, + n_heads=16, + n_layers=n_layers, + block_overrides=block_overrides, + expansion_ratio=2, + max_seq_len=64, + attn_config={ + 'attn_impl': 'flash', + 'attn_type': 'grouped_query_attention', + 'kv_n_heads': 4, + }, + ) + assert expected_list == MPTModel._get_modules_order_expanded(config) From 3e3c974c4778baf1413c48151d6b759f6987330a Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 26 Jun 2024 14:19:17 -0700 Subject: [PATCH 46/69] fixing test --- tests/models/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index ceff104b0b..096524b6ec 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2818,7 +2818,7 @@ def test_get_modules_order_expanded( n_layers += 3 expected_list.extend(['default', 'layer_a', 'layer_a']) - if start: + if repeating_pattern: block_overrides['start'] = [ { 'name': 'layer_b', From 7f9ef5c3db879552599c570afa67270d7966b2df Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 26 Jun 2024 14:19:45 -0700 Subject: [PATCH 47/69] fixing test --- tests/models/test_model.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 096524b6ec..3d5142ed1e 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2801,7 +2801,7 @@ def test_get_modules_order_expanded( end: bool, ): n_layers = 0 - block_overrides = {} + block_overrides = {'overrides': {'a': 'b'}} expected_list = [] if start: @@ -2819,7 +2819,7 @@ def test_get_modules_order_expanded( expected_list.extend(['default', 'layer_a', 'layer_a']) if repeating_pattern: - block_overrides['start'] = [ + block_overrides['repeating_pattern'] = [ { 'name': 'layer_b', 'repeat': 3, @@ -2829,7 +2829,7 @@ def test_get_modules_order_expanded( expected_list.extend(['layer_b'] * 6) if end: - block_overrides['start'] = [ + block_overrides['end'] = [ { 'name': 'layer_c', 'repeat': 1, @@ -2841,6 +2841,9 @@ def test_get_modules_order_expanded( ] n_layers += 3 expected_list.extend(['layer_c', 'default', 'default']) + + if n_layers == 0: + pytest.skip('Skipping because no overrides.') config = MPTConfig( d_model=32, From 18b77afede90e122b35f5319d954e659a405394b Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 26 Jun 2024 14:21:35 -0700 Subject: [PATCH 48/69] lint --- tests/models/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 3d5142ed1e..707c5c062a 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2841,7 +2841,7 @@ def test_get_modules_order_expanded( ] n_layers += 3 expected_list.extend(['layer_c', 'default', 'default']) - + if n_layers == 0: pytest.skip('Skipping because no overrides.') From f417e954f9b23155f7d637629b6d8ce296188718 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 26 Jun 2024 14:23:46 -0700 Subject: [PATCH 49/69] lint --- tests/models/test_model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 707c5c062a..4c24495829 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2805,7 +2805,7 @@ def test_get_modules_order_expanded( expected_list = [] if start: - block_overrides['start'] = [ + block_overrides['start'] = [ # type: ignore { 'name': 'default', 'repeat': 1, @@ -2819,7 +2819,9 @@ def test_get_modules_order_expanded( expected_list.extend(['default', 'layer_a', 'layer_a']) if repeating_pattern: - block_overrides['repeating_pattern'] = [ + block_overrides[ + 'repeating_pattern' + ] = [ # type: ignore { 'name': 'layer_b', 'repeat': 3, @@ -2829,7 +2831,7 @@ def test_get_modules_order_expanded( expected_list.extend(['layer_b'] * 6) if end: - block_overrides['end'] = [ + block_overrides['end'] = [ # type: ignore { 'name': 'layer_c', 'repeat': 1, From 72eb7e189f4a779f1a4ebf89931971d58ec88b5a Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 26 Jun 2024 15:24:00 -0700 Subject: [PATCH 50/69] adding test --- llmfoundry/models/mpt/modeling_mpt.py | 13 ++--- tests/models/test_model.py | 75 +++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 6 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 895258c2e7..19c375abd2 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -486,7 +486,7 @@ def _get_override_block_args_list( 'attn_config', {}, ): - self._validate_reuse_kv_layer_config( + MPTModel._validate_reuse_kv_layer_config( block_overrides=config.block_overrides, model_modules_order_expanded= model_modules_order_expanded, @@ -494,6 +494,11 @@ def _get_override_block_args_list( override_config=override_config, reuse_kv_layer_idx_dict=reuse_kv_layer_idx_dict, ) + if self.kv_cache_layers is None: + self.kv_cache_layers = set() + self.kv_cache_layers.add( + override_config['attn_config']['reuse_kv_layer_idx'], + ) layer_description_list.append([ b_idx, module_name, @@ -513,8 +518,8 @@ def _get_override_block_args_list( ) return new_block_args_list + @staticmethod def _validate_reuse_kv_layer_config( - self, block_overrides: Dict[str, Any], model_modules_order_expanded: List[str], b_idx: int, @@ -549,10 +554,6 @@ def _validate_reuse_kv_layer_config( ) override_attn_config['reuse_kv_layer_idx'] = reuse_kv_layer_idx - if self.kv_cache_layers is None: - self.kv_cache_layers = set() - self.kv_cache_layers.add(reuse_kv_layer_idx) - return override_config @staticmethod def _get_modules_order_expanded(config: MPTConfig) -> List[str]: diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 4c24495829..95eb3ed221 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2861,3 +2861,78 @@ def test_get_modules_order_expanded( }, ) assert expected_list == MPTModel._get_modules_order_expanded(config) + + +@pytest.mark.parametrize('reuse_kv_layer_idx', [-2, -1, 0]) +def test_validate_reuse_kv_layer_config(reuse_kv_layer_idx: int): + layer_a_override = { + 'key_1': 'value_a', + 'attn_config': { + 'key_2': 'value_b', + }, + } + layer_b_override = { + 'key_1': 'value_c', + 'attn_config': { + 'key_2': 'value_d', + }, + } + layer_c_override = { + 'key_1': 'value_c' if reuse_kv_layer_idx == -1 else 'value_a', + 'attn_config': { + 'key_2': 'value_d' if reuse_kv_layer_idx == -1 else 'value_b', + 'reuse_kv_layer_idx': reuse_kv_layer_idx, + }, + } + block_overrides = { + 'overrides': { + 'layer_a': layer_a_override, + 'layer_b': layer_b_override, + 'layer_c': layer_c_override, + }, + } + model_modules_order_expanded = ['layer_a', 'layer_b', 'layer_c'] + if reuse_kv_layer_idx == -1: + model_modules_order_expanded = [ + 'layer_a', + 'layer_b', + 'layer_c', + 'layer_c', + 'layer_c', + 'layer_a', + 'layer_c', + ] + reuse_kv_layer_idx_dict = {} + + def _validate_helper(b_idx: int): + MPTModel._validate_reuse_kv_layer_config( + block_overrides=block_overrides, + model_modules_order_expanded=model_modules_order_expanded, + b_idx=b_idx, + override_config=copy.deepcopy( + block_overrides['overrides'][model_modules_order_expanded[b_idx] + ], + ), + reuse_kv_layer_idx_dict=reuse_kv_layer_idx_dict, + ) + + if reuse_kv_layer_idx == -1: + _validate_helper(b_idx=2) + _validate_helper(b_idx=3) + _validate_helper(b_idx=4) + with pytest.raises( + expected_exception=ValueError, + match= + 'For reusing the kv cache of a previous layer, the previous layer should match the block config as the current layer\.', # type: ignore + ): + _validate_helper(b_idx=6) + + elif reuse_kv_layer_idx == -2: + _validate_helper(b_idx=2) + else: + with pytest.raises( + expected_exception=ValueError, + match= + 'The relative index of kv layer to reuse, override_attn_config\[\"reuse_kv_layer_idx\"\]=0, should be negative\.', # type: ignore + ): + _validate_helper(b_idx=2) From 391978cfc065322592b6a8db82170842a9ec9bbe Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 26 Jun 2024 16:43:36 -0700 Subject: [PATCH 51/69] addressing comment --- tests/models/test_model.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 95eb3ed221..0e574f9f36 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -5,6 +5,7 @@ import os import pathlib import warnings +from functools import partial from typing import Any, Dict, List, Optional, Union, cast from unittest import mock @@ -2766,6 +2767,16 @@ def test_reuse_prev_layer_kv_cache( test_cfg.max_seq_len, ]) model.train() + + prev_layer_key_value_dict = {} + + def mock_forward(b_forward, b_idx, *args, **kwargs): + prev_layer_key_value_dict[b_idx] = kwargs['prev_layer_key_value'] + return b_forward(*args, **kwargs) + + for b_idx, block in enumerate(model.model.transformer.blocks): + block.forward = partial(mock_forward, block.forward, b_idx) + with get_precision_context(test_cfg.precision): outputs = model(batch) assert len(outputs.past_key_values) == 2 @@ -2775,6 +2786,13 @@ def test_reuse_prev_layer_kv_cache( assert torch.all( outputs.past_key_values[0][1] == outputs.past_key_values[1][1], ) + assert prev_layer_key_value_dict[0] is None + assert torch.all( + prev_layer_key_value_dict[1][0] == outputs.past_key_values[0][0], + ) + assert torch.all( + prev_layer_key_value_dict[1][1] == outputs.past_key_values[0][1], + ) def test_override_block_args(): From 0659e322ed53fbfd78b2276998280ef90c42a22a Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 26 Jun 2024 17:36:09 -0700 Subject: [PATCH 52/69] .. --- tests/models/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 0e574f9f36..de461a8aa9 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2770,7 +2770,7 @@ def test_reuse_prev_layer_kv_cache( prev_layer_key_value_dict = {} - def mock_forward(b_forward, b_idx, *args, **kwargs): + def mock_forward(b_forward, b_idx, *args, **kwargs): # type: ignore prev_layer_key_value_dict[b_idx] = kwargs['prev_layer_key_value'] return b_forward(*args, **kwargs) From 180c004203cbad9f2185475154fdfe1d7947a2c1 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 26 Jun 2024 18:16:55 -0700 Subject: [PATCH 53/69] fixing test --- tests/models/layers/test_flash_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index 6208919a9d..01d982052f 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -586,14 +586,14 @@ def test_reuse_prev_layer_kv_cache( cfg['reuse_kv_layer_idx'] = None attn0 = build_attention_layer( name='grouped_query_attention', - attn_kwargs=om.to_container(cfg), # type: ignore + attn_kwargs=cfg, # type: ignore ).to(device) # Reuses layer 0's kv cache cfg['reuse_kv_layer_idx'] = 0 attn1 = build_attention_layer( name='grouped_query_attention', - attn_kwargs=om.to_container(cfg), # type: ignore + attn_kwargs=cfg, # type: ignore ).to(device) attn0_sd = attn0.state_dict() attn0_sd['Wq.weight'] = attn0_sd['Wqkv.weight'][:cfg['d_model']] From 023070f2dfcb3010240216e7a2411181d0ead866 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 27 Jun 2024 14:19:51 -0700 Subject: [PATCH 54/69] changing yaml format --- llmfoundry/models/mpt/configuration_mpt.py | 29 +-- llmfoundry/models/mpt/modeling_mpt.py | 59 ++---- tests/models/test_model.py | 227 +++++++++------------ 3 files changed, 121 insertions(+), 194 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index cb22c2f6f4..e29532ece9 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -119,26 +119,19 @@ def __init__( also be a dictionary that specifies the fc layer name and any kwargs for the fc layer. tie_word_embeddings (bool): Whether to tie the input embedding and output layers. use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks. - block_overrides: This allows for overriding default block configs for certain layers. This must contain `overrides` and at least one of `start`, `repeating_pattern`, and `end`. `start`, `repeating_pattern`, and `end` specify the order of different kinds of layers (default refers to a layer that does not apply any overrides). For each kind of layer, specify the `overrides` in the overrides config. + block_overrides: This allows for overriding default block configs for certain layers. This must contain `overrides` and `order`. `order` is a nested list which describes the order of the layers. For each kind of layer, specify the `overrides` in the overrides config (default refers to a layer that does not apply any overrides). To specify this model (https://research.character.ai/optimizing-inference/) , the following config will be needed: block_overrides: - start: - - name: default - repeat: 1 - repeating_pattern: - - name: sliding_window_layer - repeat: 1 - - name: sliding_window_layer_reuse - repeat: 1 - - name: sliding_window_layer - repeat: 1 - - name: sliding_window_layer_reuse - repeat: 2 - - name: reuse_kv_layer - repeat: 1 - end: - - name: default - repeat: 0 + order: + - name: default + - repeat: 2 + order: + - name: sliding_window_layer + - name: sliding_window_layer_reuse + - name: sliding_window_layer + - repeat: 2 + name: sliding_window_layer_reuse + - name: reuse_kv_layer overrides: sliding_window_layer: attn_config: diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 19c375abd2..6b837d3bee 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -469,8 +469,13 @@ def _get_override_block_args_list( 'config.block_overrides should not be None when calling _get_override_block_args_list.', ) model_modules_order_expanded = MPTModel._get_modules_order_expanded( - config, + config.block_overrides['order'], ) + if len(model_modules_order_expanded) != config.n_layers: + raise ValueError( + f'The specified block overrides do not match the number of layers: {len(model_modules_order_expanded)} vs {config.n_layers}.', + ) + new_block_args_list = [] layer_description_list = [] @@ -556,48 +561,22 @@ def _validate_reuse_kv_layer_config( override_attn_config['reuse_kv_layer_idx'] = reuse_kv_layer_idx @staticmethod - def _get_modules_order_expanded(config: MPTConfig) -> List[str]: - if config.block_overrides is None: - raise ValueError( - 'config.block_overrides should not be None when calling _get_modules_order_expanded.', - ) - modules_order_expanded = {} - for location in ['start', 'repeating_pattern', 'end']: - modules_order_expanded[location] = [] - for block in config.block_overrides.get(location, []): - if not isinstance(block['repeat'], int) or block['repeat'] < 0: - raise ValueError( - 'repeat should be a non-negative integer.', - ) - modules_order_expanded[location].extend([block['name']] * - block['repeat']) - - start_len = len(modules_order_expanded['start']) - repeating_pattern_len = len(modules_order_expanded['repeating_pattern']) - end_len = len(modules_order_expanded['end']) - - if repeating_pattern_len > 0: - if ( - config.n_layers - (start_len + end_len) - ) % repeating_pattern_len != 0: + def _get_modules_order_expanded(order: List[Dict[str, Any]]) -> List[str]: + model_modules_order_expanded = [] + for item in order: + repeat = item['repeat'] if 'repeat' in item else 1 + if ('name' in item) == ('order' in item): raise ValueError( - 'Number of layers should be divisible by the specified custom modules order.', + 'Exactly one of `order` or `name` must be specified for each block override.', ) - num_repetitions = ( - config.n_layers - (start_len + end_len) - ) // repeating_pattern_len - modules_order_expanded[ - 'repeating_pattern' - ] = modules_order_expanded['repeating_pattern'] * num_repetitions - - model_modules_order_expanded = modules_order_expanded[ - 'start'] + modules_order_expanded['repeating_pattern' - ] + modules_order_expanded['end'] - if len(model_modules_order_expanded) != config.n_layers: - raise ValueError( - f'The specified block overrides do not match the number of layers: {len(model_modules_order_expanded)} vs {config.n_layers}.', - ) + if 'name' in item: + model_modules_order_expanded.extend([item['name']] * repeat) + else: + model_modules_order_expanded.extend( + MPTModel._get_modules_order_expanded(item['order']) * + repeat, + ) return model_modules_order_expanded diff --git a/tests/models/test_model.py b/tests/models/test_model.py index de461a8aa9..7071adc888 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2625,32 +2625,8 @@ def test_head_dim_8_flash_mqa_attn(batch_size: int = 2): assert not torch.isnan(output.logits).any() -@pytest.mark.parametrize( - 'start', - [[], [{ - 'name': 'default', - 'repeat': 1, - }, { - 'name': 'layer_s', - 'repeat': 2, - }]], -) -@pytest.mark.parametrize( - 'end', - [[], [{ - 'name': 'layer_e', - 'repeat': 2, - }, { - 'name': 'default', - 'repeat': 1, - }]], -) -def test_construct_blocks(start: list, end: list): - n_layers = 9 - repeating_pattern = [{ - 'name': 'layer_rp', - 'repeat': 1, - }] +def test_construct_blocks(): + n_layers = 13 config = MPTConfig( d_model=32, @@ -2664,65 +2640,78 @@ def test_construct_blocks(start: list, end: list): 'kv_n_heads': 4, }, ) - overrides = { - 'layer_s': { + + # override architecture taken from https://research.character.ai/optimizing-inference/ + config.block_overrides = {} + config.block_overrides['overrides'] = { + 'reuse_kv_layer': { 'attn_config': { - 'reuse_kv_layer_idx': -1, + 'reuse_kv_layer_idx': -6, }, }, - 'layer_rp': { + 'sliding_window_layer': { 'attn_config': { - 'sliding_window_size': 512, + 'sliding_window_size': 1024, }, }, - 'layer_e': { + 'sliding_window_layer_reuse': { 'attn_config': { - 'sliding_window_size': 512, - 'reuse_kv_layer_idx': -2, + 'sliding_window_size': 1024, + 'reuse_kv_layer_idx': -1, }, }, } - config.block_overrides = {} - - if len(start) > 0: - config.block_overrides['start'] = start - config.block_overrides['repeating_pattern'] = repeating_pattern - if len(end) > 0: - config.block_overrides['end'] = end - config.block_overrides['overrides'] = overrides + config.block_overrides['order'] = [ + { + 'name': 'default', + }, + { + 'order': [ + { + 'name': 'sliding_window_layer', + }, + { + 'name': 'sliding_window_layer_reuse', + }, + { + 'name': 'sliding_window_layer', + }, + { + 'name': 'sliding_window_layer_reuse', + 'repeat': 2, + }, + { + 'name': 'reuse_kv_layer', + }, + ], + 'repeat': 2, + }, + ] block_list = MPTModel(config).construct_blocks(config) assert len(block_list) == n_layers + assert block_list[0].attn.sliding_window_size == -1 + assert block_list[0].attn.reuse_kv_layer_idx is None - if len(start) > 0: - assert block_list[0].attn.sliding_window_size == -1 - assert block_list[0].attn.reuse_kv_layer_idx is None - assert block_list[1].attn.sliding_window_size == -1 - assert block_list[1].attn.reuse_kv_layer_idx == 0 - assert block_list[2].attn.sliding_window_size == -1 - assert block_list[2].attn.reuse_kv_layer_idx == 0 - else: - assert block_list[0].attn.sliding_window_size == 512 - assert block_list[0].attn.reuse_kv_layer_idx is None - - if len(end) > 0: - assert block_list[6].attn.sliding_window_size == 512 - assert block_list[6].attn.reuse_kv_layer_idx == 4 - assert block_list[7].attn.sliding_window_size == 512 - assert block_list[7].attn.reuse_kv_layer_idx == 5 - assert block_list[8].attn.sliding_window_size == -1 - assert block_list[8].attn.reuse_kv_layer_idx is None - else: - assert block_list[8].attn.sliding_window_size == 512 - assert block_list[8].attn.reuse_kv_layer_idx is None + for layer_offset in [1, 7]: + assert block_list[layer_offset].attn.sliding_window_size == 1024 + assert block_list[layer_offset].attn.reuse_kv_layer_idx is None + assert block_list[layer_offset + 1].attn.sliding_window_size == 1024 + assert block_list[layer_offset + + 1].attn.reuse_kv_layer_idx == layer_offset + + assert block_list[layer_offset + 2].attn.sliding_window_size == 1024 + assert block_list[layer_offset + 2].attn.reuse_kv_layer_idx is None + assert block_list[layer_offset + 3].attn.sliding_window_size == 1024 + assert block_list[layer_offset + + 3].attn.reuse_kv_layer_idx == layer_offset + 2 + assert block_list[layer_offset + 4].attn.sliding_window_size == 1024 + assert block_list[layer_offset + + 4].attn.reuse_kv_layer_idx == layer_offset + 2 - assert block_list[3].attn.sliding_window_size == 512 - assert block_list[3].attn.reuse_kv_layer_idx is None - assert block_list[4].attn.sliding_window_size == 512 - assert block_list[4].attn.reuse_kv_layer_idx is None - assert block_list[5].attn.sliding_window_size == 512 - assert block_list[5].attn.reuse_kv_layer_idx is None + assert block_list[layer_offset + 5].attn.sliding_window_size == -1 + assert block_list[layer_offset + 5].attn.reuse_kv_layer_idx == 0 @pytest.mark.gpu @@ -2810,75 +2799,41 @@ def test_override_block_args(): assert new_config['b']['c'] == 5 -@pytest.mark.parametrize('start', [True, False]) -@pytest.mark.parametrize('repeating_pattern', [True, False]) -@pytest.mark.parametrize('end', [True, False]) -def test_get_modules_order_expanded( - start: bool, - repeating_pattern: bool, - end: bool, -): - n_layers = 0 - block_overrides = {'overrides': {'a': 'b'}} - expected_list = [] - - if start: - block_overrides['start'] = [ # type: ignore - { - 'name': 'default', - 'repeat': 1, - }, - { - 'name': 'layer_a', - 'repeat': 2, - }, - ] - n_layers += 3 - expected_list.extend(['default', 'layer_a', 'layer_a']) - - if repeating_pattern: - block_overrides[ - 'repeating_pattern' - ] = [ # type: ignore - { +def test_get_modules_order_expanded(): + order = [ + { + 'name': 'default', + }, + { + 'name': 'layer_a', + 'repeat': 2, + }, + { + 'order': [{ 'name': 'layer_b', - 'repeat': 3, - }, - ] - n_layers += 6 - expected_list.extend(['layer_b'] * 6) - - if end: - block_overrides['end'] = [ # type: ignore - { - 'name': 'layer_c', - 'repeat': 1, - }, - { - 'name': 'default', - 'repeat': 2, - }, - ] - n_layers += 3 - expected_list.extend(['layer_c', 'default', 'default']) - - if n_layers == 0: - pytest.skip('Skipping because no overrides.') - - config = MPTConfig( - d_model=32, - n_heads=16, - n_layers=n_layers, - block_overrides=block_overrides, - expansion_ratio=2, - max_seq_len=64, - attn_config={ - 'attn_impl': 'flash', - 'attn_type': 'grouped_query_attention', - 'kv_n_heads': 4, + },], + 'repeat': 3, }, - ) - assert expected_list == MPTModel._get_modules_order_expanded(config) + { + 'name': 'layer_c', + 'repeat': 2, + }, + { + 'name': 'default', + }, + ] + expected_list = [ + 'default', + 'layer_a', + 'layer_a', + 'layer_b', + 'layer_b', + 'layer_b', + 'layer_c', + 'layer_c', + 'default', + ] + assert expected_list == MPTModel._get_modules_order_expanded(order) @pytest.mark.parametrize('reuse_kv_layer_idx', [-2, -1, 0]) From dc9890d171879ab6248d3064eb9a543a7244cee7 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 27 Jun 2024 14:30:18 -0700 Subject: [PATCH 55/69] fix configuation --- llmfoundry/models/mpt/configuration_mpt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index e29532ece9..b5af1f89f3 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -206,12 +206,12 @@ def __init__( def _validate_block_overrides(self, block_overrides: Dict[str, Any]): warnings.warn(ExperimentalWarning('block_overrides')) - if 'start' not in block_overrides and 'repeating_pattern' not in block_overrides and 'end' not in block_overrides: + if 'order' not in block_overrides: + raise ValueError('`order` should be defined in block_overrides',) + if 'overrides' not in block_overrides: raise ValueError( - 'either start, repeating_pattern, or end should be defined in block_overrides', + '`overrides` should be defined in block_overrides', ) - if 'overrides' not in block_overrides: - raise ValueError('overrides should be defined in block_overrides',) for name, override in block_overrides['overrides'].items(): if name == 'default': raise ValueError('block overrides cannot be named "default".',) From 4be19c68b63f9e92deb08da08f1948bf65ead0f7 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 27 Jun 2024 15:04:01 -0700 Subject: [PATCH 56/69] fixing test --- llmfoundry/models/mpt/configuration_mpt.py | 4 ++-- tests/models/test_model.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index b5af1f89f3..2656ccd428 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -139,10 +139,10 @@ def __init__( sliding_window_layer_reuse: attn_config: sliding_window_size: 1024 - reuse_kv_layer_idx: -1 + reuse_kv_layer_idx: -1 # Relative index of the layer whose kv cache to reuse reuse_kv_layer: attn_config: - reuse_kv_layer_idx: -6 + reuse_kv_layer_idx: -6 # Relative index of the layer whose kv cache to reuse """ self.d_model = d_model self.n_heads = n_heads diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 7071adc888..986cf3ea23 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2722,14 +2722,12 @@ def test_reuse_prev_layer_kv_cache( conf_path = 'scripts/train/yamls/pretrain/testing.yaml' model_config_overrides = { 'block_overrides': { - 'start': [ + 'order': [ { 'name': 'default', - 'repeat': 1, }, { 'name': 'kv_reuse_layer', - 'repeat': 1, }, ], 'overrides': { From a5298c3947fe5074e64aa517fbadf6f66562f25e Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 27 Jun 2024 16:28:23 -0700 Subject: [PATCH 57/69] allowing repeat at top level --- llmfoundry/models/mpt/modeling_mpt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 6b837d3bee..04a70287b8 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -468,9 +468,11 @@ def _get_override_block_args_list( raise ValueError( 'config.block_overrides should not be None when calling _get_override_block_args_list.', ) + repeat = config.block_overrides[ + 'repeat'] if 'repeat' in config.block_overrides else 1 model_modules_order_expanded = MPTModel._get_modules_order_expanded( config.block_overrides['order'], - ) + ) * repeat if len(model_modules_order_expanded) != config.n_layers: raise ValueError( f'The specified block overrides do not match the number of layers: {len(model_modules_order_expanded)} vs {config.n_layers}.', From 700ede161588360e92bd0425322dce16cb4c363c Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 27 Jun 2024 19:58:31 -0700 Subject: [PATCH 58/69] allowing overriding error --- llmfoundry/models/mpt/modeling_mpt.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 04a70287b8..5a315a19f0 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -555,7 +555,10 @@ def _validate_reuse_kv_layer_config( parent_config['attn_config'] = {} parent_config['attn_config']['reuse_kv_layer_idx'] = override_config[ 'attn_config']['reuse_kv_layer_idx'] - if override_config != parent_config: + if override_config != parent_config and not ( + 'allow_mismatch' in override_config and + override_config['allow_mismatch'] + ): raise ValueError( 'For reusing the kv cache of a previous layer, the previous layer should match the block config as the current layer.', ) From ca047fa8130185f19d3cff248ddcdfdb85483196 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 29 Jun 2024 12:02:49 -0700 Subject: [PATCH 59/69] addressing comments --- llmfoundry/models/mpt/configuration_mpt.py | 6 ++-- llmfoundry/models/mpt/modeling_mpt.py | 38 +++++++--------------- tests/models/test_model.py | 4 +-- 3 files changed, 18 insertions(+), 30 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 2656ccd428..0eb56d8a92 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -391,6 +391,8 @@ def _validate_config(self) -> None: @property def allowed_block_overrides(self): return { - 'sliding_window_size', - 'reuse_kv_layer_idx', + 'attn_config': { + 'sliding_window_size': None, + 'reuse_kv_layer_idx': None, + } } diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 5a315a19f0..ccb9110f23 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -468,8 +468,7 @@ def _get_override_block_args_list( raise ValueError( 'config.block_overrides should not be None when calling _get_override_block_args_list.', ) - repeat = config.block_overrides[ - 'repeat'] if 'repeat' in config.block_overrides else 1 + repeat = config.block_overrides.get('repeat', 1) model_modules_order_expanded = MPTModel._get_modules_order_expanded( config.block_overrides['order'], ) * repeat @@ -494,7 +493,7 @@ def _get_override_block_args_list( {}, ): MPTModel._validate_reuse_kv_layer_config( - block_overrides=config.block_overrides, + overrides_definition=config.block_overrides['overrides'], model_modules_order_expanded= model_modules_order_expanded, b_idx=b_idx, @@ -509,7 +508,7 @@ def _get_override_block_args_list( layer_description_list.append([ b_idx, module_name, - self._get_overrides_for_logging(override_config), + override_config, ],) new_block_args = MPTModel._override_block_args( block_args, @@ -527,7 +526,7 @@ def _get_override_block_args_list( @staticmethod def _validate_reuse_kv_layer_config( - block_overrides: Dict[str, Any], + overrides_definition: Dict[str, Any], model_modules_order_expanded: List[str], b_idx: int, override_config: Dict[str, Any], @@ -549,7 +548,7 @@ def _validate_reuse_kv_layer_config( parent_layer_name = model_modules_order_expanded[reuse_kv_layer_idx] parent_config = {} if parent_layer_name == 'default' else copy.deepcopy( - block_overrides['overrides'][parent_layer_name], + overrides_definition[parent_layer_name], ) if 'attn_config' not in parent_config: parent_config['attn_config'] = {} @@ -585,25 +584,17 @@ def _get_modules_order_expanded(order: List[Dict[str, Any]]) -> List[str]: return model_modules_order_expanded - def _get_overrides_for_logging( - self, - override_config: Dict[str, Any], - ) -> List[dict[str, str]]: - # Flattens the override config for logging. - overrides_list = [] - for k, v in override_config.items(): - if isinstance(v, dict): - overrides_list.extend(self._get_overrides_for_logging(v)) - else: - overrides_list.append({k: v}) - return overrides_list @staticmethod def _override_block_args( block_args: Dict[str, Any], override_config: Dict[str, Any], - allowed_block_overrides: set, + allowed_block_overrides: Dict[str, Any], ) -> Dict[str, Any]: + unpermitted_keys = override_config.keys() - allowed_block_overrides.keys() + if len(unpermitted_keys): + raise KeyError(f'Overriding {unpermitted_keys} is not supported.') + new_block_args = override_config | block_args common_keys = override_config.keys() & block_args.keys() for k in common_keys: @@ -615,11 +606,9 @@ def _override_block_args( new_block_args[k] = MPTModel._override_block_args( block_args[k], override_config[k], - allowed_block_overrides, + allowed_block_overrides[k], ) else: - if k not in allowed_block_overrides: - raise KeyError(f'Overriding {k} is not supported.') new_block_args[k] = override_config[k] return new_block_args @@ -903,10 +892,7 @@ def forward( layer_kv_cache_dict = {} for b_idx, block in enumerate(self.blocks): - attn_block = block.attn if hasattr( - block, - 'attn', - ) else block.norm_attn_norm.attn + attn_block = block.norm_attn_norm.attn if self.config.fuse_norm_attn_norm else block.attn if attn_block.reuse_kv_layer_idx is not None: if attn_block.reuse_kv_layer_idx not in layer_kv_cache_dict: raise KeyError( diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 986cf3ea23..a3ae9b3c8f 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2785,7 +2785,7 @@ def mock_forward(b_forward, b_idx, *args, **kwargs): # type: ignore def test_override_block_args(): block_args = {'a': 1, 'b': {'c': 3}, 'd': 4} override_config = {'a': 2, 'b': {'c': 5}, 'e': 6} - allowed_block_overrides = {'a', 'c', 'e'} + allowed_block_overrides = {'a': None, 'b': {'c': None}, 'e': None} new_config = MPTModel._override_block_args( block_args, override_config, @@ -2877,7 +2877,7 @@ def test_validate_reuse_kv_layer_config(reuse_kv_layer_idx: int): def _validate_helper(b_idx: int): MPTModel._validate_reuse_kv_layer_config( - block_overrides=block_overrides, + overrides_definition=block_overrides['overrides'], model_modules_order_expanded=model_modules_order_expanded, b_idx=b_idx, override_config=copy.deepcopy( From 3ca3d55382c914e6d05643f12dbeb331e354147e Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 29 Jun 2024 12:05:08 -0700 Subject: [PATCH 60/69] lint --- llmfoundry/models/mpt/configuration_mpt.py | 2 +- llmfoundry/models/mpt/modeling_mpt.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 0eb56d8a92..a1fdc25f50 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -394,5 +394,5 @@ def allowed_block_overrides(self): 'attn_config': { 'sliding_window_size': None, 'reuse_kv_layer_idx': None, - } + }, } diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index ccb9110f23..a93f57068e 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -493,7 +493,8 @@ def _get_override_block_args_list( {}, ): MPTModel._validate_reuse_kv_layer_config( - overrides_definition=config.block_overrides['overrides'], + overrides_definition=config. + block_overrides['overrides'], model_modules_order_expanded= model_modules_order_expanded, b_idx=b_idx, @@ -584,14 +585,14 @@ def _get_modules_order_expanded(order: List[Dict[str, Any]]) -> List[str]: return model_modules_order_expanded - @staticmethod def _override_block_args( block_args: Dict[str, Any], override_config: Dict[str, Any], allowed_block_overrides: Dict[str, Any], ) -> Dict[str, Any]: - unpermitted_keys = override_config.keys() - allowed_block_overrides.keys() + unpermitted_keys = override_config.keys( + ) - allowed_block_overrides.keys() if len(unpermitted_keys): raise KeyError(f'Overriding {unpermitted_keys} is not supported.') From 3fbec2e7a419557dcab04e3272a009dcbd6ef26c Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 29 Jun 2024 12:11:14 -0700 Subject: [PATCH 61/69] addressing comments --- llmfoundry/models/mpt/modeling_mpt.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index a93f57068e..0a08b8cdba 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -442,7 +442,7 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: nn.ModuleList: The list of Transformer blocks. """ block_args = self.extract_block_args(config.to_dict()) - self.kv_cache_layers = None + self.kv_cache_layers = set() if config.block_overrides is not None: block_args_list = self._get_override_block_args_list( @@ -501,8 +501,6 @@ def _get_override_block_args_list( override_config=override_config, reuse_kv_layer_idx_dict=reuse_kv_layer_idx_dict, ) - if self.kv_cache_layers is None: - self.kv_cache_layers = set() self.kv_cache_layers.add( override_config['attn_config']['reuse_kv_layer_idx'], ) @@ -873,7 +871,7 @@ def forward( # initialize the past key values cache if it should be used presents = () if use_cache else None if ( - use_cache or self.kv_cache_layers is not None + use_cache or len(self.kv_cache_layers) > 0 ) and past_key_values is None: past_key_values = [() for _ in range(self.config.n_layers) ] # type: ignore @@ -923,7 +921,7 @@ def forward( ) if presents is not None: presents += (present,) - if self.kv_cache_layers is not None and b_idx in self.kv_cache_layers: + if b_idx in self.kv_cache_layers: layer_kv_cache_dict[b_idx] = [ present[0][:, past_position:], present[1][:, past_position:], From e7d85bb9aa6b1c1ca34fce618ce5438be717e5e1 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 29 Jun 2024 12:34:40 -0700 Subject: [PATCH 62/69] fix --- llmfoundry/models/mpt/modeling_mpt.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 0a08b8cdba..023d3ae4f5 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -891,7 +891,10 @@ def forward( layer_kv_cache_dict = {} for b_idx, block in enumerate(self.blocks): - attn_block = block.norm_attn_norm.attn if self.config.fuse_norm_attn_norm else block.attn + attn_block = block.norm_attn_norm.attn if self.config.get( + 'fuse_norm_attn_norm', + False, + ) else block.attn if attn_block.reuse_kv_layer_idx is not None: if attn_block.reuse_kv_layer_idx not in layer_kv_cache_dict: raise KeyError( From 8d97354560225331d60472ee6c06445dc85afed2 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 29 Jun 2024 12:45:54 -0700 Subject: [PATCH 63/69] .. --- llmfoundry/models/mpt/modeling_mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 023d3ae4f5..b29b82e068 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -891,7 +891,7 @@ def forward( layer_kv_cache_dict = {} for b_idx, block in enumerate(self.blocks): - attn_block = block.norm_attn_norm.attn if self.config.get( + attn_block = block.norm_attn_norm.attn if self.config.to_dict().get( 'fuse_norm_attn_norm', False, ) else block.attn From 54abbd26a8a4a9d3360cc4ccb7d5cb47af1b41a4 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 29 Jun 2024 12:50:34 -0700 Subject: [PATCH 64/69] .. --- llmfoundry/models/mpt/modeling_mpt.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index b29b82e068..d2d666daee 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -443,6 +443,10 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: """ block_args = self.extract_block_args(config.to_dict()) self.kv_cache_layers = set() + self.blocks_fuse_norm_attn_norm = block_args.get( + 'fuse_norm_attn_norm', + False, + ) if config.block_overrides is not None: block_args_list = self._get_override_block_args_list( @@ -891,10 +895,7 @@ def forward( layer_kv_cache_dict = {} for b_idx, block in enumerate(self.blocks): - attn_block = block.norm_attn_norm.attn if self.config.to_dict().get( - 'fuse_norm_attn_norm', - False, - ) else block.attn + attn_block = block.norm_attn_norm.attn if self.blocks_fuse_norm_attn_norm else block.attn if attn_block.reuse_kv_layer_idx is not None: if attn_block.reuse_kv_layer_idx not in layer_kv_cache_dict: raise KeyError( From 52536cd2d5f879b9445da17edb5e7128574e21ef Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 29 Jun 2024 13:31:35 -0700 Subject: [PATCH 65/69] .. --- llmfoundry/models/mpt/modeling_mpt.py | 26 ++++++++++++++------------ tests/models/test_model.py | 4 ++-- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d2d666daee..f9089674e7 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -496,7 +496,7 @@ def _get_override_block_args_list( 'attn_config', {}, ): - MPTModel._validate_reuse_kv_layer_config( + reuse_kv_layer_idx = MPTModel._resolve_reuse_kv_layer_config( overrides_definition=config. block_overrides['overrides'], model_modules_order_expanded= @@ -505,20 +505,21 @@ def _get_override_block_args_list( override_config=override_config, reuse_kv_layer_idx_dict=reuse_kv_layer_idx_dict, ) - self.kv_cache_layers.add( - override_config['attn_config']['reuse_kv_layer_idx'], - ) + override_config['attn_config']['reuse_kv_layer_idx' + ] = reuse_kv_layer_idx + self.kv_cache_layers.add(reuse_kv_layer_idx) layer_description_list.append([ b_idx, module_name, override_config, ],) - new_block_args = MPTModel._override_block_args( - block_args, - override_config, - config.allowed_block_overrides, + new_block_args_list.append( + MPTModel._override_block_args( + block_args, + override_config, + config.allowed_block_overrides, + ), ) - new_block_args_list.append(new_block_args) log.info( 'The following is a summary of overrides per layer.\n' + tabulate( layer_description_list, @@ -528,13 +529,13 @@ def _get_override_block_args_list( return new_block_args_list @staticmethod - def _validate_reuse_kv_layer_config( + def _resolve_reuse_kv_layer_config( overrides_definition: Dict[str, Any], model_modules_order_expanded: List[str], b_idx: int, override_config: Dict[str, Any], reuse_kv_layer_idx_dict: Dict[int, int], - ): + ) -> int: override_attn_config = override_config['attn_config'] if override_attn_config['reuse_kv_layer_idx'] >= 0: raise ValueError( @@ -557,6 +558,7 @@ def _validate_reuse_kv_layer_config( parent_config['attn_config'] = {} parent_config['attn_config']['reuse_kv_layer_idx'] = override_config[ 'attn_config']['reuse_kv_layer_idx'] + if override_config != parent_config and not ( 'allow_mismatch' in override_config and override_config['allow_mismatch'] @@ -565,7 +567,7 @@ def _validate_reuse_kv_layer_config( 'For reusing the kv cache of a previous layer, the previous layer should match the block config as the current layer.', ) - override_attn_config['reuse_kv_layer_idx'] = reuse_kv_layer_idx + return reuse_kv_layer_idx @staticmethod def _get_modules_order_expanded(order: List[Dict[str, Any]]) -> List[str]: diff --git a/tests/models/test_model.py b/tests/models/test_model.py index a3ae9b3c8f..e5ecae1090 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2835,7 +2835,7 @@ def test_get_modules_order_expanded(): @pytest.mark.parametrize('reuse_kv_layer_idx', [-2, -1, 0]) -def test_validate_reuse_kv_layer_config(reuse_kv_layer_idx: int): +def test_resolve_reuse_kv_layer_config(reuse_kv_layer_idx: int): layer_a_override = { 'key_1': 'value_a', 'attn_config': { @@ -2876,7 +2876,7 @@ def test_validate_reuse_kv_layer_config(reuse_kv_layer_idx: int): reuse_kv_layer_idx_dict = {} def _validate_helper(b_idx: int): - MPTModel._validate_reuse_kv_layer_config( + MPTModel._resolve_reuse_kv_layer_config( overrides_definition=block_overrides['overrides'], model_modules_order_expanded=model_modules_order_expanded, b_idx=b_idx, From 6d62c2c24a4eea41f6cccbd83c360c9e773b068c Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 29 Jun 2024 13:33:28 -0700 Subject: [PATCH 66/69] .. --- llmfoundry/models/mpt/modeling_mpt.py | 4 ++-- tests/models/test_model.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index f9089674e7..d1a4cb37be 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -496,7 +496,7 @@ def _get_override_block_args_list( 'attn_config', {}, ): - reuse_kv_layer_idx = MPTModel._resolve_reuse_kv_layer_config( + reuse_kv_layer_idx = MPTModel._resolve_reuse_kv_layer_idx( overrides_definition=config. block_overrides['overrides'], model_modules_order_expanded= @@ -529,7 +529,7 @@ def _get_override_block_args_list( return new_block_args_list @staticmethod - def _resolve_reuse_kv_layer_config( + def _resolve_reuse_kv_layer_idx( overrides_definition: Dict[str, Any], model_modules_order_expanded: List[str], b_idx: int, diff --git a/tests/models/test_model.py b/tests/models/test_model.py index e5ecae1090..91c6b55899 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2835,7 +2835,7 @@ def test_get_modules_order_expanded(): @pytest.mark.parametrize('reuse_kv_layer_idx', [-2, -1, 0]) -def test_resolve_reuse_kv_layer_config(reuse_kv_layer_idx: int): +def test_resolve_reuse_kv_layer_idx(reuse_kv_layer_idx: int): layer_a_override = { 'key_1': 'value_a', 'attn_config': { @@ -2876,7 +2876,7 @@ def test_resolve_reuse_kv_layer_config(reuse_kv_layer_idx: int): reuse_kv_layer_idx_dict = {} def _validate_helper(b_idx: int): - MPTModel._resolve_reuse_kv_layer_config( + MPTModel._resolve_reuse_kv_layer_idx( overrides_definition=block_overrides['overrides'], model_modules_order_expanded=model_modules_order_expanded, b_idx=b_idx, From 2b53237cab54a38ed7e1069f8dd38b1d9c082eef Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 29 Jun 2024 13:41:28 -0700 Subject: [PATCH 67/69] .. --- tests/models/test_model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 91c6b55899..52a0c5bfef 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2875,8 +2875,8 @@ def test_resolve_reuse_kv_layer_idx(reuse_kv_layer_idx: int): ] reuse_kv_layer_idx_dict = {} - def _validate_helper(b_idx: int): - MPTModel._resolve_reuse_kv_layer_idx( + def _validate_helper(b_idx: int) -> int: + return MPTModel._resolve_reuse_kv_layer_idx( overrides_definition=block_overrides['overrides'], model_modules_order_expanded=model_modules_order_expanded, b_idx=b_idx, @@ -2888,9 +2888,9 @@ def _validate_helper(b_idx: int): ) if reuse_kv_layer_idx == -1: - _validate_helper(b_idx=2) - _validate_helper(b_idx=3) - _validate_helper(b_idx=4) + assert _validate_helper(b_idx=2) == 1 + assert _validate_helper(b_idx=3) == 1 + assert _validate_helper(b_idx=4) == 1 with pytest.raises( expected_exception=ValueError, match= @@ -2899,7 +2899,7 @@ def _validate_helper(b_idx: int): _validate_helper(b_idx=6) elif reuse_kv_layer_idx == -2: - _validate_helper(b_idx=2) + assert _validate_helper(b_idx=2) == 0 else: with pytest.raises( expected_exception=ValueError, From d53a1e634da0689472b3e3fe2444e866b2ba74b3 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sun, 30 Jun 2024 13:54:26 -0700 Subject: [PATCH 68/69] addressing comment --- llmfoundry/models/layers/attention.py | 8 ++++++-- llmfoundry/models/layers/blocks.py | 12 +++++++++--- llmfoundry/models/mpt/modeling_mpt.py | 5 ++++- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index bf645881af..a10a9dd7b0 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -525,7 +525,10 @@ def forward( torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: - query, key, value = self.get_qkv(x, prev_layer_key_value) + extra_kwargs = {} + if prev_layer_key_value is not None: + extra_kwargs['prev_layer_key_value'] = prev_layer_key_value + query, key, value = self.get_qkv(x, **extra_kwargs) if rotary_emb_w_meta_info is not None: query, key, value = self._apply_rotary_embeddings( @@ -562,7 +565,8 @@ def forward( def get_qkv( self, x: torch.Tensor, - prev_layer_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]], + prev_layer_key_value: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Computes and returns the query, key, and value tensors. diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 71bd3be5c3..401acacfc6 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -162,6 +162,9 @@ def forward( torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: + extra_kwargs = {} + if prev_layer_key_value is not None: + extra_kwargs['prev_layer_key_value'] = prev_layer_key_value if self.fuse_norm_attn_norm: x, m, attn_weights, past_key_value = self.norm_attn_norm( x, @@ -173,7 +176,7 @@ def forward( output_attentions=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, - prev_layer_key_value=prev_layer_key_value, + **extra_kwargs, ) else: a = self.norm_1(x) @@ -187,7 +190,7 @@ def forward( needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, - prev_layer_key_value=prev_layer_key_value, + **extra_kwargs, ) x = x + self.resid_attn_dropout(b) m = x @@ -317,6 +320,9 @@ def forward( ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) + extra_kwargs = {} + if prev_layer_key_value is not None: + extra_kwargs['prev_layer_key_value'] = prev_layer_key_value b, attn_weights, past_key_value = self.attn( a, past_key_value=past_key_value, @@ -327,7 +333,7 @@ def forward( needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, - prev_layer_key_value=prev_layer_key_value, + **extra_kwargs, ) x = x + self.resid_attn_dropout(b) m = x diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d1a4cb37be..3cdd8fea12 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -913,6 +913,9 @@ def forward( past_key_value = ( past_key_values[b_idx] if past_key_values is not None else None ) + extra_kwargs = {} + if prev_layer_key_value is not None: + extra_kwargs['prev_layer_key_value'] = prev_layer_key_value x, attn_weights, present = block( x, past_key_value=past_key_value, @@ -923,7 +926,7 @@ def forward( output_attentions=bool(output_attentions), alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, - prev_layer_key_value=prev_layer_key_value, + **extra_kwargs, ) if presents is not None: presents += (present,) From 6554a22b14d47209a3fbe8665eaffafbe2be3d18 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sun, 30 Jun 2024 14:50:14 -0700 Subject: [PATCH 69/69] fixing test --- tests/models/test_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 52a0c5bfef..8d47deddaa 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2758,7 +2758,8 @@ def test_reuse_prev_layer_kv_cache( prev_layer_key_value_dict = {} def mock_forward(b_forward, b_idx, *args, **kwargs): # type: ignore - prev_layer_key_value_dict[b_idx] = kwargs['prev_layer_key_value'] + if 'prev_layer_key_value' in kwargs: + prev_layer_key_value_dict[b_idx] = kwargs['prev_layer_key_value'] return b_forward(*args, **kwargs) for b_idx, block in enumerate(model.model.transformer.blocks): @@ -2773,7 +2774,7 @@ def mock_forward(b_forward, b_idx, *args, **kwargs): # type: ignore assert torch.all( outputs.past_key_values[0][1] == outputs.past_key_values[1][1], ) - assert prev_layer_key_value_dict[0] is None + assert 0 not in prev_layer_key_value_dict assert torch.all( prev_layer_key_value_dict[1][0] == outputs.past_key_values[0][0], )