Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float8nocompile: add e2e fsdp test #1523

Open
wants to merge 11 commits into
base: gh/danielvegamyhre/18/head
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions torchao/prototype/float8nocompile/.gitignore

This file was deleted.

97 changes: 97 additions & 0 deletions torchao/prototype/float8nocompile/test/fsdp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
######################################################################
#
# To run these unit tests, use the following command:
#
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test/fsdp_test.py
#
#######################################################################
import os

import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is FSDP1. Can we test FSDP2, which is the new, recently released version of FSDP with support for float8 all-gather?

https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md

IMO testing just FSDP2 is fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to test with fsdp2


from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.prototype.float8nocompile.float8nocompile_linear_utils import (
convert_to_float8_nocompile_training,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")


class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(2048, 4096, bias=False),
nn.Linear(4096, 16, bias=False),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)


def setup_distributed():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)


@pytest.fixture
def model1():
torch.manual_seed(0)
return TestModel()


@pytest.fixture
def model2():
torch.manual_seed(0)
return TestModel()


def test_model_weights_and_gradients(model1, model2):
assert torch.cuda.is_available()
device = torch.device("cuda")

setup_distributed()

model1 = model1.to(torch.bfloat16).to(device)
model2 = model2.to(torch.bfloat16).to(device)

# compare production float8 linear conversion with no-compile version
convert_to_float8_training(model2)
convert_to_float8_nocompile_training(model1)

# distributed training with FSDP
model1 = FSDP(model1)
model2 = FSDP(model2)

input_tensor = torch.randn(
16, 2048, requires_grad=True, dtype=torch.bfloat16, device=device
)
input_copy1 = input_tensor.clone().detach().requires_grad_(True)
input_copy2 = input_tensor.clone().detach().requires_grad_(True)

loss_fn = nn.MSELoss()

output1 = model1(input_copy1)
output2 = model2(input_copy2)

loss1 = loss_fn(output1, torch.zeros_like(output1))
loss2 = loss_fn(output2, torch.zeros_like(output2))

loss1.backward()
loss2.backward()

dist.destroy_process_group()

# compare the outputs, weight gradients, and input gradients
assert torch.allclose(output1, output2, atol=0, rtol=0)
assert torch.allclose(input_copy1.grad, input_copy2.grad, atol=0, rtol=0)
for param1, param2 in zip(model1.parameters(), model2.parameters()):
assert torch.allclose(param1.grad, param2.grad, atol=0, rtol=0)
Loading