From 33bbf9bdd3e8773c4b9d97e7be8591f275b947e4 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 26 Sep 2024 21:05:56 +0000 Subject: [PATCH] simplify trainer --- llmfoundry/command_utils/train.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 9641aa44e8..effec90559 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -521,12 +521,9 @@ def train(cfg: DictConfig) -> Trainer: # TP config if tp_config is not None: - if 'layer_plan' not in tp_config: - tp_config['layer_plan'] = {} - if 'strategy' in tp_config: - strategy = tp_config.pop('strategy') - strategy_layer_plan = build_tp_strategy(strategy, model) - tp_config['layer_plan'] = strategy_layer_plan + strategy = tp_config.pop('strategy', None) + assert isinstance(strategy, str), "`strategy` must be in `tp_config`." + tp_config['layer_plan'] = build_tp_strategy(strategy, model) # Parallelism config parallelism_config = {'fsdp': fsdp_config, 'tp': tp_config}