diff --git a/csrc/compile/native_z3.cpp b/csrc/compile/native_z3.cpp index 1e0dd25549b7..663d6d9d56c8 100644 --- a/csrc/compile/native_z3.cpp +++ b/csrc/compile/native_z3.cpp @@ -856,6 +856,31 @@ at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id) return at::Tensor(); } +at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id) +{ + // auto dims = tensor.sizes(); + // std::cout << "offload_tensor graph_id=" << graph_id << " id=" << id + // << " dim=" << join_as_str(dims, ",") << std::endl; + return tensor.to(at::kCPU); +} + +at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id) +{ + // auto dims = tensor.sizes(); + // std::cout << "reload_tensor graph_id=" << graph_id << " id=" << id + // << " dim=" << join_as_str(dims, ",") << std::endl; + + return tensor.to(at::kCUDA); +} + +at::Tensor wait_tensor_copy(at::Tensor tensor, long graph_id, long id) +{ + // auto dims = tensor.sizes(); + // std::cout << "wait_tensor_copy graph_id=" << graph_id << " id=" << id + // << " dim=" << join_as_str(dims, ",") << std::endl; + return tensor; +} + void start_forward() { lazy_init_symm_memory(); @@ -894,6 +919,9 @@ TORCH_LIBRARY(native_z3, m) "wait_allgather(Tensor a, int graph_id, int id, str user, int n_args, bool bwd) -> Tensor"); m.def("reduce_grad(Tensor a, int graph_id, int id) -> Tensor"); m.def("free_tensors(Tensor[] a) -> ()"); + m.def("offload_tensor(Tensor a, int id, int id) -> Tensor"); + m.def("reload_tensor(Tensor a, int id, int id) -> Tensor"); + m.def("wait_tensor_copy(Tensor a, int id, int id) -> Tensor"); m.def("test_call(Tensor a) -> Tensor"); } @@ -906,6 +934,9 @@ TORCH_LIBRARY_IMPL(native_z3, CPU, m) m.impl("wait_allgather", &n3z::wait_allgather); m.impl("reduce_grad", &n3z::reduce_grad); m.impl("free_tensors", &n3z::free_tensors); + m.impl("offload_tensor", &n3z::offload_tensor); + m.impl("reload_tensor", &n3z::reload_tensor); + m.impl("wait_tensor_copy", &n3z::wait_tensor_copy); m.impl("test_call", &n3z::test_call); } @@ -918,6 +949,9 @@ TORCH_LIBRARY_IMPL(native_z3, CUDA, m) m.impl("wait_allgather", &n3z::wait_allgather); m.impl("reduce_grad", &n3z::reduce_grad); m.impl("free_tensors", &n3z::free_tensors); + m.impl("offload_tensor", &n3z::offload_tensor); + m.impl("reload_tensor", &n3z::reload_tensor); + m.impl("wait_tensor_copy", &n3z::wait_tensor_copy); m.impl("test_call", &n3z::test_call); } diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 28475ebe597a..7a23dd9a9e6a 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3698,6 +3698,7 @@ def compile(self, compile_kwargs={}, schedule=False, scheduler="simple_prefetch", + offload_activation=False, double_buffer=True, use_symmetric_memory=False, dump_graphs=False) -> None: @@ -3770,7 +3771,10 @@ def launch_compile_passes(micro_steps=self.micro_steps, from deepspeed.runtime.zero.compile.patch_fake_tensor import patch_fake_tensor patch_fake_tensor() - backend = make_stage3_backend(opt_passes, scheduler=scheduler, dump_graphs=dump_graphs) + backend = make_stage3_backend(opt_passes, + scheduler=scheduler, + offload_activation=offload_activation, + dump_graphs=dump_graphs) print(f"Compiling with {scheduler}") if 'backend' in compile_kwargs: diff --git a/deepspeed/runtime/zero/compile/fx.py b/deepspeed/runtime/zero/compile/fx.py index 9f953a121a6d..1061320fe0f7 100644 --- a/deepspeed/runtime/zero/compile/fx.py +++ b/deepspeed/runtime/zero/compile/fx.py @@ -19,6 +19,23 @@ def get_output_node(graph: Graph): raise ValueError("No output node found") +def move_primals_to_head(graph: Graph): + + # Move primals to the head of the graph + primals = [n for n in graph.nodes if n.op == "placeholder"] + non_primals = [n for n in graph.nodes if n.op != "placeholder"] + all_nodes = primals + non_primals + + new_graph = Graph() + env = {} + for node in all_nodes: + new_node = new_graph.node_copy(node, lambda n: env[n.name]) + env[node.name] = new_node + new_graph.lint() + + return new_graph + + def add_args_process(graph: Graph, node: Node, fn: Callable[..., Any], @@ -84,6 +101,9 @@ def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int): extra_args=[graph_id, ds_id], name=f"allgather_ds_param_{node.target}_{ds_id}", meta=_make_node_meta(node, ds_id, True)) + + # Set the previous node back to output + # We don't want to change the output node to allgather output_node = get_output_node(graph) output_node.replace_input_with(new_node, node) return new_node @@ -139,7 +159,7 @@ def register_and_add_wait_allgather(graph_id: int, graph: Graph, bwd: bool): return ds_ids, ag_wait_nodes -def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nodes: List[Node]): +def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nodes: List[Node]) -> Graph: ag_nodes = [] for pn in param_nodes: ag_node = add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name]) @@ -151,15 +171,19 @@ def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nod ds_id = param_manager.ds_ids[pn.name] add_release(graph_id, graph, last_use, pn, ds_id) + return move_primals_to_head(graph) + def add_gather_and_reduce(graph_id: int, graph: Graph, param_manager, param_nodes_bw: List[Node], - param_name_to_grad: Dict[str, Node]): + param_name_to_grad: Dict[str, Node]) -> Graph: add_gather_and_release(graph_id, graph, param_manager, param_nodes_bw) for param_name in param_manager.param_names: add_reduce(graph_id, graph, param_name_to_grad[param_name], param_name, param_manager.ds_ids[param_name]) + return move_primals_to_head(graph) + def add_free_activations(graph_id: int, graph: Graph, activation_node_names: List[str]): node_to_last_use, _ = get_last_uses(graph) diff --git a/deepspeed/runtime/zero/compile/passes/offload_activation.py b/deepspeed/runtime/zero/compile/passes/offload_activation.py new file mode 100644 index 000000000000..9cab9608a888 --- /dev/null +++ b/deepspeed/runtime/zero/compile/passes/offload_activation.py @@ -0,0 +1,124 @@ +from typing import List, Callable, Any, Dict, Tuple, Set +import random +from collections import defaultdict + +import torch +from torch.fx import Graph, Node + +from ..fx import get_output_node, get_last_uses, move_primals_to_head +from ..graph_param import DSGraphParamManager + + +OFFLOAD_THRESHOLD = 1024 * 1024 # 1MB + + +value_to_id: Dict[int, Dict[str, int]] = defaultdict(dict) +used_ids: Set[int] = set() + + +def get_random_id() -> int: + def _gen(): + # generate random int + return random.randint(10000, 2**31) + + global used_ids + v = _gen() + while v in used_ids: + v = _gen() + used_ids.add(v) + return v + + +def _make_node_meta(node: Node, comm: bool): + meta = {"comm": comm} + if "tensor_meta" in node.meta: + meta["tensor_meta"] = node.meta["tensor_meta"] + return meta + + +def _should_offload(node: Node) -> bool: + if not hasattr(node, "meta"): + return False + if not "tensor_meta" in node.meta: + return False + + shape = node.meta["tensor_meta"].shape + numel = 1 + for s in shape: + numel *= s + return numel > OFFLOAD_THRESHOLD + + +def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload: List[Node], graph_order: List[int], mem_budget: float, + param_manager: DSGraphParamManager) -> Graph: + # output_node = get_output_node(graph) + # output_value_nodes = set(output_node.args[0]) + + param_names = set(param_manager.param_names) + + import copy + cl_graph = copy.deepcopy(graph) + cl_graph.erase_node(get_output_node(cl_graph)) + + node_to_last_use, _ = get_last_uses(cl_graph) + node_name_to_last_use_name = {k.name: v.name for k, v in node_to_last_use.items()} + name_to_node = {n.name: n for n in graph.nodes} + + global value_to_id + for node in nodes_to_offload: + if node.name in param_names: + continue + + if not _should_offload(node): + continue + + val_id = get_random_id() + with graph.inserting_after(node): + offload_node = graph.create_node('call_function', torch.ops.native_z3.offload_tensor, (node, graph_id, val_id), {}, name=f"offload_{node.name}_{val_id}") + with graph.inserting_after(offload_node): + wait_node = graph.create_node('call_function', torch.ops.native_z3.wait_tensor_copy, (offload_node, graph_id, val_id), {}, name=f"wait_copy_{node.name}_{val_id}") + + output_node = get_output_node(graph) + output_node.replace_input_with(node, wait_node) + + value_to_id[graph_id][node.name] = val_id + + graph = move_primals_to_head(graph) + + graph.lint() + return graph + + +def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], mem_budget: float, + param_manager: DSGraphParamManager) -> Graph: + + graph_value_to_id = value_to_id[graph_id] + act_names = set(graph_value_to_id.keys()) + act_nodes = [n for n in graph.nodes if n.name in act_names] + + node_to_first_user = {} + for act in act_nodes: + for node in graph.nodes: + if act in node.args: + node_to_first_user[act] = node + break + + for node in act_nodes: + val_id = graph_value_to_id[node.name] + + with graph.inserting_before(node_to_first_user[node]): + reload_node = graph.create_node('call_function', torch.ops.native_z3.reload_tensor, (node, graph_id, val_id), {}, name=f"reload_{node.name}_{val_id}") + with graph.inserting_after(reload_node): + wait_node = graph.create_node('call_function', torch.ops.native_z3.wait_tensor_copy, (reload_node, graph_id, val_id), {}, name=f"wait_copy_{node.name}_{val_id}") + + # replace all uses of node with wait_node + users = {} + for u in node.users.keys(): + if u != reload_node: + users[u] = (node, wait_node) + for u, (old_in, new_in) in users.items(): + u.replace_input_with(old_in, new_in) + + graph = move_primals_to_head(graph) + graph.lint() + return graph diff --git a/deepspeed/runtime/zero/compile/passes/prefetch.py b/deepspeed/runtime/zero/compile/passes/prefetch.py index bdf38246245f..6a626c641ad5 100644 --- a/deepspeed/runtime/zero/compile/passes/prefetch.py +++ b/deepspeed/runtime/zero/compile/passes/prefetch.py @@ -145,12 +145,12 @@ def schedule_prefetch(graph: Graph, graph_id: int, graph_order: List[int], profi ag_tensor_size_sum -= sum([tensor_size_dict[ag_node.name] for ag_node in prefetch_ags]) assert ag_tensor_size_sum == 0 - assert ag_tensor_size_sum >= 0 - # print_rank_0( # f"node={node} next_alloc_mem={next_alloc_mem} pending_ags={len(prefetch_ags)} ag_tensor_size_sum={ag_tensor_size_sum}" # ) + assert ag_tensor_size_sum >= 0 + new_graph = Graph() env = {} for node in reversed(new_order_rev): diff --git a/deepspeed/runtime/zero/compile/stage3_backend.py b/deepspeed/runtime/zero/compile/stage3_backend.py index 95036cbae31b..2c357936fc91 100644 --- a/deepspeed/runtime/zero/compile/stage3_backend.py +++ b/deepspeed/runtime/zero/compile/stage3_backend.py @@ -23,6 +23,7 @@ from .profilers import ProfilingResult from .profilers.graph_profile import ProfilingInterpreter, MemoryProfilingInterpreter from .passes import run_opt_passes +from .passes.offload_activation import offload_activation_fwd, reload_activation_bwd from .list_schedule import simple_prefetch, fast_free_schedule from .util import get_input_nodes, get_param_nodes, NodeValueOffloadHelper, materialize_fake, count_inflight_values from .partitioner import get_wrapped_partitioner @@ -80,7 +81,7 @@ def launch_opt_passes(): reset_graph_order() -def make_stage3_backend(opt_passes, scheduler, dump_graphs=False, debug_log=False): +def make_stage3_backend(opt_passes, scheduler, offload_activation=False, dump_graphs=False, debug_log=False): from deepspeed.ops.op_builder import NativeZ3Builder nz3 = NativeZ3Builder().load() rank = dist.get_rank() @@ -97,6 +98,7 @@ def stage3_backend(gm: GraphModule, real_inputs): offload_helper = NodeValueOffloadHelper(torch.device(get_accelerator().current_device())) needs_backward = pytree.tree_any(lambda x: x.requires_grad if torch.is_tensor(x) else False, real_inputs) + num_original_outputs = len(get_output_node(gm.graph).args[0]) global graph_order graph_order.append((graph_id, needs_backward)) @@ -121,8 +123,14 @@ def fw(gm, sample_inputs): param_manager[graph_id] = DSGraphParamManager(gm.graph, real_inputs, param_indices) original_output_names = [n.name for n in get_output_node(gm.graph).args[0]] - add_gather_and_release(graph_id, gm.graph, param_manager[graph_id], - get_param_nodes(gm.graph, param_indices)) + gm.graph = add_gather_and_release(graph_id, gm.graph, param_manager[graph_id], + get_param_nodes(gm.graph, param_indices)) + + if needs_backward and offload_activation: + outputs = get_output_node(gm.graph).args[0] + nodes_to_offload = outputs[num_original_outputs:] + gm.graph = offload_activation_fwd(gm.graph, graph_id, nodes_to_offload, graph_order, + get_accelerator().available_memory(), param_manager[graph_id]) nz3.register_graph(graph_id, [v[1] for v in param_indices]) # Need this before profiling profiler = ProfilingInterpreter(nz3, gm, debug_log=False) @@ -196,6 +204,7 @@ def fw(gm, sample_inputs): gm = run_opt_passes(nz3, graph_id, gm, real_inputs, opt_passes, graph_order, profiling_results, param_manager, False, debug_log and rank == 0) + gm.recompile() return make_boxed_func(gm.forward) def bw(gm, sample_inputs): @@ -205,7 +214,11 @@ def bw(gm, sample_inputs): assert graph_id in param_manager, f"Graph {graph_id} not found in param_manager" param_nodes_bw, param_name_to_grad = param_manager[graph_id].get_bwd_mapping(gm.graph) - add_gather_and_reduce(graph_id, gm.graph, param_manager[graph_id], param_nodes_bw, param_name_to_grad) + gm.graph = add_gather_and_reduce(graph_id, gm.graph, param_manager[graph_id], param_nodes_bw, + param_name_to_grad) + if offload_activation: + gm.graph = reload_activation_bwd(gm.graph, graph_id, graph_order, + get_accelerator().available_memory(), param_manager[graph_id]) input_nodes = get_input_nodes(gm.graph) assert len(input_nodes) == len( @@ -285,6 +298,7 @@ def bw(gm, sample_inputs): gm = run_opt_passes(nz3, graph_id, gm, validated_inputs, opt_passes, graph_order, profiling_results, param_manager, True, debug_log and rank == 0) + gm.recompile() return make_boxed_func(gm.forward) # Call AOTAutograd