diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index f7039c8444..d76dcc20ae 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -539,7 +539,7 @@ def _construct_blocks_with_overrides( module_name, self._get_overrides_for_logging(override_config), ],) - new_block_args = self._override_block_args( + new_block_args = MPTModel._override_block_args( block_args, override_config, config.allowed_block_overrides, @@ -570,8 +570,8 @@ def _get_overrides_for_logging( overrides_list.append({k: v}) return overrides_list + @staticmethod def _override_block_args( - self, block_args: Dict[str, Any], override_config: Dict[str, Any], allowed_block_overrides: set, @@ -584,7 +584,7 @@ def _override_block_args( f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.', ) if isinstance(override_config[k], dict): - new_block_args[k] = self._override_block_args( + new_block_args[k] = MPTModel._override_block_args( block_args[k], override_config[k], allowed_block_overrides, diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 5cd0880565..63ce070af7 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2776,3 +2776,18 @@ def test_reuse_prev_layer_kv_cache( assert torch.all( outputs.past_key_values[0][1] == outputs.past_key_values[1][1], ) + + +def test_override_block_args(): + block_args = {'a': 1, 'b': {'c': 3}, 'd': 4} + override_config = {'a': 2, 'b': {'c': 5}, 'e': 6} + allowed_block_overrides = {'a', 'c', 'e'} + new_config = MPTModel._override_block_args( + block_args, + override_config, + allowed_block_overrides, + ) + assert new_config['a'] == 2 + assert new_config['d'] == 4 + assert new_config['e'] == 6 + assert new_config['b']['c'] == 5