-
Notifications
You must be signed in to change notification settings - Fork 537
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cb1ab31
commit 3921cda
Showing
1 changed file
with
91 additions
and
104 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,124 +1,111 @@ | ||
# Copyright 2024 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from pathlib import Path | ||
from tempfile import TemporaryDirectory | ||
|
||
import pytest | ||
from omegaconf import OmegaConf as om | ||
from torch.distributed._tensor import Replicate, Shard | ||
from torch.distributed.tensor.parallel import ( | ||
ColwiseParallel, | ||
PrepareModuleInput, | ||
RowwiseParallel, | ||
ColwiseParallel, | ||
PrepareModuleInput, | ||
RowwiseParallel, | ||
) | ||
|
||
|
||
from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg | ||
from llmfoundry.command_utils.train import train | ||
from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM | ||
from llmfoundry.utils.builders import build_tp_strategy | ||
|
||
|
||
from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg | ||
|
||
|
||
@pytest.mark.gpu | ||
def test_ffn_tp_strategy_layer_plan(): | ||
# Actual layer plan from tp_strategy=fnn | ||
tp_config = { | ||
'strategy': 'ffn', | ||
} | ||
|
||
|
||
model_cfg = { | ||
'name': 'mpt_causal_lm', | ||
'd_model': 128, | ||
'n_heads': 4, | ||
'n_layers': 3, | ||
'expansion_ratio': 1, | ||
'max_seq_len': 16, | ||
'vocab_size': 50368, | ||
} | ||
model = ComposerMPTCausalLM(**model_cfg) | ||
layer_plan = build_tp_strategy(tp_config['strategy'], model) | ||
|
||
|
||
# Expected layer plan | ||
_expected_layer_plan = { | ||
'ffn': | ||
PrepareModuleInput( | ||
input_layouts=Shard(0), | ||
desired_input_layouts=Replicate(), | ||
use_local_output=True, | ||
), | ||
'ffn.down_proj': | ||
RowwiseParallel( | ||
input_layouts=Shard(-1), | ||
output_layouts=Shard(0), | ||
), | ||
'ffn.up_proj': | ||
ColwiseParallel( | ||
input_layouts=Replicate(), | ||
output_layouts=Shard(-1), | ||
), | ||
} | ||
expected_layer_plan = { | ||
f'model.transformer.blocks.{layer_idx}.{name}': layer_plan | ||
for name, layer_plan in _expected_layer_plan.items() | ||
for layer_idx in range(model_cfg['n_layers']) | ||
} | ||
|
||
|
||
# Compare expected and actual layer plans | ||
for (n1, lp1), (n2, lp2) in zip( | ||
sorted(expected_layer_plan.items()), | ||
sorted(layer_plan.items()), | ||
): | ||
assert n1 == n2 | ||
assert type(lp1) == type(lp2) | ||
if isinstance( | ||
lp1, | ||
PrepareModuleInput, | ||
) and isinstance(lp2, PrepareModuleInput): | ||
assert lp1.input_layouts == lp2.input_layouts | ||
assert lp1.desired_input_layouts == lp2.desired_input_layouts | ||
assert lp1.use_local_output == lp2.use_local_output | ||
elif ( | ||
isinstance(lp1, ColwiseParallel) and | ||
isinstance(lp2, ColwiseParallel) | ||
) or ( | ||
isinstance(lp1, RowwiseParallel) and | ||
isinstance(lp2, RowwiseParallel) | ||
): | ||
assert lp1.input_layouts == lp2.input_layouts | ||
assert lp1.output_layouts == lp2.output_layouts | ||
assert lp1.use_local_output == lp2.use_local_output | ||
else: | ||
raise ValueError(f'Layer plan of wrong type: {type(layer_plan)}') | ||
|
||
|
||
# Actual layer plan from tp_strategy=fnn | ||
tp_config = { | ||
'strategy': 'ffn', | ||
} | ||
|
||
model_cfg = { | ||
'name': 'mpt_causal_lm', | ||
'd_model': 128, | ||
'n_heads': 4, | ||
'n_layers': 3, | ||
'expansion_ratio': 1, | ||
'max_seq_len': 16, | ||
'vocab_size': 50368, | ||
} | ||
model = ComposerMPTCausalLM(**model_cfg) | ||
layer_plan = build_tp_strategy(tp_config['strategy'], model) | ||
|
||
# Expected layer plan | ||
_expected_layer_plan = { | ||
'ffn': | ||
PrepareModuleInput( | ||
input_layouts=Shard(0), | ||
desired_input_layouts=Replicate(), | ||
use_local_output=True, | ||
), | ||
'ffn.down_proj': | ||
RowwiseParallel( | ||
input_layouts=Shard(-1), | ||
output_layouts=Shard(0), | ||
), | ||
'ffn.up_proj': | ||
ColwiseParallel( | ||
input_layouts=Replicate(), | ||
output_layouts=Shard(-1), | ||
), | ||
} | ||
expected_layer_plan = { | ||
f'model.transformer.blocks.{layer_idx}.{name}': layer_plan | ||
for name, layer_plan in _expected_layer_plan.items() | ||
for layer_idx in range(model_cfg['n_layers']) | ||
} | ||
|
||
# Compare expected and actual layer plans | ||
for (n1, lp1), (n2, lp2) in zip( | ||
sorted(expected_layer_plan.items()), | ||
sorted(layer_plan.items()), | ||
): | ||
assert n1 == n2 | ||
assert type(lp1) == type(lp2) | ||
if isinstance( | ||
lp1, | ||
PrepareModuleInput, | ||
) and isinstance(lp2, PrepareModuleInput): | ||
assert lp1.input_layouts == lp2.input_layouts | ||
assert lp1.desired_input_layouts == lp2.desired_input_layouts | ||
assert lp1.use_local_output == lp2.use_local_output | ||
elif ( | ||
isinstance(lp1, ColwiseParallel) and | ||
isinstance(lp2, ColwiseParallel) | ||
) or ( | ||
isinstance(lp1, RowwiseParallel) and | ||
isinstance(lp2, RowwiseParallel) | ||
): | ||
assert lp1.input_layouts == lp2.input_layouts | ||
assert lp1.output_layouts == lp2.output_layouts | ||
assert lp1.use_local_output == lp2.use_local_output | ||
else: | ||
raise ValueError(f'Layer plan of wrong type: {type(layer_plan)}') | ||
|
||
|
||
@pytest.mark.gpu | ||
# @pytest.mark.filterwarnings("error::") # treat warnings like errors | ||
def test_tp_one_gpu(): | ||
from icecream import ic | ||
# get train_cfg with tp | ||
train_cfg_path: str = 'scripts/train/yamls/pretrain/mpt-125m.yaml' | ||
with open(train_cfg_path) as f: | ||
train_cfg = om.load(f) | ||
|
||
|
||
tmp_path = '/my-tmp/c4_small' | ||
dataset_name = create_c4_dataset_xxsmall(tmp_path) | ||
train_cfg = gpt_tiny_cfg(dataset_name, 'gpu') | ||
train_cfg.tp_config = {'strategy': 'ffn'} | ||
|
||
with pytest.warns(UserWarning, match='FSDP+TP is not applicable for single-GPU training. Reverting to DDP.'): | ||
train(train_cfg) | ||
|
||
|
||
|
||
|
||
|
||
|
||
# if __name__ == '__main__': | ||
# test_tp_one_gpu() | ||
def test_no_tp_with_one_gpu(): | ||
with TemporaryDirectory() as tmp_path: | ||
# train_cfg with ffn tensor parallelism | ||
train_cfg_path: str = 'scripts/train/yamls/pretrain/mpt-125m.yaml' | ||
with open(train_cfg_path, 'r', encoding='utf-8') as f: | ||
train_cfg = om.load(f) | ||
dataset_name = create_c4_dataset_xxsmall(Path(tmp_path)) | ||
train_cfg = gpt_tiny_cfg(dataset_name, 'gpu') | ||
train_cfg.tp_config = {'strategy': 'ffn'} | ||
|
||
# Expect a warning that we use DDP and not FSDP-TP when we have one GPU. | ||
with pytest.warns( | ||
UserWarning, | ||
match= | ||
r'FSDP\+TP is not applicable for single-GPU training. Reverting to DDP.', | ||
): | ||
train(train_cfg) |