Skip to content

Commit

Permalink
test_no_tp_with_one_gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed Sep 26, 2024
1 parent cb1ab31 commit 3921cda
Showing 1 changed file with 91 additions and 104 deletions.
195 changes: 91 additions & 104 deletions tests/models/utils/test_tp_strategy.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,111 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
from tempfile import TemporaryDirectory

import pytest
from omegaconf import OmegaConf as om
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
)


from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg
from llmfoundry.command_utils.train import train
from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM
from llmfoundry.utils.builders import build_tp_strategy


from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg


@pytest.mark.gpu
def test_ffn_tp_strategy_layer_plan():
# Actual layer plan from tp_strategy=fnn
tp_config = {
'strategy': 'ffn',
}


model_cfg = {
'name': 'mpt_causal_lm',
'd_model': 128,
'n_heads': 4,
'n_layers': 3,
'expansion_ratio': 1,
'max_seq_len': 16,
'vocab_size': 50368,
}
model = ComposerMPTCausalLM(**model_cfg)
layer_plan = build_tp_strategy(tp_config['strategy'], model)


# Expected layer plan
_expected_layer_plan = {
'ffn':
PrepareModuleInput(
input_layouts=Shard(0),
desired_input_layouts=Replicate(),
use_local_output=True,
),
'ffn.down_proj':
RowwiseParallel(
input_layouts=Shard(-1),
output_layouts=Shard(0),
),
'ffn.up_proj':
ColwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(-1),
),
}
expected_layer_plan = {
f'model.transformer.blocks.{layer_idx}.{name}': layer_plan
for name, layer_plan in _expected_layer_plan.items()
for layer_idx in range(model_cfg['n_layers'])
}


# Compare expected and actual layer plans
for (n1, lp1), (n2, lp2) in zip(
sorted(expected_layer_plan.items()),
sorted(layer_plan.items()),
):
assert n1 == n2
assert type(lp1) == type(lp2)
if isinstance(
lp1,
PrepareModuleInput,
) and isinstance(lp2, PrepareModuleInput):
assert lp1.input_layouts == lp2.input_layouts
assert lp1.desired_input_layouts == lp2.desired_input_layouts
assert lp1.use_local_output == lp2.use_local_output
elif (
isinstance(lp1, ColwiseParallel) and
isinstance(lp2, ColwiseParallel)
) or (
isinstance(lp1, RowwiseParallel) and
isinstance(lp2, RowwiseParallel)
):
assert lp1.input_layouts == lp2.input_layouts
assert lp1.output_layouts == lp2.output_layouts
assert lp1.use_local_output == lp2.use_local_output
else:
raise ValueError(f'Layer plan of wrong type: {type(layer_plan)}')


# Actual layer plan from tp_strategy=fnn
tp_config = {
'strategy': 'ffn',
}

model_cfg = {
'name': 'mpt_causal_lm',
'd_model': 128,
'n_heads': 4,
'n_layers': 3,
'expansion_ratio': 1,
'max_seq_len': 16,
'vocab_size': 50368,
}
model = ComposerMPTCausalLM(**model_cfg)
layer_plan = build_tp_strategy(tp_config['strategy'], model)

# Expected layer plan
_expected_layer_plan = {
'ffn':
PrepareModuleInput(
input_layouts=Shard(0),
desired_input_layouts=Replicate(),
use_local_output=True,
),
'ffn.down_proj':
RowwiseParallel(
input_layouts=Shard(-1),
output_layouts=Shard(0),
),
'ffn.up_proj':
ColwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(-1),
),
}
expected_layer_plan = {
f'model.transformer.blocks.{layer_idx}.{name}': layer_plan
for name, layer_plan in _expected_layer_plan.items()
for layer_idx in range(model_cfg['n_layers'])
}

# Compare expected and actual layer plans
for (n1, lp1), (n2, lp2) in zip(
sorted(expected_layer_plan.items()),
sorted(layer_plan.items()),
):
assert n1 == n2
assert type(lp1) == type(lp2)
if isinstance(
lp1,
PrepareModuleInput,
) and isinstance(lp2, PrepareModuleInput):
assert lp1.input_layouts == lp2.input_layouts
assert lp1.desired_input_layouts == lp2.desired_input_layouts
assert lp1.use_local_output == lp2.use_local_output
elif (
isinstance(lp1, ColwiseParallel) and
isinstance(lp2, ColwiseParallel)
) or (
isinstance(lp1, RowwiseParallel) and
isinstance(lp2, RowwiseParallel)
):
assert lp1.input_layouts == lp2.input_layouts
assert lp1.output_layouts == lp2.output_layouts
assert lp1.use_local_output == lp2.use_local_output
else:
raise ValueError(f'Layer plan of wrong type: {type(layer_plan)}')


@pytest.mark.gpu
# @pytest.mark.filterwarnings("error::") # treat warnings like errors
def test_tp_one_gpu():
from icecream import ic
# get train_cfg with tp
train_cfg_path: str = 'scripts/train/yamls/pretrain/mpt-125m.yaml'
with open(train_cfg_path) as f:
train_cfg = om.load(f)


tmp_path = '/my-tmp/c4_small'
dataset_name = create_c4_dataset_xxsmall(tmp_path)
train_cfg = gpt_tiny_cfg(dataset_name, 'gpu')
train_cfg.tp_config = {'strategy': 'ffn'}

with pytest.warns(UserWarning, match='FSDP+TP is not applicable for single-GPU training. Reverting to DDP.'):
train(train_cfg)






# if __name__ == '__main__':
# test_tp_one_gpu()
def test_no_tp_with_one_gpu():
with TemporaryDirectory() as tmp_path:
# train_cfg with ffn tensor parallelism
train_cfg_path: str = 'scripts/train/yamls/pretrain/mpt-125m.yaml'
with open(train_cfg_path, 'r', encoding='utf-8') as f:
train_cfg = om.load(f)
dataset_name = create_c4_dataset_xxsmall(Path(tmp_path))
train_cfg = gpt_tiny_cfg(dataset_name, 'gpu')
train_cfg.tp_config = {'strategy': 'ffn'}

# Expect a warning that we use DDP and not FSDP-TP when we have one GPU.
with pytest.warns(
UserWarning,
match=
r'FSDP\+TP is not applicable for single-GPU training. Reverting to DDP.',
):
train(train_cfg)

0 comments on commit 3921cda

Please sign in to comment.