diff --git a/csrc/compile/native_z3.cpp b/csrc/compile/native_z3.cpp index 663d6d9d56c8..f2a8661dd3c9 100644 --- a/csrc/compile/native_z3.cpp +++ b/csrc/compile/native_z3.cpp @@ -303,6 +303,8 @@ class CustomOpExecutor { at::cuda::CUDAStream ag_stream, at::cuda::CUDAStream rs_stream, at::cuda::CUDAStream copy_stream, + at::cuda::CUDAStream offload_stream, + at::cuda::CUDAStream reload_stream, bool pre_div_reduce) : process_group_(process_group), param_registry_(std::move(param_registry)), @@ -312,6 +314,8 @@ class CustomOpExecutor { ag_stream_(ag_stream), rs_stream_(rs_stream), copy_stream_(copy_stream), + offload_stream_(offload_stream), + reload_stream_(reload_stream), pre_div_reduce_(pre_div_reduce) { for (long ds_id : ds_ids_) { @@ -530,6 +534,77 @@ class CustomOpExecutor { return at::Tensor(); } + at::Tensor offloadTensor(at::Tensor tensor, long id) + { + if (!hasKey(offload_events_, id)) { + offload_events_[id] = std::make_shared(cudaEventDisableTiming); + offload_comp_done_events_[id] = + std::make_shared(cudaEventDisableTiming); + + const auto options = at::TensorOptions().pinned_memory(true).device(torch::kCPU); + offload_buffers_[id] = at::empty_like(tensor, options); + } + + offload_comp_done_events_[id]->record(); + offload_comp_done_events_[id]->block(offload_stream_); + { + at::cuda::CUDAStreamGuard guard(offload_stream_); + offload_buffers_.at(id).copy_(tensor, true); + } + + tensor.record_stream(offload_stream_); + + offload_events_[id]->record(offload_stream_); + assert(hasKey(offload_buffers_, id)); + return offload_buffers_.at(id); + } + + at::Tensor reloadTensor(at::Tensor tensor, long id) + { + if (!hasKey(reload_events_, id)) { + reload_events_[id] = std::make_shared(cudaEventDisableTiming); + } + + assert(hasKey(offload_buffers_, id)); + offload_events_[id]->block(reload_stream_); + + at::Tensor ten; + { + at::cuda::CUDAStreamGuard guard(reload_stream_); + + assert(hasKey(offload_buffers_, id)); + at::Tensor buf = offload_buffers_.at(id); + const auto options = at::TensorOptions().device(torch::kCUDA); + ten = at::empty_like(buf, options); + ten.copy_(buf, true); + + reload_buffers_[id] = ten; + } + + reload_events_[id]->record(reload_stream_); + return ten; + } + + at::Tensor waitOffload(at::Tensor tensor, long id) + { + assert(hasKey(offload_events_, id)); + offload_events_[id]->block(at::cuda::getCurrentCUDAStream()); + + assert(hasKey(offload_buffers_, id)); + return offload_buffers_.at(id); + } + + at::Tensor waitReload(at::Tensor tensor, long id) + { + assert(hasKey(reload_events_, id)); + reload_events_[id]->block(at::cuda::getCurrentCUDAStream()); + + assert(hasKey(reload_buffers_, id)); + auto ten = reload_buffers_.at(id); + reload_buffers_.erase(id); + return ten; + } + private: c10::intrusive_ptr process_group_; std::shared_ptr param_registry_; @@ -539,6 +614,8 @@ class CustomOpExecutor { at::cuda::CUDAStream ag_stream_; at::cuda::CUDAStream rs_stream_; at::cuda::CUDAStream copy_stream_; + at::cuda::CUDAStream offload_stream_; + at::cuda::CUDAStream reload_stream_; GraphOpStates op_states_fwd_ = GraphOpStates(); GraphOpStates op_states_bwd_ = GraphOpStates(); @@ -547,6 +624,13 @@ class CustomOpExecutor { std::unordered_map> rs_comp_done_events_; std::unordered_map> rs_copy_done_events_; + std::unordered_map> offload_events_; + std::unordered_map> offload_comp_done_events_; + std::unordered_map> reload_events_; + std::unordered_map offload_buffers_; + + std::unordered_map reload_buffers_; + size_t reduce_counter_ = 0; bool param_updated_ = false; std::unordered_map> reduce_tasks_; @@ -647,6 +731,8 @@ c10::intrusive_ptr symm_mem = nullptr; static at::cuda::CUDAStream ag_stream = at::cuda::getStreamFromPool(true); static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true); static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true); +static at::cuda::CUDAStream offload_stream = at::cuda::getStreamFromPool(true); +static at::cuda::CUDAStream reload_stream = at::cuda::getStreamFromPool(true); static ncclComm_t nccl_comm; static bool use_symm_mem; static bool profile = false; @@ -683,6 +769,8 @@ void register_graph(long graph_id, const std::vector& ds_ids) ag_stream, rs_stream, copy_stream, + offload_stream, + reload_stream, pre_div_reduce); } @@ -861,7 +949,7 @@ at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id) // auto dims = tensor.sizes(); // std::cout << "offload_tensor graph_id=" << graph_id << " id=" << id // << " dim=" << join_as_str(dims, ",") << std::endl; - return tensor.to(at::kCPU); + return executors[graph_id]->offloadTensor(tensor, id); } at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id) @@ -869,16 +957,17 @@ at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id) // auto dims = tensor.sizes(); // std::cout << "reload_tensor graph_id=" << graph_id << " id=" << id // << " dim=" << join_as_str(dims, ",") << std::endl; + return executors[graph_id]->reloadTensor(tensor, id); +} - return tensor.to(at::kCUDA); +at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id) +{ + return executors[graph_id]->waitOffload(tensor, id); } -at::Tensor wait_tensor_copy(at::Tensor tensor, long graph_id, long id) +at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id) { - // auto dims = tensor.sizes(); - // std::cout << "wait_tensor_copy graph_id=" << graph_id << " id=" << id - // << " dim=" << join_as_str(dims, ",") << std::endl; - return tensor; + return executors[graph_id]->waitReload(tensor, id); } void start_forward() @@ -921,7 +1010,8 @@ TORCH_LIBRARY(native_z3, m) m.def("free_tensors(Tensor[] a) -> ()"); m.def("offload_tensor(Tensor a, int id, int id) -> Tensor"); m.def("reload_tensor(Tensor a, int id, int id) -> Tensor"); - m.def("wait_tensor_copy(Tensor a, int id, int id) -> Tensor"); + m.def("wait_offload(Tensor a, int id, int id) -> Tensor"); + m.def("wait_reload(Tensor a, int id, int id) -> Tensor"); m.def("test_call(Tensor a) -> Tensor"); } @@ -936,7 +1026,8 @@ TORCH_LIBRARY_IMPL(native_z3, CPU, m) m.impl("free_tensors", &n3z::free_tensors); m.impl("offload_tensor", &n3z::offload_tensor); m.impl("reload_tensor", &n3z::reload_tensor); - m.impl("wait_tensor_copy", &n3z::wait_tensor_copy); + m.impl("wait_offload", &n3z::wait_offload); + m.impl("wait_reload", &n3z::wait_reload); m.impl("test_call", &n3z::test_call); } @@ -951,7 +1042,8 @@ TORCH_LIBRARY_IMPL(native_z3, CUDA, m) m.impl("free_tensors", &n3z::free_tensors); m.impl("offload_tensor", &n3z::offload_tensor); m.impl("reload_tensor", &n3z::reload_tensor); - m.impl("wait_tensor_copy", &n3z::wait_tensor_copy); + m.impl("wait_offload", &n3z::wait_offload); + m.impl("wait_reload", &n3z::wait_reload); m.impl("test_call", &n3z::test_call); } diff --git a/deepspeed/runtime/zero/compile/passes/offload_activation.py b/deepspeed/runtime/zero/compile/passes/offload_activation.py index 2f0139a0f73f..36f8e75434ba 100644 --- a/deepspeed/runtime/zero/compile/passes/offload_activation.py +++ b/deepspeed/runtime/zero/compile/passes/offload_activation.py @@ -63,7 +63,7 @@ def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload_with_na name=f"offload_{node.name}_{val_id}") with graph.inserting_after(offload_node): wait_node = graph.create_node('call_function', - torch.ops.native_z3.wait_tensor_copy, (offload_node, graph_id, val_id), {}, + torch.ops.native_z3.wait_offload, (offload_node, graph_id, val_id), {}, name=f"wait_copy_{node.name}_{val_id}") output_node = get_output_node(graph) @@ -100,7 +100,7 @@ def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], m name=f"reload_{node.name}_{val_id}") with graph.inserting_after(reload_node): wait_node = graph.create_node('call_function', - torch.ops.native_z3.wait_tensor_copy, (reload_node, graph_id, val_id), {}, + torch.ops.native_z3.wait_reload, (reload_node, graph_id, val_id), {}, name=f"wait_copy_{node.name}_{val_id}") # replace all uses of node with wait_node