Skip to content

Commit

Permalink
add activation offloading
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Nov 13, 2024
1 parent 84c27e0 commit 6738f5b
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 9 deletions.
34 changes: 34 additions & 0 deletions csrc/compile/native_z3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,31 @@ at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id)
return at::Tensor();
}

at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id)
{
// auto dims = tensor.sizes();
// std::cout << "offload_tensor graph_id=" << graph_id << " id=" << id
// << " dim=" << join_as_str(dims, ",") << std::endl;
return tensor.to(at::kCPU);
}

at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id)
{
// auto dims = tensor.sizes();
// std::cout << "reload_tensor graph_id=" << graph_id << " id=" << id
// << " dim=" << join_as_str(dims, ",") << std::endl;

return tensor.to(at::kCUDA);
}

at::Tensor wait_tensor_copy(at::Tensor tensor, long graph_id, long id)
{
// auto dims = tensor.sizes();
// std::cout << "wait_tensor_copy graph_id=" << graph_id << " id=" << id
// << " dim=" << join_as_str(dims, ",") << std::endl;
return tensor;
}

void start_forward()
{
lazy_init_symm_memory();
Expand Down Expand Up @@ -894,6 +919,9 @@ TORCH_LIBRARY(native_z3, m)
"wait_allgather(Tensor a, int graph_id, int id, str user, int n_args, bool bwd) -> Tensor");
m.def("reduce_grad(Tensor a, int graph_id, int id) -> Tensor");
m.def("free_tensors(Tensor[] a) -> ()");
m.def("offload_tensor(Tensor a, int id, int id) -> Tensor");
m.def("reload_tensor(Tensor a, int id, int id) -> Tensor");
m.def("wait_tensor_copy(Tensor a, int id, int id) -> Tensor");

m.def("test_call(Tensor a) -> Tensor");
}
Expand All @@ -906,6 +934,9 @@ TORCH_LIBRARY_IMPL(native_z3, CPU, m)
m.impl("wait_allgather", &n3z::wait_allgather);
m.impl("reduce_grad", &n3z::reduce_grad);
m.impl("free_tensors", &n3z::free_tensors);
m.impl("offload_tensor", &n3z::offload_tensor);
m.impl("reload_tensor", &n3z::reload_tensor);
m.impl("wait_tensor_copy", &n3z::wait_tensor_copy);

m.impl("test_call", &n3z::test_call);
}
Expand All @@ -918,6 +949,9 @@ TORCH_LIBRARY_IMPL(native_z3, CUDA, m)
m.impl("wait_allgather", &n3z::wait_allgather);
m.impl("reduce_grad", &n3z::reduce_grad);
m.impl("free_tensors", &n3z::free_tensors);
m.impl("offload_tensor", &n3z::offload_tensor);
m.impl("reload_tensor", &n3z::reload_tensor);
m.impl("wait_tensor_copy", &n3z::wait_tensor_copy);

m.impl("test_call", &n3z::test_call);
}
Expand Down
6 changes: 5 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3698,6 +3698,7 @@ def compile(self,
compile_kwargs={},
schedule=False,
scheduler="simple_prefetch",
offload_activation=False,
double_buffer=True,
use_symmetric_memory=False,
dump_graphs=False) -> None:
Expand Down Expand Up @@ -3770,7 +3771,10 @@ def launch_compile_passes(micro_steps=self.micro_steps,

from deepspeed.runtime.zero.compile.patch_fake_tensor import patch_fake_tensor
patch_fake_tensor()
backend = make_stage3_backend(opt_passes, scheduler=scheduler, dump_graphs=dump_graphs)
backend = make_stage3_backend(opt_passes,
scheduler=scheduler,
offload_activation=offload_activation,
dump_graphs=dump_graphs)

print(f"Compiling with {scheduler}")
if 'backend' in compile_kwargs:
Expand Down
28 changes: 26 additions & 2 deletions deepspeed/runtime/zero/compile/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,23 @@ def get_output_node(graph: Graph):
raise ValueError("No output node found")


def move_primals_to_head(graph: Graph):

# Move primals to the head of the graph
primals = [n for n in graph.nodes if n.op == "placeholder"]
non_primals = [n for n in graph.nodes if n.op != "placeholder"]
all_nodes = primals + non_primals

new_graph = Graph()
env = {}
for node in all_nodes:
new_node = new_graph.node_copy(node, lambda n: env[n.name])
env[node.name] = new_node
new_graph.lint()

return new_graph


def add_args_process(graph: Graph,
node: Node,
fn: Callable[..., Any],
Expand Down Expand Up @@ -84,6 +101,9 @@ def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int):
extra_args=[graph_id, ds_id],
name=f"allgather_ds_param_{node.target}_{ds_id}",
meta=_make_node_meta(node, ds_id, True))

# Set the previous node back to output
# We don't want to change the output node to allgather
output_node = get_output_node(graph)
output_node.replace_input_with(new_node, node)
return new_node
Expand Down Expand Up @@ -139,7 +159,7 @@ def register_and_add_wait_allgather(graph_id: int, graph: Graph, bwd: bool):
return ds_ids, ag_wait_nodes


def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nodes: List[Node]):
def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nodes: List[Node]) -> Graph:
ag_nodes = []
for pn in param_nodes:
ag_node = add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name])
Expand All @@ -151,15 +171,19 @@ def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nod
ds_id = param_manager.ds_ids[pn.name]
add_release(graph_id, graph, last_use, pn, ds_id)

return move_primals_to_head(graph)


def add_gather_and_reduce(graph_id: int, graph: Graph, param_manager, param_nodes_bw: List[Node],
param_name_to_grad: Dict[str, Node]):
param_name_to_grad: Dict[str, Node]) -> Graph:

add_gather_and_release(graph_id, graph, param_manager, param_nodes_bw)

for param_name in param_manager.param_names:
add_reduce(graph_id, graph, param_name_to_grad[param_name], param_name, param_manager.ds_ids[param_name])

return move_primals_to_head(graph)


def add_free_activations(graph_id: int, graph: Graph, activation_node_names: List[str]):
node_to_last_use, _ = get_last_uses(graph)
Expand Down
124 changes: 124 additions & 0 deletions deepspeed/runtime/zero/compile/passes/offload_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from typing import List, Callable, Any, Dict, Tuple, Set
import random
from collections import defaultdict

import torch
from torch.fx import Graph, Node

from ..fx import get_output_node, get_last_uses, move_primals_to_head
from ..graph_param import DSGraphParamManager


OFFLOAD_THRESHOLD = 1024 * 1024 # 1MB


value_to_id: Dict[int, Dict[str, int]] = defaultdict(dict)
used_ids: Set[int] = set()


def get_random_id() -> int:
def _gen():
# generate random int
return random.randint(10000, 2**31)

global used_ids
v = _gen()
while v in used_ids:
v = _gen()
used_ids.add(v)
return v


def _make_node_meta(node: Node, comm: bool):
meta = {"comm": comm}
if "tensor_meta" in node.meta:
meta["tensor_meta"] = node.meta["tensor_meta"]
return meta


def _should_offload(node: Node) -> bool:
if not hasattr(node, "meta"):
return False
if not "tensor_meta" in node.meta:
return False

shape = node.meta["tensor_meta"].shape
numel = 1
for s in shape:
numel *= s
return numel > OFFLOAD_THRESHOLD


def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload: List[Node], graph_order: List[int], mem_budget: float,
param_manager: DSGraphParamManager) -> Graph:
# output_node = get_output_node(graph)
# output_value_nodes = set(output_node.args[0])

param_names = set(param_manager.param_names)

import copy
cl_graph = copy.deepcopy(graph)
cl_graph.erase_node(get_output_node(cl_graph))

node_to_last_use, _ = get_last_uses(cl_graph)
node_name_to_last_use_name = {k.name: v.name for k, v in node_to_last_use.items()}
name_to_node = {n.name: n for n in graph.nodes}

global value_to_id
for node in nodes_to_offload:
if node.name in param_names:
continue

if not _should_offload(node):
continue

val_id = get_random_id()
with graph.inserting_after(node):
offload_node = graph.create_node('call_function', torch.ops.native_z3.offload_tensor, (node, graph_id, val_id), {}, name=f"offload_{node.name}_{val_id}")
with graph.inserting_after(offload_node):
wait_node = graph.create_node('call_function', torch.ops.native_z3.wait_tensor_copy, (offload_node, graph_id, val_id), {}, name=f"wait_copy_{node.name}_{val_id}")

output_node = get_output_node(graph)
output_node.replace_input_with(node, wait_node)

value_to_id[graph_id][node.name] = val_id

graph = move_primals_to_head(graph)

graph.lint()
return graph


def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], mem_budget: float,
param_manager: DSGraphParamManager) -> Graph:

graph_value_to_id = value_to_id[graph_id]
act_names = set(graph_value_to_id.keys())
act_nodes = [n for n in graph.nodes if n.name in act_names]

node_to_first_user = {}
for act in act_nodes:
for node in graph.nodes:
if act in node.args:
node_to_first_user[act] = node
break

for node in act_nodes:
val_id = graph_value_to_id[node.name]

with graph.inserting_before(node_to_first_user[node]):
reload_node = graph.create_node('call_function', torch.ops.native_z3.reload_tensor, (node, graph_id, val_id), {}, name=f"reload_{node.name}_{val_id}")
with graph.inserting_after(reload_node):
wait_node = graph.create_node('call_function', torch.ops.native_z3.wait_tensor_copy, (reload_node, graph_id, val_id), {}, name=f"wait_copy_{node.name}_{val_id}")

# replace all uses of node with wait_node
users = {}
for u in node.users.keys():
if u != reload_node:
users[u] = (node, wait_node)
for u, (old_in, new_in) in users.items():
u.replace_input_with(old_in, new_in)

graph = move_primals_to_head(graph)
graph.lint()
return graph
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/compile/passes/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ def schedule_prefetch(graph: Graph, graph_id: int, graph_order: List[int], profi
ag_tensor_size_sum -= sum([tensor_size_dict[ag_node.name] for ag_node in prefetch_ags])
assert ag_tensor_size_sum == 0

assert ag_tensor_size_sum >= 0

# print_rank_0(
# f"node={node} next_alloc_mem={next_alloc_mem} pending_ags={len(prefetch_ags)} ag_tensor_size_sum={ag_tensor_size_sum}"
# )

assert ag_tensor_size_sum >= 0

new_graph = Graph()
env = {}
for node in reversed(new_order_rev):
Expand Down
22 changes: 18 additions & 4 deletions deepspeed/runtime/zero/compile/stage3_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .profilers import ProfilingResult
from .profilers.graph_profile import ProfilingInterpreter, MemoryProfilingInterpreter
from .passes import run_opt_passes
from .passes.offload_activation import offload_activation_fwd, reload_activation_bwd
from .list_schedule import simple_prefetch, fast_free_schedule
from .util import get_input_nodes, get_param_nodes, NodeValueOffloadHelper, materialize_fake, count_inflight_values
from .partitioner import get_wrapped_partitioner
Expand Down Expand Up @@ -80,7 +81,7 @@ def launch_opt_passes():
reset_graph_order()


def make_stage3_backend(opt_passes, scheduler, dump_graphs=False, debug_log=False):
def make_stage3_backend(opt_passes, scheduler, offload_activation=False, dump_graphs=False, debug_log=False):
from deepspeed.ops.op_builder import NativeZ3Builder
nz3 = NativeZ3Builder().load()
rank = dist.get_rank()
Expand All @@ -97,6 +98,7 @@ def stage3_backend(gm: GraphModule, real_inputs):

offload_helper = NodeValueOffloadHelper(torch.device(get_accelerator().current_device()))
needs_backward = pytree.tree_any(lambda x: x.requires_grad if torch.is_tensor(x) else False, real_inputs)
num_original_outputs = len(get_output_node(gm.graph).args[0])

global graph_order
graph_order.append((graph_id, needs_backward))
Expand All @@ -121,8 +123,14 @@ def fw(gm, sample_inputs):
param_manager[graph_id] = DSGraphParamManager(gm.graph, real_inputs, param_indices)
original_output_names = [n.name for n in get_output_node(gm.graph).args[0]]

add_gather_and_release(graph_id, gm.graph, param_manager[graph_id],
get_param_nodes(gm.graph, param_indices))
gm.graph = add_gather_and_release(graph_id, gm.graph, param_manager[graph_id],
get_param_nodes(gm.graph, param_indices))

if needs_backward and offload_activation:
outputs = get_output_node(gm.graph).args[0]
nodes_to_offload = outputs[num_original_outputs:]
gm.graph = offload_activation_fwd(gm.graph, graph_id, nodes_to_offload, graph_order,
get_accelerator().available_memory(), param_manager[graph_id])

nz3.register_graph(graph_id, [v[1] for v in param_indices]) # Need this before profiling
profiler = ProfilingInterpreter(nz3, gm, debug_log=False)
Expand Down Expand Up @@ -196,6 +204,7 @@ def fw(gm, sample_inputs):
gm = run_opt_passes(nz3, graph_id, gm, real_inputs, opt_passes, graph_order, profiling_results,
param_manager, False, debug_log and rank == 0)

gm.recompile()
return make_boxed_func(gm.forward)

def bw(gm, sample_inputs):
Expand All @@ -205,7 +214,11 @@ def bw(gm, sample_inputs):
assert graph_id in param_manager, f"Graph {graph_id} not found in param_manager"
param_nodes_bw, param_name_to_grad = param_manager[graph_id].get_bwd_mapping(gm.graph)

add_gather_and_reduce(graph_id, gm.graph, param_manager[graph_id], param_nodes_bw, param_name_to_grad)
gm.graph = add_gather_and_reduce(graph_id, gm.graph, param_manager[graph_id], param_nodes_bw,
param_name_to_grad)
if offload_activation:
gm.graph = reload_activation_bwd(gm.graph, graph_id, graph_order,
get_accelerator().available_memory(), param_manager[graph_id])

input_nodes = get_input_nodes(gm.graph)
assert len(input_nodes) == len(
Expand Down Expand Up @@ -285,6 +298,7 @@ def bw(gm, sample_inputs):
gm = run_opt_passes(nz3, graph_id, gm, validated_inputs, opt_passes, graph_order, profiling_results,
param_manager, True, debug_log and rank == 0)

gm.recompile()
return make_boxed_func(gm.forward)

# Call AOTAutograd
Expand Down

0 comments on commit 6738f5b

Please sign in to comment.