Skip to content

Commit

Permalink
opt loop
Browse files Browse the repository at this point in the history
  • Loading branch information
inkcherry committed Nov 12, 2024
1 parent 25df962 commit 663e637
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 20 deletions.
37 changes: 23 additions & 14 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def __init__(self, handle, param: Parameter, quantization=None) -> None:
self.__param = param
self.__quantization = quantization

def wait(self) -> None:
def wait(self, handle_dependency=True) -> None:
instrument_w_nvtx(self.__handle.wait)()
if self.__quantization:
instrument_w_nvtx(self.__quantization.quant_handle.wait)()
Expand All @@ -650,6 +650,8 @@ def wait(self) -> None:

class AllGatherCoalescedHandle:

data_buffer = []

def __init__(
self,
allgather_handle,
Expand All @@ -672,7 +674,7 @@ def __init__(
raise RuntimeError(f"expected param {param.ds_summary()} to not be available")

@instrument_w_nvtx
def wait(self) -> None:
def wait(self, handle_dependency=True) -> None:
if self.complete:
return

Expand Down Expand Up @@ -704,24 +706,30 @@ def wait(self) -> None:
partitions.append(part_to_copy)
param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape)
param.ds_status = ZeroParamStatus.AVAILABLE

for part_to_copy in partitions:
if not get_accelerator().is_synchronized_device():
part_to_copy.record_stream(get_accelerator().current_stream())
if handle_dependency:
for part_to_copy in partitions:
if not get_accelerator().is_synchronized_device():
part_to_copy.record_stream(get_accelerator().current_stream())

param_offset += ds_tensor_numel

self.complete = True
if not handle_dependency:
AllGatherCoalescedHandle.data_buffer.append(partitions)

@staticmethod
def free_buffer():
AllGatherCoalescedHandle.data_buffer = []


class MultipleAllGatherHandles:

def __init__(self, handles: List[AllGatherCoalescedHandle]):
self.handles = handles

def wait(self) -> None:
def wait(self, handle_dependency=True) -> None:
for handle in self.handles:
handle.wait()
handle.wait(handle_dependency)


class AllReduceCoalescedHandle:
Expand Down Expand Up @@ -1377,13 +1385,13 @@ def all_gather_coalesced(params: Iterable[Parameter],
quantization=quant_info,
)

def partition(param_list=None, hierarchy=0, has_been_updated=False):
def partition(param_list=None, hierarchy=0, has_been_updated=False, free_data=True):
cls = param
print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}",
force=False)
if param_list is None:
param_list = [cls]
self._partition(param_list, has_been_updated=has_been_updated)
self._partition(param_list, has_been_updated=has_been_updated, free_data=True)

def reduce_gradients_at_owner(param_list=None, hierarchy=0):
cls = param
Expand Down Expand Up @@ -1527,20 +1535,20 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None):

return handles

def _partition(self, param_list, force=False, has_been_updated=False):
def _partition(self, param_list, force=False, has_been_updated=False, free_data=True):
for param in param_list:
print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False)
if self.zero_param_process_group is not None:
self._partition_param_sec(param)
self._partition_param(param, has_been_updated=has_been_updated)
self._partition_param(param, has_been_updated=has_been_updated, free_data=True)

param.ds_status = ZeroParamStatus.NOT_AVAILABLE
# if param.ds_tensor is not None:
# assert id(param.data) == id(param.ds_tensor.data), \
# "After the parameters are initially partitioned, make sure we are not recreating the partition."
#print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False)
@instrument_w_nvtx
def _partition_param(self, param, buffer=None, has_been_updated=False):
def _partition_param(self, param, buffer=None, has_been_updated=False, free_data=True):
assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
global reuse_buffers
print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False)
Expand All @@ -1565,7 +1573,8 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):

see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False)
# param.data does not store anything meaningful in partitioned state
free_param(param)
if free_data:
free_param(param)
see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False)

if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
Expand Down
22 changes: 17 additions & 5 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,16 +322,19 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events:
self.__ongoing_fetch_events.popleft().synchronize()

self.__inflight_param_registry.pop(param).wait()
self.__inflight_param_registry.pop(param).wait(
handle_dependency=not z3_leaf_module(current_submodule))

if not get_accelerator().handles_memory_backpressure():
if not get_accelerator().handles_memory_backpressure() and not z3_leaf_module(current_submodule):
event = get_accelerator().Event()
event.record()
self.__ongoing_fetch_events.append(event)

assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().wait_stream(self.__allgather_stream)
if z3_leaf_module(current_submodule):
AllGatherCoalescedHandle.free_buffer()
self.__profiler.stop_event(wait_event_name, wait_numel)

# kick off parameter prefetches for upcoming modules
Expand Down Expand Up @@ -414,9 +417,18 @@ def release_sub_module(self, submodule: Module) -> None:
params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule))))
for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
free_data = not z3_leaf_module(submodule)
if not free_data:
# wait for the computation to finish and launch as early as possible.
empty_buffer = torch.empty(1, dtype=param.dtype, device=param.device)
param.ds_active_sub_modules.discard(submodule.id)
if param.ds_id in params_to_release and not param.is_external_param:
self.__release_param(param)
self.__release_param(param, free_data=True)
if not free_data:
if param.ds_id in params_to_release and not param.is_external_param:
# empty buffer ensures that all computations are complete
# and is used for synchronization
param.data = empty_buffer

@instrument_w_nvtx
@torch.no_grad()
Expand Down Expand Up @@ -491,11 +503,11 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize:

@compiler.disable
@instrument_w_nvtx
def __release_param(self, param: Parameter) -> None:
def __release_param(self, param: Parameter, free_data: bool = True) -> None:
if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules:
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-release: {param.ds_summary()}")
param.partition()
param.partition(free_data=free_data)
self.__n_available_params -= param.ds_numel

@instrument_w_nvtx
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/runtime/zero/test_zero_leaf_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_finegrained_optimization(self, module_granularity_threshold: int):
},
"zero_optimization": {
"stage": 3,
"stage3_prefetch_bucket_size": hidden_dim**2,
"stage3_prefetch_bucket_size": hidden_dim,
"stage3_param_persistence_threshold": 0,
"stage3_max_reuse_distance": 0,
}
Expand Down

0 comments on commit 663e637

Please sign in to comment.