diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 2ba75a95ed..3458e99432 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -631,7 +631,7 @@ def _apply_rotary_embeddings( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if self.reuse_kv_layer_idx is not None: orig_key, orig_value = key, value - key, value = torch.zeros_like(key), torch.zeros_like(value) + key, value = torch.empty_like(key), torch.empty_like(value) rotary_emb = rotary_emb_w_meta_info['rotary_emb'] seq_len = rotary_emb_w_meta_info['seq_len'] @@ -696,7 +696,7 @@ def _apply_rotary_embeddings( query = query.view(bsz, seqlen, -1) key = key.view(bsz, seqlen, -1) if self.reuse_kv_layer_idx is not None: - return query, orig_key, orig_value + return query, orig_key, orig_value # type: ignore return query, key, value def get_implementation_specific_args( diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 55886c9aa8..bfef0700c3 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -441,6 +441,7 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: nn.ModuleList: The list of Transformer blocks. """ block_args = self.extract_block_args(config.to_dict()) + self.kv_cache_layers = None if config.block_overrides is not None: return self._construct_blocks_with_overrides(config, block_args) @@ -457,6 +458,10 @@ def _construct_blocks_with_overrides( config: MPTConfig, block_args: Dict[str, Any], ) -> nn.ModuleList: + if config.block_overrides is None: + raise ValueError( + 'config.block_overrides should not be None when calling _construct_blocks_with_overrides.', + ) modules_order_expanded = {} for type in 'start', 'repeating_pattern', 'end': modules_order_expanded[type] = [] @@ -512,7 +517,7 @@ def _construct_blocks_with_overrides( ] = reuse_kv_layer_idx self.kv_cache_layers.add(reuse_kv_layer_idx) - new_block_args = self.override_block_args( + new_block_args = self._override_block_args( block_args, override_config, ) @@ -525,7 +530,7 @@ def _construct_blocks_with_overrides( return nn.ModuleList(module_list) - def override_block_args( + def _override_block_args( self, block_args: Dict[str, Any], override_config: Dict[str, Any], @@ -538,7 +543,7 @@ def override_block_args( f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.', ) if isinstance(override_config[k], dict): - new_block_args[k] = self.override_block_args( + new_block_args[k] = self._override_block_args( block_args[k], override_config[k], ) @@ -856,7 +861,7 @@ def forward( ) if presents is not None: presents += (present,) - if b_idx in self.kv_cache_layers: + if self.kv_cache_layers is not None and b_idx in self.kv_cache_layers: if self.attn_impl != 'torch': layer_kv_cache_dict[b_idx] = [ present[0][:, past_position:],