From 722d78f3b4691005410c42dbacf3895e534b2a4b Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 20 Nov 2024 18:36:32 +0000 Subject: [PATCH] hook compiled function class to take bwd inputs --- .../zero/compile/patch_compiled_func.py | 250 ++++++++++++++++++ .../runtime/zero/compile/stage3_backend.py | 122 +++------ deepspeed/runtime/zero/compile/util.py | 35 +-- 3 files changed, 285 insertions(+), 122 deletions(-) create mode 100644 deepspeed/runtime/zero/compile/patch_compiled_func.py diff --git a/deepspeed/runtime/zero/compile/patch_compiled_func.py b/deepspeed/runtime/zero/compile/patch_compiled_func.py new file mode 100644 index 000000000000..27c0b9d1defe --- /dev/null +++ b/deepspeed/runtime/zero/compile/patch_compiled_func.py @@ -0,0 +1,250 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List + +import torch +from torch._prims_common import CUDARngStateHelper +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from torch._functorch._aot_autograd.schemas import ( + OutputType, + SubclassCreationMeta, +) +from torch._functorch._aot_autograd.subclass_utils import ( + get_types_for_subclass, + unwrap_tensor_subclasses, +) +from torch._functorch._aot_autograd.runtime_wrappers import AOTDispatchAutograd +from torch._subclasses import FakeTensor + +backward_inputs = [] + + +# Copied from torch._functorch._aot_autograd.runtime_wrappers +def make_backward_input(CompiledFunction, ctx, flat_args): + num_intermediate_bases = (CompiledFunction.metadata.num_intermediate_bases) + num_mutated_runtime_inps = (CompiledFunction.metadata.num_mutated_inp_runtime_indices) + expected_grad_outs = (CompiledFunction.metadata.num_outputs + num_mutated_runtime_inps + num_intermediate_bases) + deterministic = CompiledFunction.metadata.deterministic + global_deterministic = torch.are_deterministic_algorithms_enabled() + if deterministic is not None: + torch._check( + not (not deterministic and global_deterministic), + lambda: ("This compiled backward function is being run with " + "torch.use_deterministic_algorithms(True), " + "but it was previously generated during the forward function while " + "torch.use_deterministic_algorithms(False) was set."), + ) + + assert len(flat_args) == expected_grad_outs + out_info = CompiledFunction.metadata.output_info + + inp_tangents, out_tangents, intermediate_base_tangents = ( + flat_args[:num_mutated_runtime_inps], + flat_args[num_mutated_runtime_inps:num_mutated_runtime_inps + CompiledFunction.metadata.num_outputs], + flat_args[num_mutated_runtime_inps + CompiledFunction.metadata.num_outputs:], + ) + # input_info contains info on *every* input, + # But in the backward(), we are only given grad outputs for every mutated input + # We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad + input_info = CompiledFunction.metadata.input_info + inp_tangents_filtered = [ + x for x, info_idx in zip( + inp_tangents, + CompiledFunction.metadata.mutated_inp_runtime_indices, + ) if input_info[info_idx].mutates_data and input_info[info_idx].requires_grad + ] + # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates + out_tangents_filtered = [ + x for x, info in zip(out_tangents, out_info) if info.output_type in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] and issubclass(info.raw_type, torch.Tensor) and info.requires_grad + ] + # intermediate bases always require gradients, and always participate in the backward graph. + flat_bw_args_with_grads = [ + *inp_tangents_filtered, + *out_tangents_filtered, + *intermediate_base_tangents, + ] + num_flat_bw_args_with_grads = len(flat_bw_args_with_grads) + + # sanity asserts + # metadata_only_inps = [ + # x for x, info_idx in zip(inp_tangents, mutated_inp_indices) + # if not input_info[info_idx].mutates_data + # ] + # aliased_outputs = [ + # x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias] + # assert all(x is None for x in metadata_only_inps) + # assert all(x is None for x in aliased_outputs) + # TODO: replace this with FunctionalizedRngRuntimeWrapper + rng_args = [] + if CompiledFunction.metadata.is_rng_op_functionalized: + # Add the seed and offset to args + rng_args = CUDARngStateHelper.get_torch_state_as_tuple() + + bw_tokens = [None] * CompiledFunction.metadata.num_backward_tokens + + # - note: donated buffer logic requires (*ctx.symints, *ctx.saved_tensors) showing up first + # in the bw output order. + + # Every dereference of ctx.saved_tensors incurs saved_tensors_hooks calls + # There are tests that count these calls, saving to var. + ctx_saved_tensors = ctx.saved_tensors + num_ctx_saved_tensors = len(ctx_saved_tensors) + all_args = [ + *ctx.symints, + *ctx_saved_tensors, + *flat_bw_args_with_grads, + *bw_tokens, + *rng_args, + ] + + del ctx_saved_tensors + + # Note: [AOTAutograd Backward Guards] + # During AOTDispatch, we eagerly create and trace out a joint fw-bw graph. + # Doing so requires us to "guess" about some of the metadata of our grad_outputs. + # + # In particular: if an output to the forward is a plain tensor or a subclass, + # its corresponding grad_output in the backward **may or may not** be + # a plain tensor or a subclass. The main cases are: + # (1) If an output is a plain tensor, its grad_out will also be a plain tensor, + # *unless* the output is used in some subclass compute later in the forward graph, + # which will cause its grad_output to become a subclass + # (2) If an output is a subclass, its grad_out will also be a subclass, + # *unless* the output of the forward did not actually participate in the gradient computation, + # in which case autograd will insert a plain tensor of zeros for the grad_output. + # We could avoid this case with `torch.autograd.Function.set_materialize_grads`, + # although this is not turned on today in AOTAutgrad and would require more work. + # + # Today, we make a guess on subclass-ness based on the above examples, + # and hard-error in the backward if we guessed wrong. + # + # In the future, we should add backward guards that would allow us to + # properly handle this case instead of erroring: we would need to retrace the backward graph, + # since we might produce an entirely different trace if our grad_outputs are subclass or not. + assert (len(CompiledFunction.metadata.output_types) == num_flat_bw_args_with_grads) + + grad_output_types = [type(x) for x in flat_bw_args_with_grads] + # In general, we can add more asserts/guards here for when we partitioned + # with incorrect assumptions about the grad_outputs. + # Normalize FakeTensor -> torch.Tensor + # - during tracing our types are FakeTensor + # - at runtime in the backward our types are torch.Tensor... + # - unless we're running compiled backward, in which case they are also FakeTensor + grad_output_types_ = [torch.Tensor if x is FakeTensor else x for x in grad_output_types] + assert (grad_output_types_ == CompiledFunction.metadata.output_types), f"""\ +We incorrectly attempted to compile the backward with incorrect subclass metadata. +If you run into this error, please file an issue. +Expected grad_output types: {str(CompiledFunction.metadata.output_types)} +Got grad_output types: {str(grad_output_types)}""" + + del flat_bw_args_with_grads + + tangents_start_idx = (len(all_args) - num_flat_bw_args_with_grads - len(rng_args) - len(bw_tokens)) + assert tangents_start_idx == len(ctx.symints) + num_ctx_saved_tensors + tangents_end_idx = len(all_args) - len(rng_args) - len(bw_tokens) + + # TODO: figure out how to refactor the backward properly + # so I can use aot_dispatch_subclass_wrapper() here. + if CompiledFunction.maybe_subclass_metadata is not None: + tangents = all_args[tangents_start_idx:tangents_end_idx] + + def get_types_for_tangents(tangents): + infos = [] + idx = 0 + for a in tangents: + if isinstance(a, torch.Tensor) and is_traceable_wrapper_subclass(a): + infos.append(get_types_for_subclass(a)) + else: + infos.append(idx) + idx += 1 + return infos + + runtime_subclass_info = get_types_for_tangents(tangents) + + if len(runtime_subclass_info) != len(CompiledFunction.metadata.subclass_tangent_meta): + raise RuntimeError("The grad inputs should be same number as forward output tangents") + for a, b in zip( + runtime_subclass_info, + CompiledFunction.metadata.subclass_tangent_meta, + ): + # Types should match between runtime and traced tangents. + # TODO (tmanlaibaatar) Should actually call coerce_runtime_tangent + if isinstance(a, List) and (isinstance(b, SubclassCreationMeta) and b.subclass_type): + if not a == b.subclass_type: + raise RuntimeError("The grad inputs should be same tensor subclass type as forward output") + + # Get the number of tangents after unwrapping + len_tangents = len(unwrap_tensor_subclasses( + tangents, + is_joint_structure=False, + )) + assert CompiledFunction.metadata.traced_tangent_metas is not None + all_args = [(AOTDispatchAutograd.coerce_runtime_tangent( + t, + CompiledFunction.metadata.traced_tangent_metas[i - tangents_start_idx], + ) if tangents_start_idx <= i < tangents_end_idx else t) for i, t in enumerate(all_args)] + all_args = unwrap_tensor_subclasses(all_args, is_joint_structure=False) + tangents_start_idx = (len(all_args) - len_tangents - len(rng_args) - len(bw_tokens)) + tangents_end_idx = tangents_start_idx + len_tangents + + # Make the tangents contiguous. Note that we must do this after subclass desugaring + # because inputs to inductor have to be contiguous + all_args = [(AOTDispatchAutograd._force_contiguous(t) if (tangents_start_idx <= i < tangents_end_idx) else t) + for i, t in enumerate(all_args)] + + return all_args + + +enabled_patched_func = False + + +def patch_compiled_func(): + base_meta = type(torch.autograd.Function) + + global enabled_patched_func + enabled_patched_func = True + + class FunctionMeta(base_meta): + + def __new__(cls, name, bases, dct): + if name == "CompiledFunction": + original_backward = dct.get("backward", None) + + def wrapped_backward(ctx, *grad_outputs): + + assert original_backward is not None + + if enabled_patched_func: + all_args = make_backward_input(wrapped_backward.owner_class, ctx, grad_outputs) + backward_inputs.append(all_args) + + return original_backward(ctx, *grad_outputs) + + wrapped_backward.owner_class = None + dct["backward"] = staticmethod(wrapped_backward) + new_class = super().__new__(cls, name, bases, dct) + wrapped_backward.owner_class = new_class + + return new_class + + return super().__new__(cls, name, bases, dct) + + class PatchedFunction(torch.autograd.Function, metaclass=FunctionMeta): + pass + + torch.autograd.Function = PatchedFunction + + return backward_inputs + + +def unpatch_compiled_func(): + global enabled_patched_func + enabled_patched_func = False diff --git a/deepspeed/runtime/zero/compile/stage3_backend.py b/deepspeed/runtime/zero/compile/stage3_backend.py index d20004e27a6d..f7505aaa3c65 100644 --- a/deepspeed/runtime/zero/compile/stage3_backend.py +++ b/deepspeed/runtime/zero/compile/stage3_backend.py @@ -24,8 +24,9 @@ from .profilers.graph_profile import ProfilingInterpreter, MemoryProfilingInterpreter from .passes import run_opt_passes from .passes.offload_activation import offload_activation_fwd, reload_activation_bwd +from .patch_compiled_func import patch_compiled_func, unpatch_compiled_func from .list_schedule import simple_prefetch, fast_free_schedule -from .util import get_input_nodes, get_param_nodes, NodeValueOffloadHelper, count_inflight_values, exclude_from_act_offload, OutputCaptureStack, get_activation_node_names, get_bwd_inputs +from .util import get_input_nodes, get_param_nodes, count_inflight_values, exclude_from_act_offload, get_activation_node_names from .partitioner import get_wrapped_partitioner graph_counts = defaultdict(int) @@ -71,6 +72,7 @@ def dump_graph(graph: GraphModule, name: str, skip=False): profiling_results: Dict[int, ProfilingResult] = {} +remaining_bwd_compile_count = 0 enable_opt_passes = False @@ -86,16 +88,7 @@ def make_stage3_backend(opt_passes, scheduler, offload_activation=False, dump_gr nz3 = NativeZ3Builder().load() rank = dist.get_rank() - out_capture = OutputCaptureStack() - original_call = GraphModule.__call__ - - def wrapped_call(*args, **kwargs): - out = original_call(*args, **kwargs) - if out_capture.enabled: - out_capture.push(out) - return out - - GraphModule.__call__ = wrapped_call + bwd_inputs_stack = patch_compiled_func() if scheduler == "simple_prefetch": scheduler_fn = simple_prefetch @@ -107,8 +100,6 @@ def wrapped_call(*args, **kwargs): def stage3_backend(gm: GraphModule, real_inputs): graph_id = id(gm.graph) - acc_device = torch.device(get_accelerator().current_device()) - offload_helper = NodeValueOffloadHelper(acc_device) needs_backward = pytree.tree_any(lambda x: x.requires_grad if torch.is_tensor(x) else False, real_inputs) num_original_outputs = len(get_output_node(gm.graph).args[0]) @@ -146,8 +137,10 @@ def fw(gm, sample_inputs): if not exclude_from_act_offload(node)] gm.graph = offload_activation_fwd(gm.graph, graph_id, nodes_to_offload, graph_order, get_accelerator().available_memory(), param_manager[graph_id]) - nonlocal out_capture - out_capture.enable_for_next() + + if needs_backward: + global remaining_bwd_compile_count + remaining_bwd_compile_count += 1 nz3.register_graph(graph_id, [v[1] for v in param_indices]) # Need this before profiling @@ -156,34 +149,6 @@ def create_fwd_inputs(): profiler = ProfilingInterpreter(nz3, gm, debug_log=False) real_outputs = profiler.run(*create_fwd_inputs()) - - total_activation_size = 0 - if needs_backward: - nonlocal offload_helper - output_node = get_output_node(gm.graph) - mod_output_names = [n.name for n in get_output_node(gm.graph).args[0]] - output_name_map = {n2: n1 for n1, n2 in zip(original_output_names, mod_output_names)} - for n, v in zip(output_node.args[0], real_outputs): - # Save intermediate values on CPU for backward - # We don't move ds parameters - offload_helper.save(output_name_map[n.name], v, not hasattr(v, 'ds_id')) - if torch.is_tensor(v): - total_activation_size += v.numel() * v.element_size() - if rank == 0 and debug_log: - print(f"Total activation size graph_id={graph_id} {total_activation_size / 1024 / 1024:.2f} MB") - ops_with_mem_str = [] - for n, v in zip(output_node.args[0], real_outputs): - if torch.is_tensor(v): - size = v.numel() * v.element_size() - ops_with_mem_str.append(( - size, - f" fw output {n.name} {size / total_activation_size * 100:.1f}% {v.shape} {v.dtype} {v.device} {size / 1024 / 1024:.2f} MB" - )) - else: - ops_with_mem_str.append((0, f" fw output {n.name} {v}")) - ops_with_mem_str.sort(key=lambda x: x[0], reverse=True) - print("\n".join([x[1] for x in ops_with_mem_str])) - del profiler gc.collect() get_accelerator().empty_cache() @@ -191,10 +156,11 @@ def create_fwd_inputs(): if rank == 0 and debug_log: print(f"Fwd before scheduling graph graph_id={graph_id} {gm.graph}") - gm.graph = scheduler_fn(gm.graph, - get_accelerator().available_memory(), - total_activation_size, - debug_log=debug_log) + gm.graph = scheduler_fn( + gm.graph, + get_accelerator().available_memory(), + 0, # unused + debug_log=debug_log) gm.recompile() if rank == 0 and debug_log: @@ -226,28 +192,13 @@ def create_fwd_inputs(): param_manager, False, debug_log and rank == 0) gm.recompile() + return make_boxed_func(gm.forward) def bw(gm, sample_inputs): if rank == 0 and debug_log: print(f"Bwd initial graph graph_id={graph_id} {gm.graph}") - # We profile the memory usage, but PyTorch keeps the activation on device. - # If we allocate the memory for activation, that will double the memory usage for activation. - # So we stash the captured outputs onto CPU memory and create activation using profiler. - # This allows us to profile how memory footprint grows and shrinks during the backward pass while - # avoiding keeping the duplicated activation. - captured_out = out_capture.pop() - offload_args = [] - for out in captured_out: - if torch.is_tensor(out): - if not hasattr(out, 'ds_id'): - device = out.device - out.data = out.data.to("cpu") - offload_args.append((out, device)) - gc.collect() - get_accelerator().empty_cache() - assert graph_id in param_manager, f"Graph {graph_id} not found in param_manager" param_nodes_bw, param_name_to_grad = param_manager[graph_id].get_bwd_mapping(gm.graph) @@ -261,31 +212,22 @@ def bw(gm, sample_inputs): assert len(input_nodes) == len( sample_inputs), f"Expected {len(sample_inputs)} inputs, got {len(input_nodes)}" - nonlocal offload_helper + bwd_real_inputs = bwd_inputs_stack.pop() + offload_args = {} + + for v in bwd_real_inputs: + if torch.is_tensor(v) and not hasattr(v, "ds_id"): + device = v.device + v.data = v.data.to('cpu') + offload_args[v] = device def create_bwd_inputs(): # Inputs can be destroyed during profiling. So we need to materialize them every time we run - return get_bwd_inputs(input_nodes, sample_inputs, offload_helper, acc_device) + args = [a.to(offload_args[a]) if a in offload_args else a for a in bwd_real_inputs] + return tuple(args) real_outputs = ProfilingInterpreter(nz3, gm, debug_log=False).run(*create_bwd_inputs()) - output_size = sum(v.numel() * v.element_size() for v in real_outputs if torch.is_tensor(v)) - if rank == 0 and debug_log: - print(f"Total backward grad size graph_id={graph_id} {output_size / 1024 / 1024:.2f} MB") - ops_with_mem_str = [] - output_node = get_output_node(gm.graph) - for n, v in zip(output_node.args[0], real_outputs): - if torch.is_tensor(v): - size = v.numel() * v.element_size() - ops_with_mem_str.append(( - size, - f" bw output {n.name} {size / output_size * 100:.1f}% {v.shape} {v.dtype} {v.device} {size / 1024 / 1024:.2f} MB" - )) - elif v is not None: - ops_with_mem_str.append((0, f" bw output {n.name} {v}")) - ops_with_mem_str.sort(key=lambda x: x[0], reverse=True) - print("\n".join([x[1] for x in ops_with_mem_str])) - del real_outputs gc.collect() get_accelerator().empty_cache() @@ -293,7 +235,7 @@ def create_bwd_inputs(): if rank == 0 and debug_log: print(f"Bwd before scheduling graph graph_id={graph_id} {gm.graph}") - gm.graph = scheduler_fn(gm.graph, get_accelerator().available_memory(), output_size, debug_log=debug_log) + gm.graph = scheduler_fn(gm.graph, get_accelerator().available_memory(), 0, debug_log=debug_log) if rank == 0 and debug_log: print(f"Bwd after scheduling graph_id={graph_id} {gm.graph}") @@ -303,8 +245,8 @@ def create_bwd_inputs(): _, ag_wait_nodes = register_and_add_wait_allgather(graph_id, gm.graph, True) nz3.register_bwd_graph_ops(graph_id, [n.name for n in ag_wait_nodes], [len(n.args) for n in ag_wait_nodes]) - add_free_activations(graph_id, gm.graph, get_activation_node_names(gm.graph, param_nodes_bw, - offload_helper)) + # add_free_activations(graph_id, gm.graph, + # get_activation_node_names(gm.graph, param_nodes_bw, output_names[graph_id])) dump_graph(gm, f"backward_aot_scheduled_{graph_id}", skip=not dump_graphs) gm.recompile() @@ -326,11 +268,13 @@ def create_bwd_inputs(): gm = run_opt_passes(nz3, graph_id, gm, create_bwd_inputs, opt_passes, graph_order, profiling_results, param_manager, True, debug_log and rank == 0) - # Move the stashed tensors back to the - for out, device in offload_args: - out.data = out.data.to(device) + for v, device in offload_args.items(): + v.data = v.data.to(device) - offload_helper.clear() + global remaining_bwd_compile_count + remaining_bwd_compile_count -= 1 + if remaining_bwd_compile_count == 0: + unpatch_compiled_func() gm.recompile() return make_boxed_func(gm.forward) diff --git a/deepspeed/runtime/zero/compile/util.py b/deepspeed/runtime/zero/compile/util.py index 6586319b9d19..0be31e042b91 100644 --- a/deepspeed/runtime/zero/compile/util.py +++ b/deepspeed/runtime/zero/compile/util.py @@ -224,46 +224,15 @@ def count_inflight_values(graph: Graph, file_path: str): print(f"Data successfully written to {csv_filename}") -class OutputCaptureStack: - - def __init__(self): - self.stack = [] - self.enabled = False - - def push(self, v): - self.stack.append(v) - self.enabled = False - - def pop(self): - return self.stack.pop() - - def enable_for_next(self): - self.enabled = True - - -def get_activation_node_names(graph: Graph, param_nodes_bw: List[Node], offload_helper): +def get_activation_node_names(graph: Graph, param_nodes_bw: List[Node], fwd_output_names: List[str]): input_nodes = get_input_nodes(graph) param_node_names = set([n.name for n in param_nodes_bw]) activation_node_names = [] for in_node in input_nodes: - if offload_helper.has_value(in_node.name): + if in_node.name in fwd_output_names: if in_node.name not in param_node_names: activation_node_names.append(in_node.name) return activation_node_names - - -def get_bwd_inputs(input_nodes: List[Node], sample_inputs, offload_helper, acc_device): - validated_inputs = [] - for in_node, in_val in zip(input_nodes, sample_inputs): - if offload_helper.has_value(in_node.name): - validated_inputs.append(offload_helper.load(in_node.name)) - else: - # Here we materialize the fake value on CPU to reduce the peak memory - # The values are moved to the device memory in the profiler - validated_inputs.append(materialize_fake(in_val, device=acc_device)) - validated_inputs = tuple(validated_inputs) - - return validated_inputs