Skip to content

Commit

Permalink
Add test for logged_config transforms (#1441)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu authored Aug 12, 2024
1 parent 9cdfd6d commit 55cb0e3
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import copy
from typing import Any, Dict, List

import catalogue
Expand Down Expand Up @@ -70,3 +71,34 @@ def dummy_transform(config: Dict[str, Any]) -> Dict[str, Any]:

del catalogue.REGISTRY[
('llmfoundry', 'config_transforms', 'dummy_transform')]


def test_logged_cfg():
config = DictConfig({
'global_train_batch_size': 1,
'device_train_microbatch_size': 1,
'model': {},
'scheduler': {},
'max_seq_len': 128,
'train_loader': {},
'max_duration': 1,
'tokenizer': {},
'eval_interval': 1,
'seed': 1,
'optimizer': {},
'variables': {},
},)
logged_config, _ = make_dataclass_and_log_config(
config,
TrainConfig,
TRAIN_CONFIG_KEYS,
transforms='all',
)
expected_config = copy.deepcopy(config)
expected_config.update({
'n_gpus': 1,
'device_train_batch_size': 1,
'device_train_grad_accum': 1,
'device_eval_batch_size': 1,
})
assert expected_config == logged_config

0 comments on commit 55cb0e3

Please sign in to comment.