From 2322ebea7a0a5d3c5ca95d2a3ba37b85add96a43 Mon Sep 17 00:00:00 2001 From: root <23239305+b-chu@users.noreply.github.com> Date: Sat, 10 Aug 2024 22:25:31 +0000 Subject: [PATCH] Add test for logged_config transforms --- tests/test_utils.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index dc9bcd9baf..08123846e6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,7 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy from typing import Any, Dict, List import catalogue @@ -70,3 +71,34 @@ def dummy_transform(config: Dict[str, Any]) -> Dict[str, Any]: del catalogue.REGISTRY[ ('llmfoundry', 'config_transforms', 'dummy_transform')] + + +def test_logged_cfg(): + config = DictConfig({ + 'global_train_batch_size': 1, + 'device_train_microbatch_size': 1, + 'model': {}, + 'scheduler': {}, + 'max_seq_len': 128, + 'train_loader': {}, + 'max_duration': 1, + 'tokenizer': {}, + 'eval_interval': 1, + 'seed': 1, + 'optimizer': {}, + 'variables': {}, + },) + logged_config, _ = make_dataclass_and_log_config( + config, + TrainConfig, + TRAIN_CONFIG_KEYS, + transforms='all', + ) + expected_config = copy.deepcopy(config) + expected_config.update({ + 'n_gpus': 1, + 'device_train_batch_size': 1, + 'device_train_grad_accum': 1, + 'device_eval_batch_size': 1, + }) + assert expected_config == logged_config