Skip to content

Commit

Permalink
adding tests for override_block_args
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jun 25, 2024
1 parent b74330e commit 5079d2e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
6 changes: 3 additions & 3 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def _construct_blocks_with_overrides(
module_name,
self._get_overrides_for_logging(override_config),
],)
new_block_args = self._override_block_args(
new_block_args = MPTModel._override_block_args(
block_args,
override_config,
config.allowed_block_overrides,
Expand Down Expand Up @@ -570,8 +570,8 @@ def _get_overrides_for_logging(
overrides_list.append({k: v})
return overrides_list

@staticmethod
def _override_block_args(
self,
block_args: Dict[str, Any],
override_config: Dict[str, Any],
allowed_block_overrides: set,
Expand All @@ -584,7 +584,7 @@ def _override_block_args(
f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.',
)
if isinstance(override_config[k], dict):
new_block_args[k] = self._override_block_args(
new_block_args[k] = MPTModel._override_block_args(
block_args[k],
override_config[k],
allowed_block_overrides,
Expand Down
15 changes: 15 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2776,3 +2776,18 @@ def test_reuse_prev_layer_kv_cache(
assert torch.all(
outputs.past_key_values[0][1] == outputs.past_key_values[1][1],
)


def test_override_block_args():
block_args = {'a': 1, 'b': {'c': 3}, 'd': 4}
override_config = {'a': 2, 'b': {'c': 5}, 'e': 6}
allowed_block_overrides = {'a', 'c', 'e'}
new_config = MPTModel._override_block_args(
block_args,
override_config,
allowed_block_overrides,
)
assert new_config['a'] == 2
assert new_config['d'] == 4
assert new_config['e'] == 6
assert new_config['b']['c'] == 5

0 comments on commit 5079d2e

Please sign in to comment.