-
Notifications
You must be signed in to change notification settings - Fork 196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
float8nocompile: add e2e fsdp test #1523
Open
danielvegamyhre
wants to merge
11
commits into
gh/danielvegamyhre/18/head
Choose a base branch
from
gh/danielvegamyhre/19/head
base: gh/danielvegamyhre/18/head
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+97
−3
Open
Changes from 8 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
2145e47
Update
danielvegamyhre 88797b3
Update
danielvegamyhre 89c0d5a
Update
danielvegamyhre d29176e
Update
danielvegamyhre f0bca8c
Update
danielvegamyhre aa50a54
Update
danielvegamyhre 2db5deb
Update
danielvegamyhre 077e8bd
Update
danielvegamyhre 5081694
Update
danielvegamyhre 2a139aa
Update
danielvegamyhre a0544db
Update
danielvegamyhre File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
###################################################################### | ||
# | ||
# To run these unit tests, use the following command: | ||
# | ||
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test/fsdp_test.py | ||
# | ||
####################################################################### | ||
import os | ||
|
||
import pytest | ||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
|
||
from torchao.float8.float8_linear_utils import convert_to_float8_training | ||
from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( | ||
convert_to_float8_nocompile_training, | ||
) | ||
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 | ||
|
||
if not TORCH_VERSION_AT_LEAST_2_5: | ||
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") | ||
|
||
|
||
class TestModel(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.layers = nn.Sequential( | ||
nn.Linear(2048, 4096, bias=False), | ||
nn.Linear(4096, 16, bias=False), | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return self.layers(x) | ||
|
||
|
||
def setup_distributed(): | ||
rank = int(os.environ["RANK"]) | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
dist.init_process_group("nccl", rank=rank, world_size=world_size) | ||
torch.cuda.set_device(rank) | ||
|
||
|
||
@pytest.fixture | ||
def model1(): | ||
torch.manual_seed(0) | ||
return TestModel() | ||
|
||
|
||
@pytest.fixture | ||
def model2(): | ||
torch.manual_seed(0) | ||
return TestModel() | ||
|
||
|
||
def test_model_weights_and_gradients(model1, model2): | ||
assert torch.cuda.is_available() | ||
device = torch.device("cuda") | ||
|
||
setup_distributed() | ||
|
||
model1 = model1.to(torch.bfloat16).to(device) | ||
model2 = model2.to(torch.bfloat16).to(device) | ||
|
||
# compare production float8 linear conversion with no-compile version | ||
convert_to_float8_training(model2) | ||
convert_to_float8_nocompile_training(model1) | ||
|
||
# distributed training with FSDP | ||
model1 = FSDP(model1) | ||
model2 = FSDP(model2) | ||
|
||
input_tensor = torch.randn( | ||
16, 2048, requires_grad=True, dtype=torch.bfloat16, device=device | ||
) | ||
input_copy1 = input_tensor.clone().detach().requires_grad_(True) | ||
input_copy2 = input_tensor.clone().detach().requires_grad_(True) | ||
|
||
loss_fn = nn.MSELoss() | ||
|
||
output1 = model1(input_copy1) | ||
output2 = model2(input_copy2) | ||
|
||
loss1 = loss_fn(output1, torch.zeros_like(output1)) | ||
loss2 = loss_fn(output2, torch.zeros_like(output2)) | ||
|
||
loss1.backward() | ||
loss2.backward() | ||
|
||
dist.destroy_process_group() | ||
|
||
# compare the outputs, weight gradients, and input gradients | ||
assert torch.allclose(output1, output2, atol=0, rtol=0) | ||
assert torch.allclose(input_copy1.grad, input_copy2.grad, atol=0, rtol=0) | ||
for param1, param2 in zip(model1.parameters(), model2.parameters()): | ||
assert torch.allclose(param1.grad, param2.grad, atol=0, rtol=0) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is FSDP1. Can we test FSDP2, which is the new, recently released version of FSDP with support for float8 all-gather?
https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
IMO testing just FSDP2 is fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to test with fsdp2