diff --git a/deepspeed/runtime/zero/compile/fx.py b/deepspeed/runtime/zero/compile/fx.py index 1061320fe0f7..7b405cb94208 100644 --- a/deepspeed/runtime/zero/compile/fx.py +++ b/deepspeed/runtime/zero/compile/fx.py @@ -193,9 +193,16 @@ def add_free_activations(graph_id: int, graph: Graph, activation_node_names: Lis for node, last_user in node_to_last_use.items(): last_user_to_uses[last_user].append(node) + def _should_free(node: Node) -> bool: + if not hasattr(node, "meta"): + return False + if not "tensor_meta" in node.meta: + return False + return True + for last_user, used_nodes in last_user_to_uses.items(): + activation_args = [an for an in used_nodes if an in activation_nodes_set and _should_free(an)] - activation_args = [an for an in used_nodes if an in activation_nodes_set] if len(activation_args) == 0: continue diff --git a/deepspeed/runtime/zero/compile/passes/offload_activation.py b/deepspeed/runtime/zero/compile/passes/offload_activation.py index 9cab9608a888..afbcebf69012 100644 --- a/deepspeed/runtime/zero/compile/passes/offload_activation.py +++ b/deepspeed/runtime/zero/compile/passes/offload_activation.py @@ -1,4 +1,4 @@ -from typing import List, Callable, Any, Dict, Tuple, Set +from typing import List, Dict, Set import random from collections import defaultdict @@ -8,15 +8,14 @@ from ..fx import get_output_node, get_last_uses, move_primals_to_head from ..graph_param import DSGraphParamManager - -OFFLOAD_THRESHOLD = 1024 * 1024 # 1MB - +OFFLOAD_THRESHOLD = 1024 * 1024 # 1MB value_to_id: Dict[int, Dict[str, int]] = defaultdict(dict) used_ids: Set[int] = set() def get_random_id() -> int: + def _gen(): # generate random int return random.randint(10000, 2**31) @@ -41,7 +40,7 @@ def _should_offload(node: Node) -> bool: return False if not "tensor_meta" in node.meta: return False - + shape = node.meta["tensor_meta"].shape numel = 1 for s in shape: @@ -49,8 +48,8 @@ def _should_offload(node: Node) -> bool: return numel > OFFLOAD_THRESHOLD -def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload: List[Node], graph_order: List[int], mem_budget: float, - param_manager: DSGraphParamManager) -> Graph: +def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload: List[Node], graph_order: List[int], + mem_budget: float, param_manager: DSGraphParamManager) -> Graph: # output_node = get_output_node(graph) # output_value_nodes = set(output_node.args[0]) @@ -59,7 +58,7 @@ def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload: List[N import copy cl_graph = copy.deepcopy(graph) cl_graph.erase_node(get_output_node(cl_graph)) - + node_to_last_use, _ = get_last_uses(cl_graph) node_name_to_last_use_name = {k.name: v.name for k, v in node_to_last_use.items()} name_to_node = {n.name: n for n in graph.nodes} @@ -74,9 +73,13 @@ def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload: List[N val_id = get_random_id() with graph.inserting_after(node): - offload_node = graph.create_node('call_function', torch.ops.native_z3.offload_tensor, (node, graph_id, val_id), {}, name=f"offload_{node.name}_{val_id}") + offload_node = graph.create_node('call_function', + torch.ops.native_z3.offload_tensor, (node, graph_id, val_id), {}, + name=f"offload_{node.name}_{val_id}") with graph.inserting_after(offload_node): - wait_node = graph.create_node('call_function', torch.ops.native_z3.wait_tensor_copy, (offload_node, graph_id, val_id), {}, name=f"wait_copy_{node.name}_{val_id}") + wait_node = graph.create_node('call_function', + torch.ops.native_z3.wait_tensor_copy, (offload_node, graph_id, val_id), {}, + name=f"wait_copy_{node.name}_{val_id}") output_node = get_output_node(graph) output_node.replace_input_with(node, wait_node) @@ -90,8 +93,8 @@ def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload: List[N def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], mem_budget: float, - param_manager: DSGraphParamManager) -> Graph: - + param_manager: DSGraphParamManager) -> Graph: + graph_value_to_id = value_to_id[graph_id] act_names = set(graph_value_to_id.keys()) act_nodes = [n for n in graph.nodes if n.name in act_names] @@ -107,9 +110,13 @@ def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], m val_id = graph_value_to_id[node.name] with graph.inserting_before(node_to_first_user[node]): - reload_node = graph.create_node('call_function', torch.ops.native_z3.reload_tensor, (node, graph_id, val_id), {}, name=f"reload_{node.name}_{val_id}") + reload_node = graph.create_node('call_function', + torch.ops.native_z3.reload_tensor, (node, graph_id, val_id), {}, + name=f"reload_{node.name}_{val_id}") with graph.inserting_after(reload_node): - wait_node = graph.create_node('call_function', torch.ops.native_z3.wait_tensor_copy, (reload_node, graph_id, val_id), {}, name=f"wait_copy_{node.name}_{val_id}") + wait_node = graph.create_node('call_function', + torch.ops.native_z3.wait_tensor_copy, (reload_node, graph_id, val_id), {}, + name=f"wait_copy_{node.name}_{val_id}") # replace all uses of node with wait_node users = {}