Skip to content

Commit

Permalink
Reduce the device bubble introduced by heavy loop synchronization in …
Browse files Browse the repository at this point in the history
…coalesced fetch/release
  • Loading branch information
inkcherry committed Oct 31, 2024
1 parent 53584ca commit d9c3687
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 18 deletions.
29 changes: 16 additions & 13 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def __init__(
raise RuntimeError(f"expected param {param.ds_summary()} to not be available")

@instrument_w_nvtx
def wait(self) -> None:
def wait(self, handle_dependency=True) -> None:
if self.complete:
return

Expand Down Expand Up @@ -704,24 +704,26 @@ def wait(self) -> None:
partitions.append(part_to_copy)
param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape)
param.ds_status = ZeroParamStatus.AVAILABLE

for part_to_copy in partitions:
if not get_accelerator().is_synchronized_device():
part_to_copy.record_stream(get_accelerator().current_stream())
if handle_dependency:
for part_to_copy in partitions:
if not get_accelerator().is_synchronized_device():
part_to_copy.record_stream(get_accelerator().current_stream())

param_offset += ds_tensor_numel

self.complete = True
if not handle_dependency:
return partitions


class MultipleAllGatherHandles:

def __init__(self, handles: List[AllGatherCoalescedHandle]):
self.handles = handles

def wait(self) -> None:
def wait(self, handle_dependency=True) -> None:
for handle in self.handles:
handle.wait()
handle.wait(handle_dependency)


class AllReduceCoalescedHandle:
Expand Down Expand Up @@ -1377,13 +1379,13 @@ def all_gather_coalesced(params: Iterable[Parameter],
quantization=quant_info,
)

def partition(param_list=None, hierarchy=0, has_been_updated=False):
def partition(param_list=None, hierarchy=0, has_been_updated=False, free_data=True):
cls = param
print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}",
force=False)
if param_list is None:
param_list = [cls]
self._partition(param_list, has_been_updated=has_been_updated)
self._partition(param_list, has_been_updated=has_been_updated, free_data=free_data)

def reduce_gradients_at_owner(param_list=None, hierarchy=0):
cls = param
Expand Down Expand Up @@ -1527,20 +1529,20 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None):

return handles

def _partition(self, param_list, force=False, has_been_updated=False):
def _partition(self, param_list, force=False, has_been_updated=False, free_data=True):
for param in param_list:
print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False)
if self.zero_param_process_group is not None:
self._partition_param_sec(param)
self._partition_param(param, has_been_updated=has_been_updated)
self._partition_param(param, has_been_updated=has_been_updated, free_data=free_data)

param.ds_status = ZeroParamStatus.NOT_AVAILABLE
# if param.ds_tensor is not None:
# assert id(param.data) == id(param.ds_tensor.data), \
# "After the parameters are initially partitioned, make sure we are not recreating the partition."
#print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False)
@instrument_w_nvtx
def _partition_param(self, param, buffer=None, has_been_updated=False):
def _partition_param(self, param, buffer=None, has_been_updated=False, free_data=True):
assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
global reuse_buffers
print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False)
Expand All @@ -1565,7 +1567,8 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):

see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False)
# param.data does not store anything meaningful in partitioned state
free_param(param)
if free_data:
free_param(param)
see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False)

if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
Expand Down
27 changes: 22 additions & 5 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
wait_event_name = __class__.FORWARD_FETCH_WAIT if forward else __class__.BACKWARD_FETCH_WAIT
self.__profiler.start_event(wait_event_name)
# wait for parameters in the immediately needed submodule to become available
managed_dependency_buffer = []
for param in params_to_fetch:
param.ds_active_sub_modules.add(current_submodule.id)
if logger.isEnabledFor(logging.DEBUG):
Expand All @@ -312,16 +313,23 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events:
self.__ongoing_fetch_events.popleft().synchronize()

self.__inflight_param_registry.pop(param).wait()
if z3_leaf_module(current_submodule):
# TODO: don't use dep_buffer when overlap_comm=False.
dependency_data = self.__inflight_param_registry.pop(param).wait(handle_dependency=False)
managed_dependency_buffer.append(dependency_data)
else:
self.__inflight_param_registry.pop(param).wait()

if not get_accelerator().handles_memory_backpressure():
if not get_accelerator().handles_memory_backpressure() and not z3_leaf_module:
event = get_accelerator().Event()
event.record()
self.__ongoing_fetch_events.append(event)

assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().wait_stream(self.__allgather_stream)
if z3_leaf_module(current_submodule):
del managed_dependency_buffer
self.__profiler.stop_event(wait_event_name, wait_numel)

# kick off parameter prefetches for upcoming modules
Expand Down Expand Up @@ -404,9 +412,18 @@ def release_sub_module(self, submodule: Module) -> None:
params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule))))
for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
coalesced_free = z3_leaf_module(submodule)
if coalesced_free:
# wait for the computation to finish and launch as early as possible.
empty_buffer = torch.empty(1, dtype=param.dtype, device=param.device)
param.ds_active_sub_modules.discard(submodule.id)
if param.ds_id in params_to_release and not param.is_external_param:
self.__release_param(param)
self.__release_param(param, free_data=not coalesced_free)
if coalesced_free:
if param.ds_id in params_to_release and not param.is_external_param:
# empty buffer ensures that all computations are complete
# and is used for synchronization
param.data = empty_buffer

@instrument_w_nvtx
@torch.no_grad()
Expand Down Expand Up @@ -481,11 +498,11 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize:

@compiler.disable
@instrument_w_nvtx
def __release_param(self, param: Parameter) -> None:
def __release_param(self, param: Parameter, free_data: bool = True) -> None:
if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules:
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-release: {param.ds_summary()}")
param.partition()
param.partition(free_data=free_data)
self.__n_available_params -= param.ds_numel

@instrument_w_nvtx
Expand Down

0 comments on commit d9c3687

Please sign in to comment.