From ff5f7826550b47f95ad00c242e182e1440dd7167 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 21 Aug 2024 22:04:15 +0000 Subject: [PATCH] flexible validations --- llmfoundry/command_utils/train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 6e5ca9726a..521b553946 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -5,7 +5,7 @@ import os import time import warnings -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.distributed @@ -187,7 +187,10 @@ def _initialize_dist_with_barrier(dist_timeout: Union[int, float]): log.debug('Barrier test passed with device.') -def train(cfg: DictConfig) -> Trainer: +def train( + cfg: DictConfig, + config_validation_fn: Callable = validate_config +) -> Trainer: code_paths = cfg.get('code_paths', []) # Import any user provided code for code_path in code_paths: @@ -226,7 +229,7 @@ def train(cfg: DictConfig) -> Trainer: ) # Check for incompatibilities between the model and data loaders - validate_config(train_cfg) + config_validation_fn(train_cfg) cuda_alloc_conf = [] # Get max split size mb