Skip to content

Commit

Permalink
fix freeing activation
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Nov 13, 2024
1 parent 6738f5b commit e779af5
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 15 deletions.
9 changes: 8 additions & 1 deletion deepspeed/runtime/zero/compile/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,16 @@ def add_free_activations(graph_id: int, graph: Graph, activation_node_names: Lis
for node, last_user in node_to_last_use.items():
last_user_to_uses[last_user].append(node)

def _should_free(node: Node) -> bool:
if not hasattr(node, "meta"):
return False
if not "tensor_meta" in node.meta:
return False
return True

for last_user, used_nodes in last_user_to_uses.items():
activation_args = [an for an in used_nodes if an in activation_nodes_set and _should_free(an)]

activation_args = [an for an in used_nodes if an in activation_nodes_set]
if len(activation_args) == 0:
continue

Expand Down
35 changes: 21 additions & 14 deletions deepspeed/runtime/zero/compile/passes/offload_activation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Callable, Any, Dict, Tuple, Set
from typing import List, Dict, Set
import random
from collections import defaultdict

Expand All @@ -8,15 +8,14 @@
from ..fx import get_output_node, get_last_uses, move_primals_to_head
from ..graph_param import DSGraphParamManager


OFFLOAD_THRESHOLD = 1024 * 1024 # 1MB

OFFLOAD_THRESHOLD = 1024 * 1024 # 1MB

value_to_id: Dict[int, Dict[str, int]] = defaultdict(dict)
used_ids: Set[int] = set()


def get_random_id() -> int:

def _gen():
# generate random int
return random.randint(10000, 2**31)
Expand All @@ -41,16 +40,16 @@ def _should_offload(node: Node) -> bool:
return False
if not "tensor_meta" in node.meta:
return False

shape = node.meta["tensor_meta"].shape
numel = 1
for s in shape:
numel *= s
return numel > OFFLOAD_THRESHOLD


def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload: List[Node], graph_order: List[int], mem_budget: float,
param_manager: DSGraphParamManager) -> Graph:
def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload: List[Node], graph_order: List[int],
mem_budget: float, param_manager: DSGraphParamManager) -> Graph:
# output_node = get_output_node(graph)
# output_value_nodes = set(output_node.args[0])

Expand All @@ -59,7 +58,7 @@ def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload: List[N
import copy
cl_graph = copy.deepcopy(graph)
cl_graph.erase_node(get_output_node(cl_graph))

node_to_last_use, _ = get_last_uses(cl_graph)
node_name_to_last_use_name = {k.name: v.name for k, v in node_to_last_use.items()}
name_to_node = {n.name: n for n in graph.nodes}
Expand All @@ -74,9 +73,13 @@ def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload: List[N

val_id = get_random_id()
with graph.inserting_after(node):
offload_node = graph.create_node('call_function', torch.ops.native_z3.offload_tensor, (node, graph_id, val_id), {}, name=f"offload_{node.name}_{val_id}")
offload_node = graph.create_node('call_function',
torch.ops.native_z3.offload_tensor, (node, graph_id, val_id), {},
name=f"offload_{node.name}_{val_id}")
with graph.inserting_after(offload_node):
wait_node = graph.create_node('call_function', torch.ops.native_z3.wait_tensor_copy, (offload_node, graph_id, val_id), {}, name=f"wait_copy_{node.name}_{val_id}")
wait_node = graph.create_node('call_function',
torch.ops.native_z3.wait_tensor_copy, (offload_node, graph_id, val_id), {},
name=f"wait_copy_{node.name}_{val_id}")

output_node = get_output_node(graph)
output_node.replace_input_with(node, wait_node)
Expand All @@ -90,8 +93,8 @@ def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload: List[N


def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], mem_budget: float,
param_manager: DSGraphParamManager) -> Graph:
param_manager: DSGraphParamManager) -> Graph:

graph_value_to_id = value_to_id[graph_id]
act_names = set(graph_value_to_id.keys())
act_nodes = [n for n in graph.nodes if n.name in act_names]
Expand All @@ -107,9 +110,13 @@ def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], m
val_id = graph_value_to_id[node.name]

with graph.inserting_before(node_to_first_user[node]):
reload_node = graph.create_node('call_function', torch.ops.native_z3.reload_tensor, (node, graph_id, val_id), {}, name=f"reload_{node.name}_{val_id}")
reload_node = graph.create_node('call_function',
torch.ops.native_z3.reload_tensor, (node, graph_id, val_id), {},
name=f"reload_{node.name}_{val_id}")
with graph.inserting_after(reload_node):
wait_node = graph.create_node('call_function', torch.ops.native_z3.wait_tensor_copy, (reload_node, graph_id, val_id), {}, name=f"wait_copy_{node.name}_{val_id}")
wait_node = graph.create_node('call_function',
torch.ops.native_z3.wait_tensor_copy, (reload_node, graph_id, val_id), {},
name=f"wait_copy_{node.name}_{val_id}")

# replace all uses of node with wait_node
users = {}
Expand Down

0 comments on commit e779af5

Please sign in to comment.