-
Notifications
You must be signed in to change notification settings - Fork 196
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ghstack-source-id: da38b4a141de7dfee4cec9132967ad76d7d6dc20 ghstack-comment-id: 2576459235 Pull Request resolved: #1523
- Loading branch information
1 parent
b82e72c
commit 886caa8
Showing
2 changed files
with
97 additions
and
3 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
###################################################################### | ||
# | ||
# To run these unit tests, use the following command: | ||
# | ||
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test/fsdp_test.py | ||
# | ||
####################################################################### | ||
import os | ||
|
||
import pytest | ||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
|
||
from torchao.float8.float8_linear_utils import convert_to_float8_training | ||
from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( | ||
convert_to_float8_nocompile_training, | ||
) | ||
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 | ||
|
||
if not TORCH_VERSION_AT_LEAST_2_5: | ||
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") | ||
|
||
|
||
class TestModel(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.layers = nn.Sequential( | ||
nn.Linear(2048, 4096, bias=False), | ||
nn.Linear(4096, 16, bias=False), | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return self.layers(x) | ||
|
||
|
||
def setup_distributed(): | ||
rank = int(os.environ["RANK"]) | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
dist.init_process_group("nccl", rank=rank, world_size=world_size) | ||
torch.cuda.set_device(rank) | ||
|
||
|
||
@pytest.fixture | ||
def model1(): | ||
torch.manual_seed(0) | ||
return TestModel() | ||
|
||
|
||
@pytest.fixture | ||
def model2(): | ||
torch.manual_seed(0) | ||
return TestModel() | ||
|
||
|
||
def test_model_weights_and_gradients(model1, model2): | ||
assert torch.cuda.is_available() | ||
device = torch.device("cuda") | ||
|
||
setup_distributed() | ||
|
||
model1 = model1.to(torch.bfloat16).to(device) | ||
model2 = model2.to(torch.bfloat16).to(device) | ||
|
||
# compare production float8 linear conversion with no-compile version | ||
convert_to_float8_training(model2) | ||
convert_to_float8_nocompile_training(model1) | ||
|
||
# distributed training with FSDP | ||
model1 = FSDP(model1) | ||
model2 = FSDP(model2) | ||
|
||
input_tensor = torch.randn( | ||
16, 2048, requires_grad=True, dtype=torch.bfloat16, device=device | ||
) | ||
input_copy1 = input_tensor.clone().detach().requires_grad_(True) | ||
input_copy2 = input_tensor.clone().detach().requires_grad_(True) | ||
|
||
loss_fn = nn.MSELoss() | ||
|
||
output1 = model1(input_copy1) | ||
output2 = model2(input_copy2) | ||
|
||
loss1 = loss_fn(output1, torch.zeros_like(output1)) | ||
loss2 = loss_fn(output2, torch.zeros_like(output2)) | ||
|
||
loss1.backward() | ||
loss2.backward() | ||
|
||
dist.destroy_process_group() | ||
|
||
# compare the outputs, weight gradients, and input gradients | ||
assert torch.allclose(output1, output2, atol=0, rtol=0) | ||
assert torch.allclose(input_copy1.grad, input_copy2.grad, atol=0, rtol=0) | ||
for param1, param2 in zip(model1.parameters(), model2.parameters()): | ||
assert torch.allclose(param1.grad, param2.grad, atol=0, rtol=0) |