diff --git a/scripts/train/train.py b/scripts/train/train.py index 84df79722a..4620254135 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -50,7 +50,6 @@ make_dataclass_and_log_config, pop_config, process_init_device, - update_batch_size_info, ) from llmfoundry.utils.exceptions import ( BaseContextualError, @@ -197,7 +196,7 @@ def main(cfg: DictConfig) -> Trainer: cfg, TrainConfig, TRAIN_CONFIG_KEYS, - transforms=[update_batch_size_info], + transforms='all', ) # Set logging level