diff --git a/llmfoundry/models/utils/tp_strategy.py b/llmfoundry/models/utils/tp_strategy.py index f7929686c6..748ccf0c55 100644 --- a/llmfoundry/models/utils/tp_strategy.py +++ b/llmfoundry/models/utils/tp_strategy.py @@ -19,7 +19,10 @@ def ffn_tp_strategy(model: ComposerModel) -> dict[str, ParallelStyle]: layer for layer in TP_LAYERS for name, _ in model.named_modules() if layer in name } - assert tp_layers_in_model == TP_LAYERS, f'The FFN tensor parallelism strategy requires `model` to have layers {TP_LAYERS}. But `model` is missing layers {TP_LAYERS - tp_layers_in_model}.' + if tp_layers_in_model != TP_LAYERS: + raise RuntimeError( + f'The FFN tensor parallelism strategy requires `model` to have layers {TP_LAYERS}. But `model` is missing layers {TP_LAYERS - tp_layers_in_model}.' + ) # generate layer plan layer_plan: dict[str, ParallelStyle] = {}