Skip to content

Commit

Permalink
addressing some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jun 25, 2024
1 parent 8e89db9 commit 89bc22e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 27 deletions.
42 changes: 20 additions & 22 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ffn_config_defaults,
init_config_defaults,
)
from llmfoundry.utils.warnings import ExperimentalWarning


class MPTConfig(PretrainedConfig):
Expand Down Expand Up @@ -118,7 +119,7 @@ def __init__(
also be a dictionary that specifies the fc layer name and any kwargs for the fc layer.
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks.
block_overrides: The allows for overriding default block configs for certain layers. This should contain two sub configs: order and overrides. order specifies the order of different kinds of layers (default refers to a layer that does not apply any overrides). For each kind of layer, specify the overrides in the overrides config.
block_overrides: This allows for overriding default block configs for certain layers. This must contain `overrides` and at least one of `start`, `repeating_pattern`, and `end`. `start`, `repeating_pattern`, and `end` specify the order of different kinds of layers (default refers to a layer that does not apply any overrides). For each kind of layer, specify the `overrides` in the overrides config.
To specify this model (https://research.character.ai/optimizing-inference/) , the following config will be needed:
block_overrides:
start:
Expand Down Expand Up @@ -183,27 +184,7 @@ def __init__(
'reusing kv cache from a previous layer is not implemented for torch attention.',
)
if block_overrides is not None:
warnings.warn(
'block_overrides is an experimental feature. The YAML design may change in the future.',
)
if 'start' not in block_overrides and 'repeating_pattern' not in block_overrides and 'end' not in block_overrides:
raise ValueError(
'either start, repeating_pattern, or end should be defined in block_overrides',
)
if 'overrides' not in block_overrides:
raise ValueError(
'overrides should be defined in block_overrides',
)
for name, override in block_overrides['overrides'].items():
if name == 'default':
raise ValueError(
'block overrides cannot be named "default".',
)
if 'attn_config' in override and 'reuse_kv_layer_idx' in override[
'attn_config'] and self.attn_config['attn_impl'] == 'torch':
raise NotImplementedError(
'reusing kv cache from a previous layer is not implemented for torch attention.',
)
self._validate_block_overrides(block_overrides)
self.block_overrides = block_overrides

if isinstance(fc_type, str):
Expand All @@ -230,6 +211,23 @@ def __init__(

self._validate_config()

def _validate_block_overrides(self, block_overrides):
warnings.warn(ExperimentalWarning('block_overrides'))
if 'start' not in block_overrides and 'repeating_pattern' not in block_overrides and 'end' not in block_overrides:
raise ValueError(
'either start, repeating_pattern, or end should be defined in block_overrides',
)
if 'overrides' not in block_overrides:
raise ValueError('overrides should be defined in block_overrides',)
for name, override in block_overrides['overrides'].items():
if name == 'default':
raise ValueError('block overrides cannot be named "default".',)
if 'attn_config' in override and 'reuse_kv_layer_idx' in override[
'attn_config'] and self.attn_config['attn_impl'] == 'torch':
raise NotImplementedError(
'reusing kv cache from a previous layer is not implemented for torch attention.',
)

def _set_config_defaults(
self,
config: Dict[str, Any],
Expand Down
10 changes: 5 additions & 5 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def _construct_blocks_with_overrides(
'config.block_overrides should not be None when calling _construct_blocks_with_overrides.',
)
modules_order_expanded = {}
for type in 'start', 'repeating_pattern', 'end':
for type in ['start', 'repeating_pattern', 'end']:
modules_order_expanded[type] = []
if type in config.block_overrides:
for block in config.block_overrides[type]:
Expand All @@ -487,12 +487,12 @@ def _construct_blocks_with_overrides(
raise ValueError(
'Number of layers should be divisible by the specified custom modules order.',
)
num_repetitions = (
config.n_layers - (start_len + end_len)
) // repeating_pattern_len
modules_order_expanded[
'repeating_pattern'
] = modules_order_expanded['repeating_pattern'] * (
(config.n_layers -
(start_len + end_len)) // repeating_pattern_len
)
] = modules_order_expanded['repeating_pattern'] * num_repetitions

model_modules_order_expanded = modules_order_expanded[
'start'] + modules_order_expanded['repeating_pattern'
Expand Down

0 comments on commit 89bc22e

Please sign in to comment.