From a5d3779eadf9723ede9958c77b5ab0c154c44d72 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 21 Oct 2024 23:54:17 +0000 Subject: [PATCH] prune ag node search --- .../runtime/zero/compile/list_schedule.py | 68 ++++++++++++------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/deepspeed/runtime/zero/compile/list_schedule.py b/deepspeed/runtime/zero/compile/list_schedule.py index e5ac3e8062cf..fb915cbe1534 100644 --- a/deepspeed/runtime/zero/compile/list_schedule.py +++ b/deepspeed/runtime/zero/compile/list_schedule.py @@ -274,46 +274,59 @@ def fast_free_schedule(graph: Graph, available_mem: int, output_size: int, debug for node in graph.nodes: assert "tensor_size" in node.meta, f"Node {node} does not have tensor_size" - # scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer = init_schedule(graph) - # tmp_scheduled, tmp_unscheduled = schedule_without_allgather(scheduled, unscheduled, edges) - # runnable = get_runnable_nodes(tmp_scheduled, tmp_unscheduled) - scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer = init_schedule_with_placeholders( graph) unscheduled_ags = [n for n in unscheduled if n.target == torch.ops.native_z3.allgather_param] release_nodes = {n.args[2]: n for n in unscheduled if n.target == torch.ops.native_z3.release_param} + ag_nodes_in_path = {} + for ag_node in unscheduled_ags: + last_use = node_to_last_use[ag_node] + required_nodes = get_node_requirements(last_use, scheduled) + ag_nodes_in_path[ag_node] = set(n for n in required_nodes if n.target == torch.ops.native_z3.allgather_param) + while len(unscheduled_ags) > 0: + + ag_nodes_count = {ag_node: len(nodes) for ag_node, nodes in ag_nodes_in_path.items()} + count_list = sorted(set(ag_nodes_count.values())) + runnable_ags = [] - for node in unscheduled_ags: - ds_id = node.args[2] + for ag_count in count_list: - schedule_until_ag = get_node_requirements(node, scheduled) - if schedule_until_ag is None: - print(f"Node {node} cannot be scheduled") - continue + target_unscheduled_ags = [ag for ag in unscheduled_ags if ag_nodes_count[ag] == ag_count] - last_use = node_to_last_use[node] + for node in target_unscheduled_ags: + ds_id = node.args[2] - diff_required_nodes = get_node_requirements(last_use, scheduled + schedule_until_ag) + schedule_until_ag = get_node_requirements(node, scheduled) + if schedule_until_ag is None: + continue - allgather_cost = sum(n.meta["device_time"] for n in schedule_until_ag) - free_cost = sum(n.meta["device_time"] for n in diff_required_nodes) - allgathered_mem = node.meta["tensor_size"] - allgather_acc_mem = sum(n.meta["tensor_size"] for n in schedule_until_ag - if n.target == torch.ops.native_z3.allgather_param) - free_acc_mem = sum(n.meta["tensor_size"] for n in diff_required_nodes - if n.target == torch.ops.native_z3.allgather_param) - schedule_until_free = schedule_until_ag + diff_required_nodes + [release_nodes[ds_id]] + last_use = node_to_last_use[node] - n_scheduled_ags = len([n for n in schedule_until_free if n.target == torch.ops.native_z3.allgather_param]) + diff_required_nodes = get_node_requirements(last_use, scheduled + schedule_until_ag) - task = AllgatherTask(node, allgather_cost, free_cost, allgathered_mem, allgather_acc_mem, free_acc_mem, - last_use, n_scheduled_ags, schedule_until_ag, schedule_until_free) + allgather_cost = sum(n.meta["device_time"] for n in schedule_until_ag) + free_cost = sum(n.meta["device_time"] for n in diff_required_nodes) + allgathered_mem = node.meta["tensor_size"] + allgather_acc_mem = sum(n.meta["tensor_size"] for n in schedule_until_ag + if n.target == torch.ops.native_z3.allgather_param) + free_acc_mem = sum(n.meta["tensor_size"] for n in diff_required_nodes + if n.target == torch.ops.native_z3.allgather_param) + schedule_until_free = schedule_until_ag + diff_required_nodes + [release_nodes[ds_id]] - # print(f" allgather runnable: {node} last_use: {node_to_last_use[node]} t1: {end_search_1-start_search_1:.2f} t2: {end_search_2-start_search_2:.2f} task: {task}") - runnable_ags.append(task) + n_scheduled_ags = len( + [n for n in schedule_until_free if n.target == torch.ops.native_z3.allgather_param]) + + task = AllgatherTask(node, allgather_cost, free_cost, allgathered_mem, allgather_acc_mem, free_acc_mem, + last_use, n_scheduled_ags, schedule_until_ag, schedule_until_free) + + # print(f" ag_count {ag_count} allgather runnable {i}: {node} last_use: {node_to_last_use[node]} t: {t2-t1:.2f}") + runnable_ags.append(task) + + if len(runnable_ags) > 0: + break assert len(runnable_ags) > 0, "No runnable allgather nodes" @@ -339,6 +352,11 @@ def fast_free_schedule(graph: Graph, available_mem: int, output_size: int, debug unscheduled_ags.remove(next_ag.node) + ag_nodes_in_path.pop(next_ag.node) + for ag_node, nodes in ag_nodes_in_path.items(): + if next_ag.node in nodes: + nodes.remove(next_ag.node) + # print(f"After ag scheduled: scheduled: {scheduled}") scheduled_set = set(scheduled)