diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 6e5ca9726a..521b553946 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -5,7 +5,7 @@ import os import time import warnings -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.distributed @@ -187,7 +187,10 @@ def _initialize_dist_with_barrier(dist_timeout: Union[int, float]): log.debug('Barrier test passed with device.') -def train(cfg: DictConfig) -> Trainer: +def train( + cfg: DictConfig, + config_validation_fn: Callable = validate_config +) -> Trainer: code_paths = cfg.get('code_paths', []) # Import any user provided code for code_path in code_paths: @@ -226,7 +229,7 @@ def train(cfg: DictConfig) -> Trainer: ) # Check for incompatibilities between the model and data loaders - validate_config(train_cfg) + config_validation_fn(train_cfg) cuda_alloc_conf = [] # Get max split size mb