Skip to content

Commit

Permalink
add hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Sep 6, 2024
1 parent d7e7f9d commit 35df07d
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 12 deletions.
67 changes: 61 additions & 6 deletions csrc/compile/native_z3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,37 @@ class DSParamRegistry {
const at::Tensor& getGatheredParam(long ds_id) const { return gathered_params_.at(ds_id); }
bool hasGatheredParam(long ds_id) const { return gathered_params_.count(ds_id) > 0; }

void registerOpNArgs(const std::string& op_name, long n_args)
{
op_n_args_.emplace(op_name, n_args);
}

void resetArgCounter(const std::string& op_name)
{
assert(op_n_args_.count(op_name) > 0);
assert(args_counter_.count(op_name) == 0);
args_counter_.emplace(op_name, op_n_args_.at(op_name));
}

void decrementArgCounter(const std::string& op_name)
{
assert(args_counter_.count(op_name) > 0);
if (args_counter_.at(op_name)) return;
args_counter_[op_name]--;
}

bool isArgCounterZero(const std::string& op_name) const
{
assert(args_counter_.count(op_name) > 0);
return args_counter_.at(op_name) == 0;
}

private:
std::unordered_map<long, DSParam> params_;
std::unordered_map<long, at::Tensor> gathered_params_;
std::unordered_map<long, at::Tensor> allgather_handles_;
std::unordered_map<std::string, long> op_n_args_;
std::unordered_map<std::string, long> args_counter_;
};

static DSParamRegistry registry = DSParamRegistry();
Expand Down Expand Up @@ -94,10 +122,31 @@ void register_param(long ds_id,
registry.registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, persistent);
}

void set_process_group(c10::intrusive_ptr<c10d::ProcessGroup> pg)
void register_op_n_args(const std::string& op_name, long n_args)
{
registry.registerOpNArgs(op_name, n_args);
}

void set_process_group(c10::intrusive_ptr<c10d::ProcessGroup> pg) { process_group = pg; }

void start_forward()
{
// std::cout << "start_forward" << std::endl;
}

void end_forward()
{
// std::cout << "end_forward" << std::endl;
}

void start_backward(bool update)
{
// std::cout << "start_backward update=" << update << std::endl;
}

void end_backward(bool update)
{
std::cout << "set_process_group rank=" << pg->getRank() << std::endl;
process_group = pg;
// unused
}

at::Tensor allgather_param(at::Tensor param_tensor, long ds_id)
Expand Down Expand Up @@ -131,9 +180,10 @@ at::Tensor release_param(at::Tensor v, long ds_id)
return v;
}

at::Tensor wait_allgather(at::Tensor v, long ds_id, long n_args)
at::Tensor wait_allgather(at::Tensor v, long ds_id, const std::string& user, long n_args)
{
// std::cout << "wait_allgather ds_id=" << ds_id << " n_args=" << n_args << std::endl;
// std::cout << "wait_allgather ds_id=" << ds_id << " user=" << user << " n_args=" << n_args
// << std::endl;

return v;
}
Expand Down Expand Up @@ -165,7 +215,7 @@ TORCH_LIBRARY(native_z3, m)
m.def("test_call(Tensor a) -> Tensor");
m.def("allgather_param(Tensor a, int id) -> Tensor");
m.def("release_param(Tensor a, int id) -> Tensor");
m.def("wait_allgather(Tensor a, int id, int n_args) -> Tensor");
m.def("wait_allgather(Tensor a, int id, str user, int n_args) -> Tensor");
m.def("reduce_grad(Tensor a, int id) -> Tensor");
}

Expand All @@ -192,4 +242,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("test_call", &test_call, "Test function");
m.def("register_param", &register_param, "Register a parameter");
m.def("set_process_group", &set_process_group, "Set the process group");
m.def("register_op_n_args", &register_op_n_args, "Register the number of arguments for an op");
m.def("start_forward", &start_forward, "Start forward pass");
m.def("end_forward", &end_forward, "End forward pass");
m.def("start_backward", &start_backward, "Start backward pass");
m.def("end_backward", &end_backward, "End backward pass");
}
25 changes: 22 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,8 +1897,27 @@ def forward(self, *inputs, **kwargs):
if self.fp16_auto_cast():
inputs = self._cast_inputs_half(inputs)

if hasattr(self, "nz3"):
self.nz3.start_forward()

loss = self.module(*inputs, **kwargs)

if hasattr(self, "nz3"):
self.nz3.end_forward()

def bwd_hook(grad):
self.nz3.start_backward(self.is_gradient_accumulation_boundary())
return grad

def set_hook(v):
if torch.is_tensor(v) and v.grad_fn is not None:
v.register_hook(bwd_hook)
return v

# `loss` can be any nested structure
from torch.utils._pytree import tree_map
loss = tree_map(set_hook, loss)

if self.zero_optimization_partition_weights() and not self.is_compiled:
# Disable automated discovery of external parameters
for module in self.module.modules():
Expand Down Expand Up @@ -3684,11 +3703,11 @@ def compile(self,
assert self.zero_optimization_stage() == ZeroStageEnum.weights, "Only stage3 support for schedule"

from deepspeed.ops.op_builder import NativeZ3Builder
nz3 = NativeZ3Builder().load()
nz3.set_process_group(self.data_parallel_group)
self.nz3 = NativeZ3Builder().load()
self.nz3.set_process_group(self.data_parallel_group)
for p in self.module.parameters():
grad_buffer = self.optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[p.ds_id]
nz3.register_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist)
self.nz3.register_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist)

for m in self.module.modules():
m._parameters = m._original_parameters
Expand Down
11 changes: 8 additions & 3 deletions deepspeed/runtime/zero/compile/stage3_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
gathered_params = {}
param_map = {}
z3_optimizer = None
nz3 = None


def add_allgather(graph: Graph, node: Node, ds_id: int):
Expand All @@ -45,11 +46,11 @@ def add_release(graph: Graph, node: Node, release_node: Node, ds_id: int):
name=f"release_ds_param_{release_node.target}_{ds_id}")


def add_wait_allgather(graph: Graph, node: Node, ds_id: int, n_args: int):
def add_wait_allgather(graph: Graph, node: Node, ds_id: int, user: str, n_args: int):
return add_postprocess(graph,
node,
torch.ops.native_z3.wait_allgather,
extra_args=[ds_id, n_args],
extra_args=[ds_id, user, n_args],
name=f"wait_allgather_ds_param_{ds_id}")


Expand Down Expand Up @@ -89,9 +90,10 @@ def add_gather_and_release(gm: GraphModule, param_nodes: List[Node], ds_ids: Dic
]
for user in user_nodes[pn]:
n_node_args = len([arg for arg in user.args if isinstance(arg, Node)])
nz3.register_op_n_args(user.name, n_node_args)
for arg in user.args:
if isinstance(arg, Node):
add_wait_allgather(graph, arg, ds_ids[pn.name], n_node_args)
add_wait_allgather(graph, arg, ds_ids[pn.name], user.name, n_node_args)

return allgather_nodes, release_nodes

Expand Down Expand Up @@ -123,6 +125,9 @@ def dump_graph(graph: GraphModule, name: str, skip=False):


def make_stage3_backend(dump_graphs=False):
from deepspeed.ops.op_builder import NativeZ3Builder
global nz3
nz3 = NativeZ3Builder().load()

def stage3_backend(gm: GraphModule, sample_inputs):
# n_params = len(list(gm.named_parameters()))
Expand Down

0 comments on commit 35df07d

Please sign in to comment.