Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jun 22, 2024
1 parent 81e2930 commit 9b3d813
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,20 +441,19 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList:
nn.ModuleList: The list of Transformer blocks.
"""
block_args = self.extract_block_args(config.to_dict())
module_list = []
self.kv_cache_layers = set()

modules_order_expanded = {}
for type in 'start', 'repeating_pattern', 'end':
modules_order_expanded[type] = []
for block in config['block_overrides'][type]:
if not isinstance(block['repetitions'],
int) or block['repetitions'] < 1:
raise ValueError(
'repetitions should be a positive integer.',
)
modules_order_expanded[type].extend([block['name']] *
block['repetitions'])
if type in config['block_overrides']:
for block in config['block_overrides'][type]:
if not isinstance(block['repeat'],
int) or block['repeat'] < 1:
raise ValueError(
'repeat should be a positive integer.',
)
modules_order_expanded[type].extend([block['name']] *
block['repeat'])

start_len = len(modules_order_expanded['start'])
repeating_pattern_len = len(modules_order_expanded['repeating_pattern'])
Expand All @@ -466,14 +465,15 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList:
raise ValueError(
'Number of layers should be divisible by the specified custom modules order.',
)
attn_modules_order_expanded = modules_order_expanded[
model_modules_order_expanded = modules_order_expanded[
'start'] + modules_order_expanded['repeating_pattern'] * (
(config.n_layers -
(start_len + end_len)) // repeating_pattern_len
) + modules_order_expanded['end']
config.n_layers - (start_len + end_len)
) // repeating_pattern_len + modules_order_expanded['end']

self.kv_cache_layers = set()
module_list = []
for i in range(config.n_layers):
module_name = attn_modules_order_expanded[i]
module_name = model_modules_order_expanded[i]

override_config = {}
if module_name != 'default':
Expand Down Expand Up @@ -521,12 +521,12 @@ def override_block_args(
f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.',
)
if isinstance(override_config[k], dict):
new_block_args = self.override_block_args(
new_block_args[k] = self.override_block_args(
block_args[k],
override_config[k],
)
else:
new_block_args = override_config[k]
new_block_args[k] = override_config[k]
return new_block_args

def extract_block_args(self, block_args: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -805,7 +805,6 @@ def forward(
)

layer_kv_cache_dict = {}
prev_layer_key_value = None
for b_idx, block in enumerate(self.blocks):
attn_block = block.attn if hasattr(
block,
Expand All @@ -818,6 +817,8 @@ def forward(
)
prev_layer_key_value = layer_kv_cache_dict[
block.reuse_kv_layer_idx]
else:
prev_layer_key_value = None
if output_hidden_states:
assert all_hidden_states is not None # pyright
all_hidden_states = all_hidden_states + (x,)
Expand Down

0 comments on commit 9b3d813

Please sign in to comment.