From d9c36870ffbbd0acc96133dfa9e406eaa156f5c7 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 31 Oct 2024 09:56:33 +0000 Subject: [PATCH] Reduce the device bubble introduced by heavy loop synchronization in coalesced fetch/release --- .../runtime/zero/partition_parameters.py | 29 ++++++++++--------- .../zero/partitioned_param_coordinator.py | 27 +++++++++++++---- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 22a6746bb57cf..81167117eb2da 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -672,7 +672,7 @@ def __init__( raise RuntimeError(f"expected param {param.ds_summary()} to not be available") @instrument_w_nvtx - def wait(self) -> None: + def wait(self, handle_dependency=True) -> None: if self.complete: return @@ -704,14 +704,16 @@ def wait(self) -> None: partitions.append(part_to_copy) param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape) param.ds_status = ZeroParamStatus.AVAILABLE - - for part_to_copy in partitions: - if not get_accelerator().is_synchronized_device(): - part_to_copy.record_stream(get_accelerator().current_stream()) + if handle_dependency: + for part_to_copy in partitions: + if not get_accelerator().is_synchronized_device(): + part_to_copy.record_stream(get_accelerator().current_stream()) param_offset += ds_tensor_numel self.complete = True + if not handle_dependency: + return partitions class MultipleAllGatherHandles: @@ -719,9 +721,9 @@ class MultipleAllGatherHandles: def __init__(self, handles: List[AllGatherCoalescedHandle]): self.handles = handles - def wait(self) -> None: + def wait(self, handle_dependency=True) -> None: for handle in self.handles: - handle.wait() + handle.wait(handle_dependency) class AllReduceCoalescedHandle: @@ -1377,13 +1379,13 @@ def all_gather_coalesced(params: Iterable[Parameter], quantization=quant_info, ) - def partition(param_list=None, hierarchy=0, has_been_updated=False): + def partition(param_list=None, hierarchy=0, has_been_updated=False, free_data=True): cls = param print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}", force=False) if param_list is None: param_list = [cls] - self._partition(param_list, has_been_updated=has_been_updated) + self._partition(param_list, has_been_updated=has_been_updated, free_data=free_data) def reduce_gradients_at_owner(param_list=None, hierarchy=0): cls = param @@ -1527,12 +1529,12 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): return handles - def _partition(self, param_list, force=False, has_been_updated=False): + def _partition(self, param_list, force=False, has_been_updated=False, free_data=True): for param in param_list: print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False) if self.zero_param_process_group is not None: self._partition_param_sec(param) - self._partition_param(param, has_been_updated=has_been_updated) + self._partition_param(param, has_been_updated=has_been_updated, free_data=free_data) param.ds_status = ZeroParamStatus.NOT_AVAILABLE # if param.ds_tensor is not None: @@ -1540,7 +1542,7 @@ def _partition(self, param_list, force=False, has_been_updated=False): # "After the parameters are initially partitioned, make sure we are not recreating the partition." #print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False) @instrument_w_nvtx - def _partition_param(self, param, buffer=None, has_been_updated=False): + def _partition_param(self, param, buffer=None, has_been_updated=False, free_data=True): assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight" global reuse_buffers print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False) @@ -1565,7 +1567,8 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False) # param.data does not store anything meaningful in partitioned state - free_param(param) + if free_data: + free_param(param) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) if param.ds_tensor.final_location == OffloadDeviceEnum.nvme: diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 5780b2afd6def..2b2c57d62f060 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -300,6 +300,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: wait_event_name = __class__.FORWARD_FETCH_WAIT if forward else __class__.BACKWARD_FETCH_WAIT self.__profiler.start_event(wait_event_name) # wait for parameters in the immediately needed submodule to become available + managed_dependency_buffer = [] for param in params_to_fetch: param.ds_active_sub_modules.add(current_submodule.id) if logger.isEnabledFor(logging.DEBUG): @@ -312,9 +313,14 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events: self.__ongoing_fetch_events.popleft().synchronize() - self.__inflight_param_registry.pop(param).wait() + if z3_leaf_module(current_submodule): + # TODO: don't use dep_buffer when overlap_comm=False. + dependency_data = self.__inflight_param_registry.pop(param).wait(handle_dependency=False) + managed_dependency_buffer.append(dependency_data) + else: + self.__inflight_param_registry.pop(param).wait() - if not get_accelerator().handles_memory_backpressure(): + if not get_accelerator().handles_memory_backpressure() and not z3_leaf_module: event = get_accelerator().Event() event.record() self.__ongoing_fetch_events.append(event) @@ -322,6 +328,8 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() if not get_accelerator().resolves_data_dependency(): get_accelerator().current_stream().wait_stream(self.__allgather_stream) + if z3_leaf_module(current_submodule): + del managed_dependency_buffer self.__profiler.stop_event(wait_event_name, wait_numel) # kick off parameter prefetches for upcoming modules @@ -404,9 +412,18 @@ def release_sub_module(self, submodule: Module) -> None: params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set( p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule)))) for param in iter_params(submodule, recurse=z3_leaf_module(submodule)): + coalesced_free = z3_leaf_module(submodule) + if coalesced_free: + # wait for the computation to finish and launch as early as possible. + empty_buffer = torch.empty(1, dtype=param.dtype, device=param.device) param.ds_active_sub_modules.discard(submodule.id) if param.ds_id in params_to_release and not param.is_external_param: - self.__release_param(param) + self.__release_param(param, free_data=not coalesced_free) + if coalesced_free: + if param.ds_id in params_to_release and not param.is_external_param: + # empty buffer ensures that all computations are complete + # and is used for synchronization + param.data = empty_buffer @instrument_w_nvtx @torch.no_grad() @@ -481,11 +498,11 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize: @compiler.disable @instrument_w_nvtx - def __release_param(self, param: Parameter) -> None: + def __release_param(self, param: Parameter, free_data: bool = True) -> None: if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: if logger.isEnabledFor(logging.DEBUG): debug_rank0(f"-release: {param.ds_summary()}") - param.partition() + param.partition(free_data=free_data) self.__n_available_params -= param.ds_numel @instrument_w_nvtx