Skip to content

Commit

Permalink
fix to run multi allgather inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Dec 18, 2024
1 parent 3355088 commit eda6d8e
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 34 deletions.
33 changes: 13 additions & 20 deletions csrc/compile/native_z3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ class CustomOpExecutor {
for (long ds_id : invalid_ds_ids) { ag_comm_done_events_[ds_id]->record(ag_stream_); }
}

at::Tensor releaseParam(at::Tensor v, long ds_id)
void releaseParam(long ds_id)
{
const DSParam& param = param_registry_->getParam(ds_id);

Expand All @@ -484,12 +484,10 @@ class CustomOpExecutor {

param_registry_->unregisterGatheredParam(ds_id);
}

return v;
}

at::Tensor waitAllgather(at::Tensor v,
long ds_id,
const std::vector<long>& ds_ids,
const std::string& user,
long n_args,
bool is_backward)
Expand All @@ -499,8 +497,10 @@ class CustomOpExecutor {
op_states.decrementArgCounter(user);

if (op_states.isArgCounterZero(user)) {
assert(hasKey(ag_comm_done_events_, ds_id));
ag_comm_done_events_[ds_id]->block(at::cuda::getCurrentCUDAStream());
for (long ds_id : ds_ids) {
assert(hasKey(ag_comm_done_events_, ds_id));
ag_comm_done_events_[ds_id]->block(at::cuda::getCurrentCUDAStream());
}
op_states_fwd_.resetArgCounter();
}

Expand Down Expand Up @@ -945,27 +945,22 @@ at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_
return output_buf;
}

at::Tensor release_param(at::Tensor v, long graph_id, long ds_id)
{
return executors[graph_id]->releaseParam(v, ds_id);
}

at::Tensor release_param_meta(at::Tensor v, long graph_id, long ds_id) { return v; }
void release_param(long graph_id, long ds_id) { executors[graph_id]->releaseParam(ds_id); }

at::Tensor wait_allgather(at::Tensor v,
long graph_id,
long ds_id,
const std::vector<long>& ds_ids,
const std::string& user,
long n_args,
bool is_backward)
{
executors[graph_id]->waitAllgather(v, ds_id, user, n_args, is_backward);
executors[graph_id]->waitAllgather(v, ds_ids, user, n_args, is_backward);
return v;
}

at::Tensor wait_allgather_meta(at::Tensor v,
long graph_id,
long ds_id,
const std::vector<long>& ds_ids,
const std::string& user,
long n_args,
bool is_backward)
Expand Down Expand Up @@ -1057,9 +1052,9 @@ TORCH_LIBRARY(native_z3, m)
{
m.def("allgather_param(Tensor a, int graph_id, int id) -> Tensor");
m.def("prefetch_params_fused(int graph_id, Tensor[] params, int[] ids) -> ()");
m.def("release_param(Tensor a, int graph_id, int id) -> Tensor");
m.def(
"wait_allgather(Tensor a, int graph_id, int id, str user, int n_args, bool bwd) -> Tensor");
"wait_allgather(Tensor a, int graph_id, int[] ids, str user, int n_args, bool bwd) -> "
"Tensor");
m.def("reduce_grad(Tensor a, int graph_id, int id) -> Tensor");
m.def("free_tensors(Tensor[] a) -> ()");
m.def("offload_tensor(Tensor a, int id, int id) -> Tensor");
Expand All @@ -1074,7 +1069,6 @@ TORCH_LIBRARY_IMPL(native_z3, CPU, m)
{
m.impl("allgather_param", &n3z::allgather_param);
m.impl("prefetch_params_fused", &n3z::prefetch_params_fused);
m.impl("release_param", &n3z::release_param);
m.impl("wait_allgather", &n3z::wait_allgather);
m.impl("reduce_grad", &n3z::reduce_grad);
m.impl("free_tensors", &n3z::free_tensors);
Expand All @@ -1090,7 +1084,6 @@ TORCH_LIBRARY_IMPL(native_z3, CUDA, m)
{
m.impl("allgather_param", &n3z::allgather_param);
m.impl("prefetch_params_fused", &n3z::prefetch_params_fused);
m.impl("release_param", &n3z::release_param);
m.impl("wait_allgather", &n3z::wait_allgather);
m.impl("reduce_grad", &n3z::reduce_grad);
m.impl("free_tensors", &n3z::free_tensors);
Expand All @@ -1105,7 +1098,6 @@ TORCH_LIBRARY_IMPL(native_z3, CUDA, m)
TORCH_LIBRARY_IMPL(native_z3, Meta, m)
{
m.impl("allgather_param", &n3z::allgather_param_meta);
m.impl("release_param", &n3z::release_param_meta);
m.impl("wait_allgather", &n3z::wait_allgather_meta);
m.impl("reduce_grad", &n3z::reduce_grad_meta);
}
Expand All @@ -1129,6 +1121,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("end_forward", &n3z::end_forward, "End forward pass");
m.def("start_backward", &n3z::start_backward, "Start backward pass");
// m.def("end_backward", &n3z::end_backward, "End backward pass");
m.def("release_param", &n3z::release_param, "Release a parameter");
m.def("reset", &n3z::reset, "Reset the state");
m.def(
"invalidate_gathered_param", &n3z::invalidate_gathered_param, "Invalidate gathered param");
Expand Down
25 changes: 16 additions & 9 deletions deepspeed/runtime/zero/compile/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from torch.fx import Node, Graph

from deepspeed.ops.op_builder import NativeZ3Builder
from .util import get_last_uses


Expand Down Expand Up @@ -110,21 +111,28 @@ def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int):


def add_release(graph_id: int, graph: Graph, node: Node, release_node: Node, ds_id: int):

nz3 = NativeZ3Builder().load()

def wrap_release_ds_param(x: Any, graph_id: int, ds_id: int):
nz3.release_param(graph_id, ds_id)
return x

add_postprocess(graph,
node,
torch.ops.native_z3.release_param,
wrap_release_ds_param,
extra_args=[graph_id, ds_id],
name=f"release_ds_param_{release_node.target}_{node.name}_{ds_id}",
meta=_make_node_meta(node, ds_id, False))


def add_wait_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int, user: str, n_args: int, bwd: bool):
def add_wait_allgather(graph_id: int, graph: Graph, node: Node, ds_ids: List[int], user: str, n_args: int, bwd: bool):
add_args_process(graph,
node,
torch.ops.native_z3.wait_allgather,
extra_args=[graph_id, ds_id, user, n_args, bwd],
name=f"wait_allgather_ds_param_{ds_id}",
meta=_make_node_meta(node, ds_id, False))
extra_args=[graph_id, ds_ids, user, n_args, bwd],
name=f"wait_allgather_ds_param_{'_'.join([str(ds_id) for ds_id in ds_ids])}",
meta=_make_node_meta(node, ds_ids, False))


def add_reduce(graph_id: int, graph: Graph, grad_node: Node, param_name: str, ds_id: int):
Expand All @@ -149,12 +157,11 @@ def register_and_add_wait_allgather(graph_id: int, graph: Graph, bwd: bool):
if node.target in ops_no_wait:
continue

assert len(ag_args) == 1, f"Node {node.name} takes multiple allgathered params"
ag_wait_nodes.append(node)

ds_id = ag_args[0].meta["ds_id"]
add_wait_allgather(graph_id, graph, node, ds_id, node.name, len(node.args), bwd)
ds_ids.append(ds_id)
ds_ids = [a.meta["ds_id"] for a in ag_args]
add_wait_allgather(graph_id, graph, node, ds_ids, node.name, len(node.args), bwd)
ds_ids.extend(ds_ids)

return ds_ids, ag_wait_nodes

Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/compile/list_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.fx.node import map_arg
from torch.utils._pytree import tree_iter

from .util import get_last_uses
from .util import get_last_uses, is_release_node
from .fx import get_output_node


Expand Down Expand Up @@ -281,7 +281,7 @@ def fast_free_schedule(graph: Graph, available_mem: int, output_size: int, debug
graph)

unscheduled_ags = [n for n in unscheduled if n.target == torch.ops.native_z3.allgather_param]
release_nodes = {n.args[2]: n for n in unscheduled if n.target == torch.ops.native_z3.release_param}
release_nodes = {n.args[2]: n for n in unscheduled if is_release_node(n)}

ag_nodes_in_path = {}
for ag_node in unscheduled_ags:
Expand Down
7 changes: 4 additions & 3 deletions deepspeed/runtime/zero/compile/profilers/graph_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from ..util import is_comm_op
from ..util import is_comm_op, is_release_node


def _all_real_if_tensor(args):
Expand Down Expand Up @@ -132,7 +132,8 @@ def rebuild_param_if_necessary(v):
n.meta["max_memory"] = max_mem
n.meta["tensor_size"] = tensor_size

run_only_once = cache_hit or n.target == torch.ops.native_z3.release_param
is_release_op = is_release_node(n)
run_only_once = cache_hit or is_release_op
iteration = 1 if run_only_once else self.iteration
accelerator = get_accelerator()
start_events = [accelerator.Event(enable_timing=True) for _ in range(iteration)]
Expand Down Expand Up @@ -189,7 +190,7 @@ def partition_param_if_necessary(v):
self.cache[cache_key] = (n.meta["device_time"], n.meta["wall_time"], n.meta["alloc_mem"],
n.meta["max_mem"], n.meta["tensor_size"])

if n.target == torch.ops.native_z3.release_param:
if is_release_op:
n.meta["alloc_mem"] = -self.allgather_mem.get(args[2], 0)

if dist.get_rank() == 0 and self.debug_log:
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/runtime/zero/compile/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,7 @@ def show_memory(label: str):
msg = f"Mem {node.name}"
name = f"show_memory_{node.name}"
graph.create_node('call_function', show_memory, (msg, ), {}, name=name)


def is_release_node(n: Node) -> bool:
return hasattr(n.target, "__name__") and n.target.__name__ == "wrap_release_ds_param"

0 comments on commit eda6d8e

Please sign in to comment.