Skip to content

Commit

Permalink
prune ag node search
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Oct 21, 2024
1 parent fdc5954 commit a5d3779
Showing 1 changed file with 43 additions and 25 deletions.
68 changes: 43 additions & 25 deletions deepspeed/runtime/zero/compile/list_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,46 +274,59 @@ def fast_free_schedule(graph: Graph, available_mem: int, output_size: int, debug
for node in graph.nodes:
assert "tensor_size" in node.meta, f"Node {node} does not have tensor_size"

# scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer = init_schedule(graph)
# tmp_scheduled, tmp_unscheduled = schedule_without_allgather(scheduled, unscheduled, edges)
# runnable = get_runnable_nodes(tmp_scheduled, tmp_unscheduled)

scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer = init_schedule_with_placeholders(
graph)

unscheduled_ags = [n for n in unscheduled if n.target == torch.ops.native_z3.allgather_param]
release_nodes = {n.args[2]: n for n in unscheduled if n.target == torch.ops.native_z3.release_param}

ag_nodes_in_path = {}
for ag_node in unscheduled_ags:
last_use = node_to_last_use[ag_node]
required_nodes = get_node_requirements(last_use, scheduled)
ag_nodes_in_path[ag_node] = set(n for n in required_nodes if n.target == torch.ops.native_z3.allgather_param)

while len(unscheduled_ags) > 0:

ag_nodes_count = {ag_node: len(nodes) for ag_node, nodes in ag_nodes_in_path.items()}
count_list = sorted(set(ag_nodes_count.values()))

runnable_ags = []
for node in unscheduled_ags:
ds_id = node.args[2]
for ag_count in count_list:

schedule_until_ag = get_node_requirements(node, scheduled)
if schedule_until_ag is None:
print(f"Node {node} cannot be scheduled")
continue
target_unscheduled_ags = [ag for ag in unscheduled_ags if ag_nodes_count[ag] == ag_count]

last_use = node_to_last_use[node]
for node in target_unscheduled_ags:
ds_id = node.args[2]

diff_required_nodes = get_node_requirements(last_use, scheduled + schedule_until_ag)
schedule_until_ag = get_node_requirements(node, scheduled)
if schedule_until_ag is None:
continue

allgather_cost = sum(n.meta["device_time"] for n in schedule_until_ag)
free_cost = sum(n.meta["device_time"] for n in diff_required_nodes)
allgathered_mem = node.meta["tensor_size"]
allgather_acc_mem = sum(n.meta["tensor_size"] for n in schedule_until_ag
if n.target == torch.ops.native_z3.allgather_param)
free_acc_mem = sum(n.meta["tensor_size"] for n in diff_required_nodes
if n.target == torch.ops.native_z3.allgather_param)
schedule_until_free = schedule_until_ag + diff_required_nodes + [release_nodes[ds_id]]
last_use = node_to_last_use[node]

n_scheduled_ags = len([n for n in schedule_until_free if n.target == torch.ops.native_z3.allgather_param])
diff_required_nodes = get_node_requirements(last_use, scheduled + schedule_until_ag)

task = AllgatherTask(node, allgather_cost, free_cost, allgathered_mem, allgather_acc_mem, free_acc_mem,
last_use, n_scheduled_ags, schedule_until_ag, schedule_until_free)
allgather_cost = sum(n.meta["device_time"] for n in schedule_until_ag)
free_cost = sum(n.meta["device_time"] for n in diff_required_nodes)
allgathered_mem = node.meta["tensor_size"]
allgather_acc_mem = sum(n.meta["tensor_size"] for n in schedule_until_ag
if n.target == torch.ops.native_z3.allgather_param)
free_acc_mem = sum(n.meta["tensor_size"] for n in diff_required_nodes
if n.target == torch.ops.native_z3.allgather_param)
schedule_until_free = schedule_until_ag + diff_required_nodes + [release_nodes[ds_id]]

# print(f" allgather runnable: {node} last_use: {node_to_last_use[node]} t1: {end_search_1-start_search_1:.2f} t2: {end_search_2-start_search_2:.2f} task: {task}")
runnable_ags.append(task)
n_scheduled_ags = len(
[n for n in schedule_until_free if n.target == torch.ops.native_z3.allgather_param])

task = AllgatherTask(node, allgather_cost, free_cost, allgathered_mem, allgather_acc_mem, free_acc_mem,
last_use, n_scheduled_ags, schedule_until_ag, schedule_until_free)

# print(f" ag_count {ag_count} allgather runnable {i}: {node} last_use: {node_to_last_use[node]} t: {t2-t1:.2f}")
runnable_ags.append(task)

if len(runnable_ags) > 0:
break

assert len(runnable_ags) > 0, "No runnable allgather nodes"

Expand All @@ -339,6 +352,11 @@ def fast_free_schedule(graph: Graph, available_mem: int, output_size: int, debug

unscheduled_ags.remove(next_ag.node)

ag_nodes_in_path.pop(next_ag.node)
for ag_node, nodes in ag_nodes_in_path.items():
if next_ag.node in nodes:
nodes.remove(next_ag.node)

# print(f"After ag scheduled: scheduled: {scheduled}")

scheduled_set = set(scheduled)
Expand Down

0 comments on commit a5d3779

Please sign in to comment.