diff --git a/llmfoundry/tp/__init__.py b/llmfoundry/tp/__init__.py index 48b788befa..323ae23727 100644 --- a/llmfoundry/tp/__init__.py +++ b/llmfoundry/tp/__init__.py @@ -2,10 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.registry import tp_strategies -from llmfoundry.tp.ffn_tp_strategy import ffn +from llmfoundry.tp.ffn_tp_strategy import ffn_tp_strategy -tp_strategies.register('ffn', func=ffn) +tp_strategies.register('ffn', func=ffn_tp_strategy) __all__ = [ - 'ffn', + 'ffn_tp_strategy', ] diff --git a/llmfoundry/tp/ffn_tp_strategy.py b/llmfoundry/tp/ffn_tp_strategy.py index 2804bfc747..1de92ef6ae 100644 --- a/llmfoundry/tp/ffn_tp_strategy.py +++ b/llmfoundry/tp/ffn_tp_strategy.py @@ -11,7 +11,7 @@ from torch.distributed.tensor.parallel.style import ParallelStyle -def ffn(model: ComposerModel) -> dict[str, ParallelStyle]: +def ffn_tp_strategy(model: ComposerModel) -> dict[str, ParallelStyle]: TP_LAYERS = {'ffn', 'ffn.up_proj', 'ffn.down_proj'} # Validate that all TP_LAYERS are in model