diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 9b34190edf..a10a9dd7b0 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -416,6 +416,7 @@ def __init__( device: Optional[str] = None, bias: bool = True, sliding_window_size: int = -1, + reuse_kv_layer_idx: Optional[int] = None, ): super().__init__() @@ -428,6 +429,7 @@ def __init__( self.n_heads = n_heads self.kv_n_heads = kv_n_heads self.sliding_window_size = sliding_window_size + self.reuse_kv_layer_idx = reuse_kv_layer_idx self.head_dim = d_model // n_heads @@ -458,18 +460,29 @@ def __init__( self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = attn_pdrop - self.Wqkv = build_fc( - name=fc_type_name, - in_features=self.d_model, - out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim, - fc_kwargs=fc_type, - ) - # for param init fn; enables shape based init of fused layers - fuse_splits = [ - i * self.head_dim - for i in range(1, self.n_heads + 2 * self.kv_n_heads) - ] - self.Wqkv._fused = (0, fuse_splits) + if self.reuse_kv_layer_idx is None: + self.Wqkv = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim, + fc_kwargs=fc_type, + ) + # for param init fn; enables shape based init of fused layers + fuse_splits = [ + i * self.head_dim + for i in range(1, self.n_heads + 2 * self.kv_n_heads) + ] + self.Wqkv._fused = (0, fuse_splits) + else: + self.Wq = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.d_model, + fc_kwargs=fc_type, + ) + # for param init fn; enables shape based init of fused layers + fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)] + self.Wq._fused = (0, fuse_splits) if self.qk_ln or self.qk_gn: norm_size = self.head_dim if qk_gn else d_model @@ -478,13 +491,14 @@ def __init__( normalized_shape=norm_size, device=device, ) - if qk_ln: - norm_size = self.head_dim * kv_n_heads - self.k_ln = build_norm( - name=norm_type.lower(), - normalized_shape=norm_size, - device=device, - ) + if self.reuse_kv_layer_idx is None: + if qk_ln: + norm_size = self.head_dim * kv_n_heads + self.k_ln = build_norm( + name=norm_type.lower(), + normalized_shape=norm_size, + device=device, + ) self.attn_fn = attention_implementations.get(self.attn_impl) @@ -507,9 +521,14 @@ def forward( needs_weights: bool = False, alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, + prev_layer_key_value: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: - query, key, value = self.get_qkv(x) + extra_kwargs = {} + if prev_layer_key_value is not None: + extra_kwargs['prev_layer_key_value'] = prev_layer_key_value + query, key, value = self.get_qkv(x, **extra_kwargs) if rotary_emb_w_meta_info is not None: query, key, value = self._apply_rotary_embeddings( @@ -546,6 +565,8 @@ def forward( def get_qkv( self, x: torch.Tensor, + prev_layer_key_value: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Computes and returns the query, key, and value tensors. @@ -557,6 +578,27 @@ def get_qkv( key (torch.Tensor): The key tensor. value (torch.Tensor): The value tensor. """ + if self.reuse_kv_layer_idx is not None: + if prev_layer_key_value is None: + raise ValueError( + 'prev_layer_key_value is None, cannot reuse_prev_layer_kv.', + ) + key, value = prev_layer_key_value + + query = self.Wq(x) + if self.clip_qkv: + query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv) + + if self.qk_ln or self.qk_gn: + # Applying layernorm to qk + q_shape = query.shape + if self.qk_gn: + b, s = query.shape[:2] + query = query.view(b, s, self.n_heads, -1) + dtype = query.dtype + query = self.q_ln(query).to(dtype).view(q_shape) + return query, key, value + qkv = self.Wqkv(x) if self.clip_qkv: @@ -591,6 +633,10 @@ def _apply_rotary_embeddings( key: torch.Tensor, value: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self.reuse_kv_layer_idx is not None: + orig_key, orig_value = key, value + key, value = torch.empty_like(key), torch.empty_like(value) + rotary_emb = rotary_emb_w_meta_info['rotary_emb'] seq_len = rotary_emb_w_meta_info['seq_len'] offset_info = rotary_emb_w_meta_info['offset_info'] @@ -602,6 +648,7 @@ def _apply_rotary_embeddings( value = value.view(bsz, seqlen, -1, self.head_dim) kv = torch.stack([key, value], dim=2) + # Note: Rotates in place (https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/flash_attn/layers/rotary.py#L429) query, kv = rotary_emb( query, kv, @@ -652,6 +699,8 @@ def _apply_rotary_embeddings( query = query.view(bsz, seqlen, -1) key = key.view(bsz, seqlen, -1) + if self.reuse_kv_layer_idx is not None: + return query, orig_key, orig_value # type: ignore return query, key, value def get_implementation_specific_args( @@ -705,6 +754,7 @@ def __init__( device: Optional[str] = None, bias: bool = True, sliding_window_size: int = -1, + reuse_kv_layer_idx: Optional[int] = None, ): super().__init__( d_model=d_model, @@ -721,6 +771,7 @@ def __init__( device=device, bias=bias, sliding_window_size=sliding_window_size, + reuse_kv_layer_idx=reuse_kv_layer_idx, ) @@ -746,6 +797,7 @@ def __init__( device: Optional[str] = None, bias: bool = True, sliding_window_size: int = -1, + reuse_kv_layer_idx: Optional[int] = None, ): super().__init__( d_model=d_model, @@ -762,6 +814,7 @@ def __init__( device=device, bias=bias, sliding_window_size=sliding_window_size, + reuse_kv_layer_idx=reuse_kv_layer_idx, ) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 59aa497b78..401acacfc6 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -158,8 +158,13 @@ def forward( output_attentions: bool = False, alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, + prev_layer_key_value: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: + extra_kwargs = {} + if prev_layer_key_value is not None: + extra_kwargs['prev_layer_key_value'] = prev_layer_key_value if self.fuse_norm_attn_norm: x, m, attn_weights, past_key_value = self.norm_attn_norm( x, @@ -171,6 +176,7 @@ def forward( output_attentions=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + **extra_kwargs, ) else: a = self.norm_1(x) @@ -184,6 +190,7 @@ def forward( needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + **extra_kwargs, ) x = x + self.resid_attn_dropout(b) m = x @@ -308,9 +315,14 @@ def forward( output_attentions: bool = False, alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, + prev_layer_key_value: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) + extra_kwargs = {} + if prev_layer_key_value is not None: + extra_kwargs['prev_layer_key_value'] = prev_layer_key_value b, attn_weights, past_key_value = self.attn( a, past_key_value=past_key_value, @@ -321,6 +333,7 @@ def forward( needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + **extra_kwargs, ) x = x + self.resid_attn_dropout(b) m = x diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 9205c0e505..a1fdc25f50 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -20,6 +20,7 @@ ffn_config_defaults, init_config_defaults, ) +from llmfoundry.utils.warnings import ExperimentalWarning class MPTConfig(PretrainedConfig): @@ -48,6 +49,7 @@ def __init__( fc_type: Union[str, Dict] = 'torch', tie_word_embeddings: bool = True, use_pad_tok_in_ffn: bool = True, + block_overrides: Optional[Dict[str, Any]] = None, **kwargs: Any, ): """The MPT configuration class. @@ -117,6 +119,30 @@ def __init__( also be a dictionary that specifies the fc layer name and any kwargs for the fc layer. tie_word_embeddings (bool): Whether to tie the input embedding and output layers. use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks. + block_overrides: This allows for overriding default block configs for certain layers. This must contain `overrides` and `order`. `order` is a nested list which describes the order of the layers. For each kind of layer, specify the `overrides` in the overrides config (default refers to a layer that does not apply any overrides). + To specify this model (https://research.character.ai/optimizing-inference/) , the following config will be needed: + block_overrides: + order: + - name: default + - repeat: 2 + order: + - name: sliding_window_layer + - name: sliding_window_layer_reuse + - name: sliding_window_layer + - repeat: 2 + name: sliding_window_layer_reuse + - name: reuse_kv_layer + overrides: + sliding_window_layer: + attn_config: + sliding_window_size: 1024 + sliding_window_layer_reuse: + attn_config: + sliding_window_size: 1024 + reuse_kv_layer_idx: -1 # Relative index of the layer whose kv cache to reuse + reuse_kv_layer: + attn_config: + reuse_kv_layer_idx: -6 # Relative index of the layer whose kv cache to reuse """ self.d_model = d_model self.n_heads = n_heads @@ -145,6 +171,15 @@ def __init__( init_config_defaults, ) + if 'reuse_kv_layer_idx' in self.attn_config and self.attn_config[ + 'attn_impl'] == 'torch': + raise NotImplementedError( + 'reusing kv cache from a previous layer is not implemented for torch attention.', + ) + if block_overrides is not None: + self._validate_block_overrides(block_overrides) + self.block_overrides = block_overrides + if isinstance(fc_type, str): fc_type = {'name': fc_type} self.fc_type = fc_type @@ -169,6 +204,23 @@ def __init__( self._validate_config() + def _validate_block_overrides(self, block_overrides: Dict[str, Any]): + warnings.warn(ExperimentalWarning('block_overrides')) + if 'order' not in block_overrides: + raise ValueError('`order` should be defined in block_overrides',) + if 'overrides' not in block_overrides: + raise ValueError( + '`overrides` should be defined in block_overrides', + ) + for name, override in block_overrides['overrides'].items(): + if name == 'default': + raise ValueError('block overrides cannot be named "default".',) + if 'attn_config' in override and 'reuse_kv_layer_idx' in override[ + 'attn_config'] and self.attn_config['attn_impl'] == 'torch': + raise NotImplementedError( + 'reusing kv cache from a previous layer is not implemented for torch attention.', + ) + def _set_config_defaults( self, config: Dict[str, Any], @@ -335,3 +387,12 @@ def _validate_config(self) -> None: ) self.validate_attention_config() + + @property + def allowed_block_overrides(self): + return { + 'attn_config': { + 'sliding_window_size': None, + 'reuse_kv_layer_idx': None, + }, + } diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index c8884a03a1..3cdd8fea12 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -8,6 +8,7 @@ from __future__ import annotations +import copy import math import warnings from functools import cached_property @@ -28,6 +29,7 @@ import torch.nn.functional as F from composer.models import HuggingFaceModel from composer.utils import dist +from tabulate import tabulate from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.attention import is_flash_v2_installed @@ -440,14 +442,181 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: nn.ModuleList: The list of Transformer blocks. """ block_args = self.extract_block_args(config.to_dict()) + self.kv_cache_layers = set() + self.blocks_fuse_norm_attn_norm = block_args.get( + 'fuse_norm_attn_norm', + False, + ) + + if config.block_overrides is not None: + block_args_list = self._get_override_block_args_list( + config, + block_args, + ) + else: + block_args_list = [block_args for _ in range(config.n_layers)] return nn.ModuleList([ self.block_class( device=config.init_device, - **block_args, - ) for _ in range(config.n_layers) + **block_args_i, + ) for block_args_i in block_args_list ]) + def _get_override_block_args_list( + self, + config: MPTConfig, + block_args: Dict[str, Any], + ) -> List[Dict[str, Any]]: + if config.block_overrides is None: + raise ValueError( + 'config.block_overrides should not be None when calling _get_override_block_args_list.', + ) + repeat = config.block_overrides.get('repeat', 1) + model_modules_order_expanded = MPTModel._get_modules_order_expanded( + config.block_overrides['order'], + ) * repeat + if len(model_modules_order_expanded) != config.n_layers: + raise ValueError( + f'The specified block overrides do not match the number of layers: {len(model_modules_order_expanded)} vs {config.n_layers}.', + ) + + new_block_args_list = [] + layer_description_list = [] + + reuse_kv_layer_idx_dict = {} + for b_idx in range(config.n_layers): + module_name = model_modules_order_expanded[b_idx] + override_config = {} + if module_name != 'default': + override_config = copy.deepcopy( + config.block_overrides['overrides'][module_name], + ) + if 'reuse_kv_layer_idx' in override_config.get( + 'attn_config', + {}, + ): + reuse_kv_layer_idx = MPTModel._resolve_reuse_kv_layer_idx( + overrides_definition=config. + block_overrides['overrides'], + model_modules_order_expanded= + model_modules_order_expanded, + b_idx=b_idx, + override_config=override_config, + reuse_kv_layer_idx_dict=reuse_kv_layer_idx_dict, + ) + override_config['attn_config']['reuse_kv_layer_idx' + ] = reuse_kv_layer_idx + self.kv_cache_layers.add(reuse_kv_layer_idx) + layer_description_list.append([ + b_idx, + module_name, + override_config, + ],) + new_block_args_list.append( + MPTModel._override_block_args( + block_args, + override_config, + config.allowed_block_overrides, + ), + ) + log.info( + 'The following is a summary of overrides per layer.\n' + tabulate( + layer_description_list, + headers=['idx', 'name', 'overrides'], + ), + ) + return new_block_args_list + + @staticmethod + def _resolve_reuse_kv_layer_idx( + overrides_definition: Dict[str, Any], + model_modules_order_expanded: List[str], + b_idx: int, + override_config: Dict[str, Any], + reuse_kv_layer_idx_dict: Dict[int, int], + ) -> int: + override_attn_config = override_config['attn_config'] + if override_attn_config['reuse_kv_layer_idx'] >= 0: + raise ValueError( + f'The relative index of kv layer to reuse, {override_attn_config["reuse_kv_layer_idx"]=}, should be negative.', + ) + reuse_kv_layer_idx = b_idx + override_attn_config['reuse_kv_layer_idx'] + if reuse_kv_layer_idx < 0: + raise ValueError( + f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.', + ) + if reuse_kv_layer_idx in reuse_kv_layer_idx_dict: + reuse_kv_layer_idx = reuse_kv_layer_idx_dict[reuse_kv_layer_idx] + reuse_kv_layer_idx_dict[b_idx] = reuse_kv_layer_idx + + parent_layer_name = model_modules_order_expanded[reuse_kv_layer_idx] + parent_config = {} if parent_layer_name == 'default' else copy.deepcopy( + overrides_definition[parent_layer_name], + ) + if 'attn_config' not in parent_config: + parent_config['attn_config'] = {} + parent_config['attn_config']['reuse_kv_layer_idx'] = override_config[ + 'attn_config']['reuse_kv_layer_idx'] + + if override_config != parent_config and not ( + 'allow_mismatch' in override_config and + override_config['allow_mismatch'] + ): + raise ValueError( + 'For reusing the kv cache of a previous layer, the previous layer should match the block config as the current layer.', + ) + + return reuse_kv_layer_idx + + @staticmethod + def _get_modules_order_expanded(order: List[Dict[str, Any]]) -> List[str]: + model_modules_order_expanded = [] + for item in order: + repeat = item['repeat'] if 'repeat' in item else 1 + if ('name' in item) == ('order' in item): + raise ValueError( + 'Exactly one of `order` or `name` must be specified for each block override.', + ) + + if 'name' in item: + model_modules_order_expanded.extend([item['name']] * repeat) + else: + model_modules_order_expanded.extend( + MPTModel._get_modules_order_expanded(item['order']) * + repeat, + ) + + return model_modules_order_expanded + + @staticmethod + def _override_block_args( + block_args: Dict[str, Any], + override_config: Dict[str, Any], + allowed_block_overrides: Dict[str, Any], + ) -> Dict[str, Any]: + unpermitted_keys = override_config.keys( + ) - allowed_block_overrides.keys() + if len(unpermitted_keys): + raise KeyError(f'Overriding {unpermitted_keys} is not supported.') + + new_block_args = override_config | block_args + common_keys = override_config.keys() & block_args.keys() + for k in common_keys: + if type(override_config[k]) != type(block_args[k]): + raise ValueError( + f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.', + ) + if isinstance(override_config[k], dict): + new_block_args[k] = MPTModel._override_block_args( + block_args[k], + override_config[k], + allowed_block_overrides[k], + ) + else: + new_block_args[k] = override_config[k] + return new_block_args + def extract_block_args(self, block_args: Dict[str, Any]) -> Dict[str, Any]: """Sets the block args.""" if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: @@ -707,7 +876,9 @@ def forward( # initialize the past key values cache if it should be used presents = () if use_cache else None - if use_cache and past_key_values is None: + if ( + use_cache or len(self.kv_cache_layers) > 0 + ) and past_key_values is None: past_key_values = [() for _ in range(self.config.n_layers) ] # type: ignore @@ -724,13 +895,27 @@ def forward( attention_mask, ) + layer_kv_cache_dict = {} for b_idx, block in enumerate(self.blocks): + attn_block = block.norm_attn_norm.attn if self.blocks_fuse_norm_attn_norm else block.attn + if attn_block.reuse_kv_layer_idx is not None: + if attn_block.reuse_kv_layer_idx not in layer_kv_cache_dict: + raise KeyError( + f'kv cache for layer {block.reuse_kv_layer_idx} not found in {layer_kv_cache_dict=}.', + ) + prev_layer_key_value = layer_kv_cache_dict[ + attn_block.reuse_kv_layer_idx] + else: + prev_layer_key_value = None if output_hidden_states: assert all_hidden_states is not None # pyright all_hidden_states = all_hidden_states + (x,) past_key_value = ( past_key_values[b_idx] if past_key_values is not None else None ) + extra_kwargs = {} + if prev_layer_key_value is not None: + extra_kwargs['prev_layer_key_value'] = prev_layer_key_value x, attn_weights, present = block( x, past_key_value=past_key_value, @@ -741,9 +926,15 @@ def forward( output_attentions=bool(output_attentions), alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + **extra_kwargs, ) if presents is not None: presents += (present,) + if b_idx in self.kv_cache_layers: + layer_kv_cache_dict[b_idx] = [ + present[0][:, past_position:], + present[1][:, past_position:], + ] if output_attentions: assert all_self_attns is not None # pyright @@ -1233,6 +1424,12 @@ def flops_per_batch(self, batch: Mapping): # that the dataset has been constructed without padding. Additionally, we # assume the backward pass is approximately 2x the forward pass + if self.model.config.block_overrides is not None: + warnings.warn( + 'Warning, flop computation is not supported when using block overrides. Returning 0 flops per batch.', + ) + return 0 + bs, msl = batch['input_ids'].shape[0:2] params = self.n_active_params params_flops_per_token = 2 * params diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index 669a6a93a1..01d982052f 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -537,3 +537,212 @@ def test_grouped_query_invalid_heads(): with pytest.raises(ValueError, match=expected_error): _ = attention.GroupedQueryAttention(**cfg) + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'pos_emb_config', + [{ + 'alibi': True, + 'rope': False, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + }], +) +def test_reuse_prev_layer_kv_cache( + pos_emb_config: dict, + device: str = 'cuda', +): + """Checks reusing previous layer's kv cache.""" + alibi = pos_emb_config['alibi'] + rope = pos_emb_config['rope'] + + cfg = { + 'attn_impl': 'flash', + 'd_model': 64, + 'n_heads': 4, + 'attn_pdrop': 0, + 'clip_qkv': True, + } + + n, s, f = 2, 4, cfg['d_model'] + assert cfg['d_model'] % cfg['n_heads'] == 0 + cfg['kv_n_heads'] = 2 + + sequence_id = torch.LongTensor([ + [0] * 2 + [1] * (s - 2), + [0] * 4 + [1] * (s - 4), + ]).to(device=device) + + # Computes its own kv cache + cfg['reuse_kv_layer_idx'] = None + attn0 = build_attention_layer( + name='grouped_query_attention', + attn_kwargs=cfg, # type: ignore + ).to(device) + + # Reuses layer 0's kv cache + cfg['reuse_kv_layer_idx'] = 0 + attn1 = build_attention_layer( + name='grouped_query_attention', + attn_kwargs=cfg, # type: ignore + ).to(device) + attn0_sd = attn0.state_dict() + attn0_sd['Wq.weight'] = attn0_sd['Wqkv.weight'][:cfg['d_model']] + attn0_sd['Wq.bias'] = attn0_sd['Wqkv.bias'][:cfg['d_model']] + del attn0_sd['Wqkv.weight'] + del attn0_sd['Wqkv.bias'] + attn1.load_state_dict(attn0_sd) + + attention_mask = torch.ones(n, s).to(device).bool() + + def gen_bias(attn_impl: str): + causal = True + attn_bias = None + bs = attention.attn_bias_shape( + attn_impl, + cfg['n_heads'], + s, + alibi, + use_sequence_id=True, + causal=causal, + ) + if bs is not None: + attn_bias = torch.zeros(*bs, device=device) + attn_bias = attention.build_attn_bias( + attn_impl, + attn_bias, + cfg['n_heads'], + s, + causal=causal, + alibi=alibi, + alibi_bias_max=8, + ) + + return attn_bias + + attention_mask_in_length = gen_attention_mask_in_length( + sequence_id=sequence_id, + S=s, + attn_uses_sequence_id=True, + attn_impl='flash', + attention_mask=attention_mask, + ) + + flash_attn_padding_info = gen_flash_attn_padding_info( + n, + s, + 0, + torch.device(device), + attention_mask_in_length, + attention_mask, + ) + + x0 = torch.randn(n, s, f).to(device) + x1 = x0.clone().detach() + x0.requires_grad = True + x1.requires_grad = True + + with torch.autocast(x0.device.type): + attn_bias_0 = gen_bias('flash') + alibi_slopes_0 = None + if alibi: + alibi_slopes_0 = gen_slopes( + n_heads=cfg['n_heads'], + alibi_bias_max=8, + device=torch.device(device), + return_1d=True, + ) + rotary_emb_w_meta_info = None + if rope: + rotary_embedding = gen_rotary_embedding( + rope_head_dim=cfg['d_model'] // cfg['n_heads'], + rope_impl=pos_emb_config['rope_impl'], + rope_theta=pos_emb_config['rope_theta'], + rope_dail_config=pos_emb_config.get('rope_dail_config', {}), + rope_hf_config=pos_emb_config.get('rope_hf_config', {}), + max_seq_len=s, + ).to(device) + pos = torch.arange(s).unsqueeze(0).to(device=device) + # adjust the position indices to account for padding tokens + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1), + min=0, + ) + rotary_emb_w_meta_info = { + 'impl': + pos_emb_config['rope_impl'], + 'rotary_emb': + rotary_embedding, + 'offset_info': + pos if (pos_emb_config['rope_impl'] == 'hf') else 0, + 'seq_len': + s, + } + + y0, _, prev_layer_key_value = attn0( + x0, + past_key_value=(), + attn_bias=attn_bias_0, + attention_mask=attention_mask, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info, + alibi_slopes=alibi_slopes_0, + ) + attn_bias_1 = gen_bias('flash') + alibi_slopes_1 = None + if alibi: + alibi_slopes_1 = gen_slopes( + n_heads=cfg['n_heads'], + alibi_bias_max=8, + device=torch.device(device), + return_1d=True, + ) + + prev_layer_key_value = [ + t.clone().detach() for t in prev_layer_key_value + ] + y1, _, _ = attn1( + x1, + past_key_value=None, + attn_bias=attn_bias_1, + attention_mask=attention_mask, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info, + alibi_slopes=alibi_slopes_1, + prev_layer_key_value=prev_layer_key_value, + ) + y0 *= attention_mask.unsqueeze(-1) + y1 *= attention_mask.unsqueeze(-1) + + loss0 = y0.sum() + loss1 = y1.sum() + + loss0.backward() + loss1.backward() + assert allclose_helper(y0, y1) + + torch_name_param_map = dict(attn1.named_parameters()) + for n, p in attn0.named_parameters(): + if 'Wq' in n: + tp = torch_name_param_map[n.replace('Wqkv', 'Wq')] + assert p.grad is not None + assert tp.grad is not None + assert allclose_helper(p[:cfg['d_model']], tp) + assert allclose_helper(p.grad[:cfg['d_model']], tp.grad) + else: + tp = torch_name_param_map[n] + assert p.grad is not None + assert tp.grad is not None + assert allclose_helper(p, tp) + assert allclose_helper(p.grad, tp.grad) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 2f93b1d3ce..8d47deddaa 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -5,6 +5,7 @@ import os import pathlib import warnings +from functools import partial from typing import Any, Dict, List, Optional, Union, cast from unittest import mock @@ -44,7 +45,7 @@ is_flash_v2_installed, ) from llmfoundry.models.layers.blocks import MPTBlock -from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM +from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM, MPTModel from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model from llmfoundry.utils.config_utils import to_dict_container @@ -72,12 +73,17 @@ def _load_tokenizer_cfg(cfg: Union[Dict[str, Any], DictConfig]) -> Dict: def _get_objs( request: pytest.FixtureRequest, conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml', + model_config_overrides: Optional[Dict] = None, + attn_impl: str = 'torch', ): warnings.filterwarnings( action='ignore', message='Torchmetrics v0.9 introduced a new argument class property', ) test_cfg = get_config(conf_path=conf_path) + if model_config_overrides is not None: + for k, v in model_config_overrides.items(): + test_cfg.model[k] = v # Read FSDP Config as a dict fsdp_config = test_cfg.get('fsdp_config', None) @@ -97,7 +103,7 @@ def _get_objs( device = 'cuda' if is_gpu else 'cpu' test_cfg.precision = 'amp_bf16' if is_gpu else 'fp32' test_cfg.model.attn_config = { - 'attn_impl': 'torch', + 'attn_impl': attn_impl, } test_cfg.model.init_device = device test_cfg.device = device @@ -2617,3 +2623,288 @@ def test_head_dim_8_flash_mqa_attn(batch_size: int = 2): output = model(batch) assert not torch.isnan(output.logits).any() + + +def test_construct_blocks(): + n_layers = 13 + + config = MPTConfig( + d_model=32, + n_heads=16, + n_layers=n_layers, + expansion_ratio=2, + max_seq_len=64, + attn_config={ + 'attn_impl': 'flash', + 'attn_type': 'grouped_query_attention', + 'kv_n_heads': 4, + }, + ) + + # override architecture taken from https://research.character.ai/optimizing-inference/ + config.block_overrides = {} + config.block_overrides['overrides'] = { + 'reuse_kv_layer': { + 'attn_config': { + 'reuse_kv_layer_idx': -6, + }, + }, + 'sliding_window_layer': { + 'attn_config': { + 'sliding_window_size': 1024, + }, + }, + 'sliding_window_layer_reuse': { + 'attn_config': { + 'sliding_window_size': 1024, + 'reuse_kv_layer_idx': -1, + }, + }, + } + config.block_overrides['order'] = [ + { + 'name': 'default', + }, + { + 'order': [ + { + 'name': 'sliding_window_layer', + }, + { + 'name': 'sliding_window_layer_reuse', + }, + { + 'name': 'sliding_window_layer', + }, + { + 'name': 'sliding_window_layer_reuse', + 'repeat': 2, + }, + { + 'name': 'reuse_kv_layer', + }, + ], + 'repeat': 2, + }, + ] + + block_list = MPTModel(config).construct_blocks(config) + + assert len(block_list) == n_layers + assert block_list[0].attn.sliding_window_size == -1 + assert block_list[0].attn.reuse_kv_layer_idx is None + + for layer_offset in [1, 7]: + assert block_list[layer_offset].attn.sliding_window_size == 1024 + assert block_list[layer_offset].attn.reuse_kv_layer_idx is None + assert block_list[layer_offset + 1].attn.sliding_window_size == 1024 + assert block_list[layer_offset + + 1].attn.reuse_kv_layer_idx == layer_offset + + assert block_list[layer_offset + 2].attn.sliding_window_size == 1024 + assert block_list[layer_offset + 2].attn.reuse_kv_layer_idx is None + assert block_list[layer_offset + 3].attn.sliding_window_size == 1024 + assert block_list[layer_offset + + 3].attn.reuse_kv_layer_idx == layer_offset + 2 + assert block_list[layer_offset + 4].attn.sliding_window_size == 1024 + assert block_list[layer_offset + + 4].attn.reuse_kv_layer_idx == layer_offset + 2 + + assert block_list[layer_offset + 5].attn.sliding_window_size == -1 + assert block_list[layer_offset + 5].attn.reuse_kv_layer_idx == 0 + + +@pytest.mark.gpu +def test_reuse_prev_layer_kv_cache( + request: pytest.FixtureRequest, + batch_size: int = 2, +): + conf_path = 'scripts/train/yamls/pretrain/testing.yaml' + model_config_overrides = { + 'block_overrides': { + 'order': [ + { + 'name': 'default', + }, + { + 'name': 'kv_reuse_layer', + }, + ], + 'overrides': { + 'kv_reuse_layer': { + 'attn_config': { + 'reuse_kv_layer_idx': -1, + }, + }, + }, + }, + 'use_cache': True, + } + test_cfg, model, _ = _get_objs( + request=request, + conf_path=conf_path, + model_config_overrides=model_config_overrides, + attn_impl='flash', + ) + + batch = gen_random_batch(batch_size, test_cfg) + + assert batch['input_ids'].shape == torch.Size([ + batch_size, + test_cfg.max_seq_len, + ]) + model.train() + + prev_layer_key_value_dict = {} + + def mock_forward(b_forward, b_idx, *args, **kwargs): # type: ignore + if 'prev_layer_key_value' in kwargs: + prev_layer_key_value_dict[b_idx] = kwargs['prev_layer_key_value'] + return b_forward(*args, **kwargs) + + for b_idx, block in enumerate(model.model.transformer.blocks): + block.forward = partial(mock_forward, block.forward, b_idx) + + with get_precision_context(test_cfg.precision): + outputs = model(batch) + assert len(outputs.past_key_values) == 2 + assert torch.all( + outputs.past_key_values[0][0] == outputs.past_key_values[1][0], + ) + assert torch.all( + outputs.past_key_values[0][1] == outputs.past_key_values[1][1], + ) + assert 0 not in prev_layer_key_value_dict + assert torch.all( + prev_layer_key_value_dict[1][0] == outputs.past_key_values[0][0], + ) + assert torch.all( + prev_layer_key_value_dict[1][1] == outputs.past_key_values[0][1], + ) + + +def test_override_block_args(): + block_args = {'a': 1, 'b': {'c': 3}, 'd': 4} + override_config = {'a': 2, 'b': {'c': 5}, 'e': 6} + allowed_block_overrides = {'a': None, 'b': {'c': None}, 'e': None} + new_config = MPTModel._override_block_args( + block_args, + override_config, + allowed_block_overrides, + ) + assert new_config['a'] == 2 + assert new_config['d'] == 4 + assert new_config['e'] == 6 + assert new_config['b']['c'] == 5 + + +def test_get_modules_order_expanded(): + order = [ + { + 'name': 'default', + }, + { + 'name': 'layer_a', + 'repeat': 2, + }, + { + 'order': [{ + 'name': 'layer_b', + },], + 'repeat': 3, + }, + { + 'name': 'layer_c', + 'repeat': 2, + }, + { + 'name': 'default', + }, + ] + expected_list = [ + 'default', + 'layer_a', + 'layer_a', + 'layer_b', + 'layer_b', + 'layer_b', + 'layer_c', + 'layer_c', + 'default', + ] + assert expected_list == MPTModel._get_modules_order_expanded(order) + + +@pytest.mark.parametrize('reuse_kv_layer_idx', [-2, -1, 0]) +def test_resolve_reuse_kv_layer_idx(reuse_kv_layer_idx: int): + layer_a_override = { + 'key_1': 'value_a', + 'attn_config': { + 'key_2': 'value_b', + }, + } + layer_b_override = { + 'key_1': 'value_c', + 'attn_config': { + 'key_2': 'value_d', + }, + } + layer_c_override = { + 'key_1': 'value_c' if reuse_kv_layer_idx == -1 else 'value_a', + 'attn_config': { + 'key_2': 'value_d' if reuse_kv_layer_idx == -1 else 'value_b', + 'reuse_kv_layer_idx': reuse_kv_layer_idx, + }, + } + block_overrides = { + 'overrides': { + 'layer_a': layer_a_override, + 'layer_b': layer_b_override, + 'layer_c': layer_c_override, + }, + } + model_modules_order_expanded = ['layer_a', 'layer_b', 'layer_c'] + if reuse_kv_layer_idx == -1: + model_modules_order_expanded = [ + 'layer_a', + 'layer_b', + 'layer_c', + 'layer_c', + 'layer_c', + 'layer_a', + 'layer_c', + ] + reuse_kv_layer_idx_dict = {} + + def _validate_helper(b_idx: int) -> int: + return MPTModel._resolve_reuse_kv_layer_idx( + overrides_definition=block_overrides['overrides'], + model_modules_order_expanded=model_modules_order_expanded, + b_idx=b_idx, + override_config=copy.deepcopy( + block_overrides['overrides'][model_modules_order_expanded[b_idx] + ], + ), + reuse_kv_layer_idx_dict=reuse_kv_layer_idx_dict, + ) + + if reuse_kv_layer_idx == -1: + assert _validate_helper(b_idx=2) == 1 + assert _validate_helper(b_idx=3) == 1 + assert _validate_helper(b_idx=4) == 1 + with pytest.raises( + expected_exception=ValueError, + match= + 'For reusing the kv cache of a previous layer, the previous layer should match the block config as the current layer\.', # type: ignore + ): + _validate_helper(b_idx=6) + + elif reuse_kv_layer_idx == -2: + assert _validate_helper(b_idx=2) == 0 + else: + with pytest.raises( + expected_exception=ValueError, + match= + 'The relative index of kv layer to reuse, override_attn_config\[\"reuse_kv_layer_idx\"\]=0, should be negative\.', # type: ignore + ): + _validate_helper(b_idx=2)