Skip to content

Commit

Permalink
fixing
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jun 22, 2024
1 parent c774a4b commit 8dee35e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def _apply_rotary_embeddings(
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.reuse_kv_layer_idx is not None:
orig_key, orig_value = key, value
key, value = torch.zeros_like(key), torch.zeros_like(value)
key, value = torch.empty_like(key), torch.empty_like(value)

rotary_emb = rotary_emb_w_meta_info['rotary_emb']
seq_len = rotary_emb_w_meta_info['seq_len']
Expand Down Expand Up @@ -696,7 +696,7 @@ def _apply_rotary_embeddings(
query = query.view(bsz, seqlen, -1)
key = key.view(bsz, seqlen, -1)
if self.reuse_kv_layer_idx is not None:
return query, orig_key, orig_value
return query, orig_key, orig_value # type: ignore
return query, key, value

def get_implementation_specific_args(
Expand Down
13 changes: 9 additions & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList:
nn.ModuleList: The list of Transformer blocks.
"""
block_args = self.extract_block_args(config.to_dict())
self.kv_cache_layers = None

if config.block_overrides is not None:
return self._construct_blocks_with_overrides(config, block_args)
Expand All @@ -457,6 +458,10 @@ def _construct_blocks_with_overrides(
config: MPTConfig,
block_args: Dict[str, Any],
) -> nn.ModuleList:
if config.block_overrides is None:
raise ValueError(
'config.block_overrides should not be None when calling _construct_blocks_with_overrides.',
)
modules_order_expanded = {}
for type in 'start', 'repeating_pattern', 'end':
modules_order_expanded[type] = []
Expand Down Expand Up @@ -512,7 +517,7 @@ def _construct_blocks_with_overrides(
] = reuse_kv_layer_idx
self.kv_cache_layers.add(reuse_kv_layer_idx)

new_block_args = self.override_block_args(
new_block_args = self._override_block_args(
block_args,
override_config,
)
Expand All @@ -525,7 +530,7 @@ def _construct_blocks_with_overrides(

return nn.ModuleList(module_list)

def override_block_args(
def _override_block_args(
self,
block_args: Dict[str, Any],
override_config: Dict[str, Any],
Expand All @@ -538,7 +543,7 @@ def override_block_args(
f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.',
)
if isinstance(override_config[k], dict):
new_block_args[k] = self.override_block_args(
new_block_args[k] = self._override_block_args(
block_args[k],
override_config[k],
)
Expand Down Expand Up @@ -856,7 +861,7 @@ def forward(
)
if presents is not None:
presents += (present,)
if b_idx in self.kv_cache_layers:
if self.kv_cache_layers is not None and b_idx in self.kv_cache_layers:
if self.attn_impl != 'torch':
layer_kv_cache_dict[b_idx] = [
present[0][:, past_position:],
Expand Down

0 comments on commit 8dee35e

Please sign in to comment.