diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 22a6746bb57c..7b57cf6b4e27 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -639,7 +639,7 @@ def __init__(self, handle, param: Parameter, quantization=None) -> None: self.__param = param self.__quantization = quantization - def wait(self) -> None: + def wait(self, handle_dependency=True) -> None: instrument_w_nvtx(self.__handle.wait)() if self.__quantization: instrument_w_nvtx(self.__quantization.quant_handle.wait)() @@ -650,6 +650,8 @@ def wait(self) -> None: class AllGatherCoalescedHandle: + data_buffer = [] + def __init__( self, allgather_handle, @@ -672,7 +674,7 @@ def __init__( raise RuntimeError(f"expected param {param.ds_summary()} to not be available") @instrument_w_nvtx - def wait(self) -> None: + def wait(self, handle_dependency=True) -> None: if self.complete: return @@ -704,14 +706,20 @@ def wait(self) -> None: partitions.append(part_to_copy) param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape) param.ds_status = ZeroParamStatus.AVAILABLE - - for part_to_copy in partitions: - if not get_accelerator().is_synchronized_device(): - part_to_copy.record_stream(get_accelerator().current_stream()) + if handle_dependency: + for part_to_copy in partitions: + if not get_accelerator().is_synchronized_device(): + part_to_copy.record_stream(get_accelerator().current_stream()) param_offset += ds_tensor_numel self.complete = True + if not handle_dependency: + AllGatherCoalescedHandle.data_buffer.append(partitions) + + @staticmethod + def free_buffer(): + AllGatherCoalescedHandle.data_buffer = [] class MultipleAllGatherHandles: @@ -719,9 +727,9 @@ class MultipleAllGatherHandles: def __init__(self, handles: List[AllGatherCoalescedHandle]): self.handles = handles - def wait(self) -> None: + def wait(self, handle_dependency=True) -> None: for handle in self.handles: - handle.wait() + handle.wait(handle_dependency) class AllReduceCoalescedHandle: @@ -1377,13 +1385,13 @@ def all_gather_coalesced(params: Iterable[Parameter], quantization=quant_info, ) - def partition(param_list=None, hierarchy=0, has_been_updated=False): + def partition(param_list=None, hierarchy=0, has_been_updated=False, free_data=True): cls = param print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}", force=False) if param_list is None: param_list = [cls] - self._partition(param_list, has_been_updated=has_been_updated) + self._partition(param_list, has_been_updated=has_been_updated, free_data=True) def reduce_gradients_at_owner(param_list=None, hierarchy=0): cls = param @@ -1527,12 +1535,12 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): return handles - def _partition(self, param_list, force=False, has_been_updated=False): + def _partition(self, param_list, force=False, has_been_updated=False, free_data=True): for param in param_list: print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False) if self.zero_param_process_group is not None: self._partition_param_sec(param) - self._partition_param(param, has_been_updated=has_been_updated) + self._partition_param(param, has_been_updated=has_been_updated, free_data=True) param.ds_status = ZeroParamStatus.NOT_AVAILABLE # if param.ds_tensor is not None: @@ -1540,7 +1548,7 @@ def _partition(self, param_list, force=False, has_been_updated=False): # "After the parameters are initially partitioned, make sure we are not recreating the partition." #print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False) @instrument_w_nvtx - def _partition_param(self, param, buffer=None, has_been_updated=False): + def _partition_param(self, param, buffer=None, has_been_updated=False, free_data=True): assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight" global reuse_buffers print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False) @@ -1565,7 +1573,8 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False) # param.data does not store anything meaningful in partitioned state - free_param(param) + if free_data: + free_param(param) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) if param.ds_tensor.final_location == OffloadDeviceEnum.nvme: diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 49f477cc4a1b..df02063fe392 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -322,9 +322,10 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events: self.__ongoing_fetch_events.popleft().synchronize() - self.__inflight_param_registry.pop(param).wait() + self.__inflight_param_registry.pop(param).wait( + handle_dependency=not z3_leaf_module(current_submodule)) - if not get_accelerator().handles_memory_backpressure(): + if not get_accelerator().handles_memory_backpressure() and not z3_leaf_module(current_submodule): event = get_accelerator().Event() event.record() self.__ongoing_fetch_events.append(event) @@ -332,6 +333,8 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() if not get_accelerator().resolves_data_dependency(): get_accelerator().current_stream().wait_stream(self.__allgather_stream) + if z3_leaf_module(current_submodule): + AllGatherCoalescedHandle.free_buffer() self.__profiler.stop_event(wait_event_name, wait_numel) # kick off parameter prefetches for upcoming modules @@ -414,9 +417,18 @@ def release_sub_module(self, submodule: Module) -> None: params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set( p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule)))) for param in iter_params(submodule, recurse=z3_leaf_module(submodule)): + free_data = not z3_leaf_module(submodule) + if not free_data: + # wait for the computation to finish and launch as early as possible. + empty_buffer = torch.empty(1, dtype=param.dtype, device=param.device) param.ds_active_sub_modules.discard(submodule.id) if param.ds_id in params_to_release and not param.is_external_param: - self.__release_param(param) + self.__release_param(param, free_data=True) + if not free_data: + if param.ds_id in params_to_release and not param.is_external_param: + # empty buffer ensures that all computations are complete + # and is used for synchronization + param.data = empty_buffer @instrument_w_nvtx @torch.no_grad() @@ -491,11 +503,11 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize: @compiler.disable @instrument_w_nvtx - def __release_param(self, param: Parameter) -> None: + def __release_param(self, param: Parameter, free_data: bool = True) -> None: if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: if logger.isEnabledFor(logging.DEBUG): debug_rank0(f"-release: {param.ds_summary()}") - param.partition() + param.partition(free_data=free_data) self.__n_available_params -= param.ds_numel @instrument_w_nvtx diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py index 74c709883645..b3c85a67a4e3 100644 --- a/tests/unit/runtime/zero/test_zero_leaf_module.py +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -210,7 +210,7 @@ def test_finegrained_optimization(self, module_granularity_threshold: int): }, "zero_optimization": { "stage": 3, - "stage3_prefetch_bucket_size": hidden_dim**2, + "stage3_prefetch_bucket_size": hidden_dim, "stage3_param_persistence_threshold": 0, "stage3_max_reuse_distance": 0, }