Skip to content

Commit

Permalink
Reduce the device bubble introduced by heavy loop synchronization in …
Browse files Browse the repository at this point in the history
…coalesced fetch/release(z3_leaf_module) (#6694)

depend on #6649

When performing fetch/release operations on Z3 leaf modules, the loop
time is excessively long in fine-grained module. Compared to non-leaf
modules, Z3 leaf modules may include a larger number of parameters.
Although each loop unit does not consume much time, the overall loop
length can be significant.

![image](https://github.com/user-attachments/assets/9891835a-2620-47f3-aba6-ea22b8905d1c)
**The fetch time is impacted by:**

Post-allgather operations (narrow, slice ,cat, difficult to avoid)
Memory pressure(record_stream/fetch event create&sync)
**The release time is impacted by:**
slice
Free parameter record_stream

Considering the fine-grained leaf modules, where each parameter is
relatively small, we can treat the parameters within each leaf module as
a unified entity to handle memory pressure. This approach can
approximately halve the CPU time required for fetch/release operations.

---------

Co-authored-by: Ma, Guokai <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
4 people authored Jan 6, 2025
1 parent c5e48f4 commit b0040b6
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 41 deletions.
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/mics.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class MiCS_AllGatherCoalescedHandle(AllGatherCoalescedHandle):
def __init__(self, allgather_handle, params: List[Parameter], partitions: List[Tensor], world_size: int) -> None:
super().__init__(allgather_handle, params, partitions, world_size)

def wait(self) -> None:
def wait(self, **kwargs) -> None:
"""
"""
# let the current stream to op
Expand Down
19 changes: 11 additions & 8 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,16 @@ def __init__(
module.ds_inflight_param_registry = InflightParamRegistry()
self.__inflight_param_registry = module.ds_inflight_param_registry

self.fast_sharding_for_leaf_module = False

if zero_module_granularity_threshold > 0:
self.min_granularity_value = sys.maxsize
self.min_granularity_layer = None
self.granularity_info = set()
self.z3_leaf_layers = []
self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold)
self.fast_sharding_for_leaf_module = True

self.param_coordinator = PartitionedParameterCoordinator(
prefetch_bucket_sz=self._prefetch_bucket_sz,
max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
Expand All @@ -155,14 +165,7 @@ def __init__(
timers=self.timers,
zero_quantized_weights=self.zero_quantized_weights,
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
)

if zero_module_granularity_threshold > 0:
self.min_granularity_value = sys.maxsize
self.min_granularity_layer = None
self.granularity_info = set()
self.z3_leaf_layers = []
self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold)
fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module)

self.forward_hooks = []
self.backward_hooks = []
Expand Down
39 changes: 24 additions & 15 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, param: Parameter) -> None:
non_blocking=True).view(param.ds_shape)
self.__param = param

def wait(self) -> None:
def wait(self, **kwargs) -> None:
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().synchronize()
self.__param.ds_status = ZeroParamStatus.AVAILABLE
Expand All @@ -78,7 +78,7 @@ def __init__(self, params: List[Parameter]) -> None:
non_blocking=True).view(param.ds_shape)

@instrument_w_nvtx
def wait(self) -> None:
def wait(self, **kwargs) -> None:
if self.__complete:
return

Expand Down Expand Up @@ -639,7 +639,7 @@ def __init__(self, handle, param: Parameter, quantization=None) -> None:
self.__param = param
self.__quantization = quantization

def wait(self) -> None:
def wait(self, handle_dependency=True) -> None:
instrument_w_nvtx(self.__handle.wait)()
if self.__quantization:
instrument_w_nvtx(self.__quantization.quant_handle.wait)()
Expand All @@ -650,6 +650,8 @@ def wait(self) -> None:

class AllGatherCoalescedHandle:

data_buffer = []

def __init__(
self,
allgather_handle,
Expand All @@ -672,7 +674,7 @@ def __init__(
raise RuntimeError(f"expected param {param.ds_summary()} to not be available")

@instrument_w_nvtx
def wait(self) -> None:
def wait(self, handle_dependency=True) -> None:
if self.complete:
return

Expand Down Expand Up @@ -704,24 +706,30 @@ def wait(self) -> None:
partitions.append(part_to_copy)
param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape)
param.ds_status = ZeroParamStatus.AVAILABLE

for part_to_copy in partitions:
if not get_accelerator().is_synchronized_device():
if not get_accelerator().is_synchronized_device() and handle_dependency:
for part_to_copy in partitions:
part_to_copy.record_stream(get_accelerator().current_stream())

param_offset += ds_tensor_numel

self.complete = True
if not get_accelerator().is_synchronized_device() and not handle_dependency:
# if the device needs to handle dependencies and opts for explicit processing outside the function.
AllGatherCoalescedHandle.data_buffer.append(partitions)

@staticmethod
def free_buffer():
AllGatherCoalescedHandle.data_buffer = []


class MultipleAllGatherHandles:

def __init__(self, handles: List[AllGatherCoalescedHandle]):
self.handles = handles

def wait(self) -> None:
def wait(self, handle_dependency=True) -> None:
for handle in self.handles:
handle.wait()
handle.wait(handle_dependency)


class AllReduceCoalescedHandle:
Expand Down Expand Up @@ -1377,13 +1385,13 @@ def all_gather_coalesced(params: Iterable[Parameter],
quantization=quant_info,
)

def partition(param_list=None, hierarchy=0, has_been_updated=False):
def partition(param_list=None, hierarchy=0, has_been_updated=False, free_data=True):
cls = param
print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}",
force=False)
if param_list is None:
param_list = [cls]
self._partition(param_list, has_been_updated=has_been_updated)
self._partition(param_list, has_been_updated=has_been_updated, free_data=True)

def reduce_gradients_at_owner(param_list=None, hierarchy=0):
cls = param
Expand Down Expand Up @@ -1527,20 +1535,20 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None):

return handles

def _partition(self, param_list, force=False, has_been_updated=False):
def _partition(self, param_list, force=False, has_been_updated=False, free_data=True):
for param in param_list:
print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False)
if self.zero_param_process_group is not None:
self._partition_param_sec(param)
self._partition_param(param, has_been_updated=has_been_updated)
self._partition_param(param, has_been_updated=has_been_updated, free_data=True)

param.ds_status = ZeroParamStatus.NOT_AVAILABLE
# if param.ds_tensor is not None:
# assert id(param.data) == id(param.ds_tensor.data), \
# "After the parameters are initially partitioned, make sure we are not recreating the partition."
#print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False)
@instrument_w_nvtx
def _partition_param(self, param, buffer=None, has_been_updated=False):
def _partition_param(self, param, buffer=None, has_been_updated=False, free_data=True):
assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
global reuse_buffers
print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False)
Expand All @@ -1565,7 +1573,8 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):

see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False)
# param.data does not store anything meaningful in partitioned state
free_param(param)
if free_data:
free_param(param)
see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False)

if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
Expand Down
50 changes: 33 additions & 17 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,17 @@ class __ParamInTrace:
param: Parameter
step_id_last_used_at: int

def __init__(
self,
prefetch_bucket_sz: int,
max_reuse_distance_in_numel: int,
max_available_parameters_in_numel: int,
allgather_stream: get_accelerator().Stream,
inflight_param_registry: InflightParamRegistry,
prefetch_nvme: bool = False,
timers=None,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
) -> None:
def __init__(self,
prefetch_bucket_sz: int,
max_reuse_distance_in_numel: int,
max_available_parameters_in_numel: int,
allgather_stream: get_accelerator().Stream,
inflight_param_registry: InflightParamRegistry,
prefetch_nvme: bool = False,
timers=None,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
fast_sharding_for_leaf_module=False) -> None:
# mapping of param -> handle for each param that is currently in flight
self.__inflight_param_registry = inflight_param_registry
# keeps track of the number of submodules invoked so far.
Expand Down Expand Up @@ -130,6 +129,10 @@ def __init__(
self.__max_ongoing_fetch_events: int = 2
self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None)

# whether to enable fast fetch for the z3 leaf module.
# this will improve fetch speed but will not break down leaf module parameters to alleviate memory pressure.
self.fast_sharding_for_leaf_module = fast_sharding_for_leaf_module

"""Tracing and Tracking
TODO. consider performing trace before initializing PartitionedParameterCoordinator
and passing trace results into constructor. This way all the code in here can
Expand Down Expand Up @@ -308,6 +311,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
wait_numel = 0
wait_event_name = __class__.FORWARD_FETCH_WAIT if forward else __class__.BACKWARD_FETCH_WAIT
self.__profiler.start_event(wait_event_name)
fast_fetch = self.fast_sharding_for_leaf_module and z3_leaf_module(current_submodule)
# wait for parameters in the immediately needed submodule to become available
for param in params_to_fetch:
param.ds_active_sub_modules.add(current_submodule.id)
Expand All @@ -321,16 +325,18 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events:
self.__ongoing_fetch_events.popleft().synchronize()

self.__inflight_param_registry.pop(param).wait()
self.__inflight_param_registry.pop(param).wait(handle_dependency=not fast_fetch)

if not get_accelerator().handles_memory_backpressure():
if not get_accelerator().handles_memory_backpressure() and not fast_fetch:
event = get_accelerator().Event()
event.record()
self.__ongoing_fetch_events.append(event)

assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().wait_stream(self.__allgather_stream)
if fast_fetch:
AllGatherCoalescedHandle.free_buffer()
self.__profiler.stop_event(wait_event_name, wait_numel)

# kick off parameter prefetches for upcoming modules
Expand Down Expand Up @@ -412,10 +418,20 @@ def release_sub_module(self, submodule: Module) -> None:
be released."""
params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule))))

free_data = not z3_leaf_module(submodule) or not self.fast_sharding_for_leaf_module
if not free_data:
# wait for the computation to finish and launch as early as possible.
empty_buffer = torch.empty(1, device=get_accelerator().current_device())

for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
param.ds_active_sub_modules.discard(submodule.id)
if param.ds_id in params_to_release and not param.is_external_param:
self.__release_param(param)
self.__release_param(param, free_data)
if not free_data:
if param.ds_id in params_to_release and not param.is_external_param:
# empty buffer ensures that all computations are complete
param.data = empty_buffer

@instrument_w_nvtx
@torch.no_grad()
Expand Down Expand Up @@ -490,11 +506,11 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize:

@compiler.disable
@instrument_w_nvtx
def __release_param(self, param: Parameter) -> None:
def __release_param(self, param: Parameter, free_data: bool = True) -> None:
if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules:
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-release: {param.ds_summary()}")
param.partition()
param.partition(free_data=free_data)
self.__n_available_params -= param.ds_numel

@instrument_w_nvtx
Expand Down

0 comments on commit b0040b6

Please sign in to comment.