Skip to content

Commit

Permalink
fix sync for offload
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Nov 13, 2024
1 parent 8426e0a commit 424e1e6
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 12 deletions.
112 changes: 102 additions & 10 deletions csrc/compile/native_z3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ class CustomOpExecutor {
at::cuda::CUDAStream ag_stream,
at::cuda::CUDAStream rs_stream,
at::cuda::CUDAStream copy_stream,
at::cuda::CUDAStream offload_stream,
at::cuda::CUDAStream reload_stream,
bool pre_div_reduce)
: process_group_(process_group),
param_registry_(std::move(param_registry)),
Expand All @@ -312,6 +314,8 @@ class CustomOpExecutor {
ag_stream_(ag_stream),
rs_stream_(rs_stream),
copy_stream_(copy_stream),
offload_stream_(offload_stream),
reload_stream_(reload_stream),
pre_div_reduce_(pre_div_reduce)
{
for (long ds_id : ds_ids_) {
Expand Down Expand Up @@ -530,6 +534,77 @@ class CustomOpExecutor {
return at::Tensor();
}

at::Tensor offloadTensor(at::Tensor tensor, long id)
{
if (!hasKey(offload_events_, id)) {
offload_events_[id] = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
offload_comp_done_events_[id] =
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);

const auto options = at::TensorOptions().pinned_memory(true).device(torch::kCPU);
offload_buffers_[id] = at::empty_like(tensor, options);
}

offload_comp_done_events_[id]->record();
offload_comp_done_events_[id]->block(offload_stream_);
{
at::cuda::CUDAStreamGuard guard(offload_stream_);
offload_buffers_.at(id).copy_(tensor, true);
}

tensor.record_stream(offload_stream_);

offload_events_[id]->record(offload_stream_);
assert(hasKey(offload_buffers_, id));
return offload_buffers_.at(id);
}

at::Tensor reloadTensor(at::Tensor tensor, long id)
{
if (!hasKey(reload_events_, id)) {
reload_events_[id] = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
}

assert(hasKey(offload_buffers_, id));
offload_events_[id]->block(reload_stream_);

at::Tensor ten;
{
at::cuda::CUDAStreamGuard guard(reload_stream_);

assert(hasKey(offload_buffers_, id));
at::Tensor buf = offload_buffers_.at(id);
const auto options = at::TensorOptions().device(torch::kCUDA);
ten = at::empty_like(buf, options);
ten.copy_(buf, true);

reload_buffers_[id] = ten;
}

reload_events_[id]->record(reload_stream_);
return ten;
}

at::Tensor waitOffload(at::Tensor tensor, long id)
{
assert(hasKey(offload_events_, id));
offload_events_[id]->block(at::cuda::getCurrentCUDAStream());

assert(hasKey(offload_buffers_, id));
return offload_buffers_.at(id);
}

at::Tensor waitReload(at::Tensor tensor, long id)
{
assert(hasKey(reload_events_, id));
reload_events_[id]->block(at::cuda::getCurrentCUDAStream());

assert(hasKey(reload_buffers_, id));
auto ten = reload_buffers_.at(id);
reload_buffers_.erase(id);
return ten;
}

private:
c10::intrusive_ptr<c10d::ProcessGroup> process_group_;
std::shared_ptr<DSParamRegistry> param_registry_;
Expand All @@ -539,6 +614,8 @@ class CustomOpExecutor {
at::cuda::CUDAStream ag_stream_;
at::cuda::CUDAStream rs_stream_;
at::cuda::CUDAStream copy_stream_;
at::cuda::CUDAStream offload_stream_;
at::cuda::CUDAStream reload_stream_;
GraphOpStates op_states_fwd_ = GraphOpStates();
GraphOpStates op_states_bwd_ = GraphOpStates();

Expand All @@ -547,6 +624,13 @@ class CustomOpExecutor {
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> rs_comp_done_events_;
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> rs_copy_done_events_;

std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> offload_events_;
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> offload_comp_done_events_;
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> reload_events_;
std::unordered_map<long, at::Tensor> offload_buffers_;

std::unordered_map<long, at::Tensor> reload_buffers_;

size_t reduce_counter_ = 0;
bool param_updated_ = false;
std::unordered_map<at::ScalarType, std::vector<ReduceTask>> reduce_tasks_;
Expand Down Expand Up @@ -647,6 +731,8 @@ c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem = nullptr;
static at::cuda::CUDAStream ag_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream offload_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream reload_stream = at::cuda::getStreamFromPool(true);
static ncclComm_t nccl_comm;
static bool use_symm_mem;
static bool profile = false;
Expand Down Expand Up @@ -683,6 +769,8 @@ void register_graph(long graph_id, const std::vector<long>& ds_ids)
ag_stream,
rs_stream,
copy_stream,
offload_stream,
reload_stream,
pre_div_reduce);
}

Expand Down Expand Up @@ -861,24 +949,25 @@ at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id)
// auto dims = tensor.sizes();
// std::cout << "offload_tensor graph_id=" << graph_id << " id=" << id
// << " dim=" << join_as_str(dims, ",") << std::endl;
return tensor.to(at::kCPU);
return executors[graph_id]->offloadTensor(tensor, id);
}

at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id)
{
// auto dims = tensor.sizes();
// std::cout << "reload_tensor graph_id=" << graph_id << " id=" << id
// << " dim=" << join_as_str(dims, ",") << std::endl;
return executors[graph_id]->reloadTensor(tensor, id);
}

return tensor.to(at::kCUDA);
at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id)
{
return executors[graph_id]->waitOffload(tensor, id);
}

at::Tensor wait_tensor_copy(at::Tensor tensor, long graph_id, long id)
at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id)
{
// auto dims = tensor.sizes();
// std::cout << "wait_tensor_copy graph_id=" << graph_id << " id=" << id
// << " dim=" << join_as_str(dims, ",") << std::endl;
return tensor;
return executors[graph_id]->waitReload(tensor, id);
}

void start_forward()
Expand Down Expand Up @@ -921,7 +1010,8 @@ TORCH_LIBRARY(native_z3, m)
m.def("free_tensors(Tensor[] a) -> ()");
m.def("offload_tensor(Tensor a, int id, int id) -> Tensor");
m.def("reload_tensor(Tensor a, int id, int id) -> Tensor");
m.def("wait_tensor_copy(Tensor a, int id, int id) -> Tensor");
m.def("wait_offload(Tensor a, int id, int id) -> Tensor");
m.def("wait_reload(Tensor a, int id, int id) -> Tensor");

m.def("test_call(Tensor a) -> Tensor");
}
Expand All @@ -936,7 +1026,8 @@ TORCH_LIBRARY_IMPL(native_z3, CPU, m)
m.impl("free_tensors", &n3z::free_tensors);
m.impl("offload_tensor", &n3z::offload_tensor);
m.impl("reload_tensor", &n3z::reload_tensor);
m.impl("wait_tensor_copy", &n3z::wait_tensor_copy);
m.impl("wait_offload", &n3z::wait_offload);
m.impl("wait_reload", &n3z::wait_reload);

m.impl("test_call", &n3z::test_call);
}
Expand All @@ -951,7 +1042,8 @@ TORCH_LIBRARY_IMPL(native_z3, CUDA, m)
m.impl("free_tensors", &n3z::free_tensors);
m.impl("offload_tensor", &n3z::offload_tensor);
m.impl("reload_tensor", &n3z::reload_tensor);
m.impl("wait_tensor_copy", &n3z::wait_tensor_copy);
m.impl("wait_offload", &n3z::wait_offload);
m.impl("wait_reload", &n3z::wait_reload);

m.impl("test_call", &n3z::test_call);
}
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/compile/passes/offload_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload_with_na
name=f"offload_{node.name}_{val_id}")
with graph.inserting_after(offload_node):
wait_node = graph.create_node('call_function',
torch.ops.native_z3.wait_tensor_copy, (offload_node, graph_id, val_id), {},
torch.ops.native_z3.wait_offload, (offload_node, graph_id, val_id), {},
name=f"wait_copy_{node.name}_{val_id}")

output_node = get_output_node(graph)
Expand Down Expand Up @@ -100,7 +100,7 @@ def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], m
name=f"reload_{node.name}_{val_id}")
with graph.inserting_after(reload_node):
wait_node = graph.create_node('call_function',
torch.ops.native_z3.wait_tensor_copy, (reload_node, graph_id, val_id), {},
torch.ops.native_z3.wait_reload, (reload_node, graph_id, val_id), {},
name=f"wait_copy_{node.name}_{val_id}")

# replace all uses of node with wait_node
Expand Down

0 comments on commit 424e1e6

Please sign in to comment.