Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor Parallelism #1521

Merged
merged 76 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 67 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
4db2a91
add tp_strategy registry
eitanturok Aug 28, 2024
9f948b2
update
eitanturok Aug 28, 2024
8229133
add ffn tp strategy
eitanturok Aug 28, 2024
841c8cf
Merge branch 'mosaicml:main' into tp
eitanturok Aug 28, 2024
3b201d8
only do layer_plan for now
eitanturok Aug 28, 2024
1c9a532
Merge branch 'mosaicml:main' into tp
eitanturok Sep 9, 2024
b0118f7
add tp_config
eitanturok Sep 9, 2024
4868c17
build tp_strategy
eitanturok Sep 9, 2024
f30cda8
update
eitanturok Sep 9, 2024
63d236c
update
eitanturok Sep 9, 2024
6dcf5e6
update
eitanturok Sep 9, 2024
eac4ad2
replace Dict with dict
eitanturok Sep 9, 2024
4bcaf96
update
eitanturok Sep 9, 2024
092f2f2
works!
eitanturok Sep 9, 2024
a935bed
tp_strategy does not require model
eitanturok Sep 9, 2024
11c0492
tp_strategy accepts model
eitanturok Sep 9, 2024
bddb165
fix validation
eitanturok Sep 9, 2024
309b96c
updatE
eitanturok Sep 9, 2024
4226916
fix logging issue
eitanturok Sep 10, 2024
86b1b81
fix yaml
eitanturok Sep 10, 2024
f2d6571
add error
eitanturok Sep 10, 2024
c6cee7f
it works!
eitanturok Sep 10, 2024
8040aa7
works with original yaml
eitanturok Sep 10, 2024
f384de7
update
eitanturok Sep 10, 2024
9f77bcf
Merge branch 'mosaicml:main' into tp
eitanturok Sep 12, 2024
3b5f935
delete file
eitanturok Sep 12, 2024
76adc48
add replication
eitanturok Sep 12, 2024
90264ea
tp-strat does not crash
eitanturok Sep 25, 2024
7b73db5
debug print
eitanturok Sep 25, 2024
19f5477
it works!
eitanturok Sep 25, 2024
2b9664e
better init for parallelism_config
eitanturok Sep 25, 2024
9f7365e
remove comment
eitanturok Sep 25, 2024
a382b5e
init tests
eitanturok Sep 25, 2024
3fa5189
tests pass
eitanturok Sep 25, 2024
5e58dbc
add test for tp training
eitanturok Sep 25, 2024
3a6dec6
remove test for tp training b/c in composer
eitanturok Sep 25, 2024
1f025d8
remove icecream
eitanturok Sep 25, 2024
c2d309a
remove more icrecream
eitanturok Sep 25, 2024
6d65a29
style
eitanturok Sep 25, 2024
3372ec0
style more
eitanturok Sep 25, 2024
c09223c
Merge branch 'main' into tp
eitanturok Sep 25, 2024
d2c9114
remove extra
eitanturok Sep 25, 2024
ae5980a
Merge branch 'tp' of https://github.com/eitanturok/llm-foundry into tp
eitanturok Sep 25, 2024
ff36f17
fix type checking
eitanturok Sep 25, 2024
1474a80
make tp yaml
eitanturok Sep 25, 2024
835d839
Merge branch 'main' into tp
eitanturok Sep 25, 2024
8996de4
no flash attn
eitanturok Sep 25, 2024
5004fe5
better comment
eitanturok Sep 25, 2024
d0f6751
Merge branch 'main' into tp
eitanturok Sep 25, 2024
ba2dd0d
docformatter
eitanturok Sep 26, 2024
7e3ad71
remove |=
eitanturok Sep 26, 2024
39a92ad
add runtimeError
eitanturok Sep 26, 2024
92719e7
Update llmfoundry/models/utils/tp_strategy.py
eitanturok Sep 26, 2024
8c0135d
Update llmfoundry/models/utils/tp_strategy.py
eitanturok Sep 26, 2024
0156dd2
explain with comments
eitanturok Sep 26, 2024
c696338
run on GPU
eitanturok Sep 26, 2024
ccdbcf4
style
eitanturok Sep 26, 2024
cb1ab31
test one gpu warning
eitanturok Sep 26, 2024
3921cda
test_no_tp_with_one_gpu
eitanturok Sep 26, 2024
4e4b6b9
test_no_tp_with_moes
eitanturok Sep 26, 2024
c9c2455
add experimental_function decorator to tp_strategy
eitanturok Sep 26, 2024
33bbf9b
simplify trainer
eitanturok Sep 26, 2024
c9e64df
tp_strategy -> tp_stratigies
eitanturok Sep 26, 2024
df169e8
make tp dir
eitanturok Sep 26, 2024
c9a8078
Merge branch 'main' into tp
eitanturok Sep 26, 2024
e6ab929
rename
eitanturok Sep 26, 2024
6caeea9
better function names
eitanturok Sep 26, 2024
3426ea3
import fix style
eitanturok Sep 26, 2024
2683c6d
delete tp yaml
eitanturok Sep 26, 2024
d5779c7
warn checkpointing does not work
eitanturok Sep 26, 2024
7ac37bc
better description
eitanturok Sep 26, 2024
67a1c7b
cleanup
eitanturok Sep 27, 2024
eb2b591
tp test directory
eitanturok Sep 27, 2024
86992f9
style
eitanturok Sep 27, 2024
24ffeb4
type checking
eitanturok Sep 27, 2024
04da536
remove assert
eitanturok Sep 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
models,
optim,
tokenizers,
tp,
utils,
)
from llmfoundry._version import __version__
Expand Down Expand Up @@ -87,5 +88,6 @@
'models',
'optim',
'tokenizers',
'tp',
'utils',
]
38 changes: 32 additions & 6 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import time
import warnings
from copy import deepcopy
from typing import Any, Optional, Union

import torch
Expand All @@ -18,7 +19,11 @@
TraceHandler,
cyclic_schedule,
)
from composer.utils import dist, get_device, reproducibility
from composer.utils import (
dist,
get_device,
reproducibility,
)
eitanturok marked this conversation as resolved.
Show resolved Hide resolved
from omegaconf import DictConfig
from omegaconf import OmegaConf as om

Expand All @@ -43,6 +48,7 @@
build_save_planner,
build_scheduler,
build_tokenizer,
build_tp_strategies,
)
from llmfoundry.utils.config_utils import (
TRAIN_CONFIG_KEYS,
Expand Down Expand Up @@ -329,16 +335,27 @@ def train(cfg: DictConfig) -> Trainer:
changing autoresume default to True...',
)

# Warn if fsdp is enabled but user only has 1 GPU
if dist.get_world_size() == 1 and fsdp_config is not None:
# Optional tp config
tp_config: Optional[dict[str, Any]] = train_cfg.tp_config

# Warn if FSDP or TP is enabled but user only has 1 GPU
if dist.get_world_size(
) == 1 and (fsdp_config is not None or tp_config is not None):
parallelism = ''
if fsdp_config is not None:
parallelism += 'FSDP'
if tp_config is not None:
parallelism += '+TP' if fsdp_config is not None else 'TP'
warnings.warn(
'FSDP is not applicable for single-GPU training. Reverting to DDP.',
f'{parallelism} is not applicable for single-GPU training. Reverting to DDP.',
eitanturok marked this conversation as resolved.
Show resolved Hide resolved
)
fsdp_config = None
tp_config = None

# Initialize context
init_context = process_init_device(model_config, fsdp_config)
init_context = process_init_device(model_config, fsdp_config, tp_config)
logged_cfg.update({'fsdp_config': fsdp_config}, merge=True)
logged_cfg.update({'tp_config': deepcopy(tp_config)}, merge=True)

# Build tokenizer
log.info('Building tokenizer...')
Expand Down Expand Up @@ -502,6 +519,15 @@ def train(cfg: DictConfig) -> Trainer:

_log_num_params(model, logged_cfg)

# TP config
if tp_config is not None:
strategy = tp_config.pop('strategy', None)
assert isinstance(strategy, str), '`strategy` must be in `tp_config`.'
tp_config['layer_plan'] = build_tp_strategies(strategy, model)

# Parallelism config
parallelism_config = {'fsdp': fsdp_config, 'tp': tp_config}

# Optimizer
optimizer_name: str = train_cfg.optimizer.pop('name')
optimizer_cfg = train_cfg.optimizer
Expand Down Expand Up @@ -546,7 +572,7 @@ def train(cfg: DictConfig) -> Trainer:
precision=train_cfg.precision,
algorithms=algorithms,
device_train_microbatch_size=train_cfg.device_train_microbatch_size,
parallelism_config={'fsdp': fsdp_config},
parallelism_config=parallelism_config,
save_folder=train_cfg.save_folder,
save_filename=save_filename,
save_latest_filename=save_latest_filename,
Expand Down
22 changes: 22 additions & 0 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from composer.models import ComposerModel
from composer.optim import ComposerScheduler
from torch.distributed.checkpoint import LoadPlanner, SavePlanner
from torch.distributed.tensor.parallel.style import ParallelStyle
from torch.optim import Optimizer
from torch.utils.data import DataLoader as TorchDataloader
from torch.utils.data import Dataset
Expand Down Expand Up @@ -389,6 +390,26 @@
description=_save_planners_description,
)

_tp_strategies_description = (
"""The tp_strategies registry is used to register strategies for tensor parallelism.

Args:
model (ComposerModel): The model.

Returns:
layer_plan (Dict[str, ParallelStyle]): The plan used to parallelize the model.
model (ComposerModel): The model.
"""
)

tp_strategies = create_registry(
'llmfoundry',
'tp_strategies',
generic_type=Callable[[ComposerModel], dict[str, ParallelStyle]],
entry_points=True,
description=_tp_strategies_description,
)

__all__ = [
'loggers',
'callbacks',
Expand Down Expand Up @@ -416,4 +437,5 @@
'config_transforms',
'load_planners',
'save_planners',
'tp_strategies',
]
11 changes: 11 additions & 0 deletions llmfoundry/tp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.registry import tp_strategies
from llmfoundry.tp.ffn_tp_strategy import ffn_tp_strategy

tp_strategies.register('ffn', func=ffn_tp_strategy)

__all__ = [
'ffn_tp_strategy',
]
56 changes: 56 additions & 0 deletions llmfoundry/tp/ffn_tp_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from composer.models import ComposerModel
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
)
from torch.distributed.tensor.parallel.style import ParallelStyle


def ffn_tp_strategy(model: ComposerModel) -> dict[str, ParallelStyle]:
TP_LAYERS = {'ffn', 'ffn.up_proj', 'ffn.down_proj'}

# Validate that all TP_LAYERS are in model
tp_layers_in_model = {
layer for layer in TP_LAYERS for name, _ in model.named_modules()
if layer in name
}
if tp_layers_in_model != TP_LAYERS:
raise RuntimeError(
f'The FFN tensor parallelism strategy requires `model` to have layers {TP_LAYERS}. But `model` is missing layers {TP_LAYERS - tp_layers_in_model}.',
)

# Generate layer plan
layer_plan: dict[str, ParallelStyle] = {}
for name, _ in model.named_modules():
# Before the ffn layer starts, distribute the input data for proper TP use
# Inputs are currently sharded across the batch dimension (dim 0) as is done in standard DDP
# Inputs will be replicated across hidden dimension (dim 1) via allgather
if name.split('.')[-1] == 'ffn':
layer_plan[name] = PrepareModuleInput(
input_layouts=Shard(0),
desired_input_layouts=Replicate(),
use_local_output=True,
)
# Shard the ffn.up_proj weight matrix across its columns
# Inputs are already replicated across each TP group
# Outputs will be sharded along the hidden dimension (dim 1) via allgather
elif name.split('.')[-2:] == ['ffn', 'up_proj']:
layer_plan[name] = ColwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(-1),
)
# Shard the ffn.down_proj weight matrix across its rows
# Inputs are sharded along the hidden dimension (dim 1)
# Outputs will be sharded along batch dimension (dim 0) via allreduce
elif name.split('.')[-2:] == ['ffn', 'down_proj']:
layer_plan[name] = RowwiseParallel(
input_layouts=Shard(-1),
output_layouts=Shard(0),
)

return layer_plan
16 changes: 16 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from torch.distributed.checkpoint import LoadPlanner, SavePlanner
from torch.distributed.tensor.parallel.style import ParallelStyle
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
from transformers import AutoTokenizer, PreTrainedTokenizerBase
Expand All @@ -37,6 +38,7 @@
)
from llmfoundry.utils.config_utils import to_dict_container, to_list_container
from llmfoundry.utils.registry_utils import construct_from_registry
from llmfoundry.utils.warnings import experimental_function

log = logging.getLogger(__name__)

Expand All @@ -52,6 +54,7 @@
'build_tokenizer',
'build_composer_model',
'build_metric',
'build_tp_strategies',
]


Expand Down Expand Up @@ -701,3 +704,16 @@ def _validate_cfg(icl_cfg: dict[str, Any]):
)

return evaluators, logger_keys


@experimental_function('tp_strategies')
def build_tp_strategies(
name: str,
model: ComposerModel,
) -> dict[str, ParallelStyle]:
return construct_from_registry(
name=name,
registry=registry.tp_strategies,
partial_function=False,
kwargs={'model': model},
)
14 changes: 13 additions & 1 deletion llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class TrainConfig:
# Distributed training parameters
dist_timeout: Union[int, float] = 600.0
fsdp_config: Optional[dict[str, Any]] = None
tp_config: Optional[dict[str, Any]] = None

# Evaluation parameters
eval_interval: Union[int, str] = 1
Expand Down Expand Up @@ -501,7 +502,11 @@ def update_batch_size_info(cfg: dict[str, Any]) -> dict[str, Any]:
return cfg


def process_init_device(model_cfg: dict[str, Any], fsdp_config: Optional[dict]):
def process_init_device(
model_cfg: dict[str, Any],
fsdp_config: Optional[dict] = None,
tp_config: Optional[dict] = None,
):
# Restrict model init_device to 'meta' and 'cpu',
# using 'cuda' vs. 'cuda:id' is tricky and can lead to common user errors
# when multiple GPUs are available.
Expand Down Expand Up @@ -533,6 +538,13 @@ def process_init_device(model_cfg: dict[str, Any], fsdp_config: Optional[dict]):
# Set defaults for mixed initialization
fsdp_config.setdefault('load_monolith_rank0_only', True)

# Check we are not using tensor parallelism with MoEs
if tp_config is not None and 'ffn_config' in model_cfg and model_cfg[
'ffn_config'].get('ffn_type', None) in ffns_with_megablocks:
raise ValueError(
'Tensor Parallelism is not currently supported for MoE models.',
)

# Set ffn_config.device_mesh using fsdp_config
if fsdp_config is not None and 'ffn_config' in model_cfg and model_cfg[
'ffn_config'].get('ffn_type', None) in ffns_with_megablocks:
Expand Down
Loading
Loading