Skip to content

Commit

Permalink
simplify trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed Sep 26, 2024
1 parent c9c2455 commit 33bbf9b
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,9 @@ def train(cfg: DictConfig) -> Trainer:

# TP config
if tp_config is not None:
if 'layer_plan' not in tp_config:
tp_config['layer_plan'] = {}
if 'strategy' in tp_config:
strategy = tp_config.pop('strategy')
strategy_layer_plan = build_tp_strategy(strategy, model)
tp_config['layer_plan'] = strategy_layer_plan
strategy = tp_config.pop('strategy', None)
assert isinstance(strategy, str), "`strategy` must be in `tp_config`."
tp_config['layer_plan'] = build_tp_strategy(strategy, model)

# Parallelism config
parallelism_config = {'fsdp': fsdp_config, 'tp': tp_config}
Expand Down

0 comments on commit 33bbf9b

Please sign in to comment.