-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integration with DCP #978
base: unflatten
Are you sure you want to change the base?
Integration with DCP #978
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates | ||
import torch | ||
from pippy import annotate_split_points, Pipe, SplitPoint | ||
import torch.distributed.checkpoint as dcp | ||
import tempfile | ||
|
||
|
||
d_hid = 16 | ||
|
@@ -66,6 +68,49 @@ def get_layers(module): | |
return layers | ||
|
||
|
||
def pipe_to_sd(pipe): | ||
sd = {} | ||
for stage_idx in range(pipe.num_stages): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. something a little fishy about this proposal (equally so for both option 1 and 2) is that it's not likely you'd want to iterate all the stages in the pipe and load/save them. Example 1: simple pipeline with 4 gpus |
||
stage_mod = pipe.get_stage_module(stage_idx) | ||
sd[f"stage_{stage_idx}"] = stage_mod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not really clear to me why we need to add a prefix at all.
There should be no duplication of fqns between submods/stages. what are we doing about the 'submod_0' part in the fqn? when we do If the former, can't we just save/load the keys as usual? If the latter, we can still save/load without a prefix of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Former. @wconstab |
||
return sd | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
#Simulate saving the pipe | ||
# Option 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Option 1 would be more likely used than Option 2 in realistic setting. Could you please uncomment this block of code? |
||
# for stage_idx in range(pipe.num_stages): | ||
# print(f"Saving pipeline stage {stage_idx}") | ||
# stage_mod = pipe.get_stage_module(stage_idx) | ||
# dcp.save( | ||
# {f"stage_{stage_idx}": stage_mod}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious, is the dict required by API of DCP? Can a user directly save There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does this matter? i think the DCP api had reasons for interfacing with dict instead of model, adding a new variant that takes model and gets its dict should be possible, but i think it's clearer this way that the only part of the model that gets saved is the dict There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to be clear: I like saving the state dict too (instead of the module). That's more composable to me. |
||
# checkpoint_id=f"{tmpdir}_{stage_idx}" | ||
# ) | ||
# Option 2: | ||
sd = pipe_to_sd(pipe) | ||
dcp.save(state_dict, checkpoint_id=tmpdir) | ||
|
||
|
||
#Simulate loading the pipe | ||
# Option 1: | ||
# for stage_idx in range(pipe.num_stages): | ||
# print(f"Loading pipeline stage {stage_idx}") | ||
# stage_mod = pipe.get_stage_module(stage_idx) | ||
# dcp.load( | ||
# {f"stage_{stage_idx}": stage_mod}, | ||
# checkpoint_id=f"{tmpdir}_{stage_idx}" | ||
# ) | ||
|
||
#Option 2: | ||
new_pipe = Pipe.from_tracing( | ||
transformer, | ||
1, | ||
(x,), | ||
) | ||
sd = pipe_to_sd(new_pipe) | ||
dcp.load(sd, checkpoint_id=tmpdir) | ||
|
||
pipe = new_pipe | ||
|
||
# Collect all layers in pipe | ||
layers = [] | ||
for stage_idx in range(pipe.num_stages): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wz337 , might be interesting in dist state dict