From 82cacfc2adc278a7656177996c50ca4748dc8c8c Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 10 Dec 2024 11:20:01 -0500 Subject: [PATCH] Use ds-specific module id to avoid conflicts --- deepspeed/runtime/zero/parameter_offload.py | 24 +++++++------ .../zero/partitioned_param_coordinator.py | 24 ++++++------- tests/unit/runtime/zero/test_zero.py | 34 +++++++++++++++++++ 3 files changed, 59 insertions(+), 23 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index f945f5166190..7f6263cc9663 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -240,7 +240,7 @@ def _start_of_forward_hook(module, *args): self.module.register_forward_pre_hook(_start_of_forward_hook) #likely one of them should be enough but just to be safe - self._register_hooks_recursively(self.module) + self._register_deepspeed_module(self.module) # Add top module to stack trace global FWD_MODULE_STACK @@ -266,11 +266,11 @@ def mark_persistent_parameters(self, param_threshold, model_threshold): return persistent_params - def _register_hooks_recursively(self, module, count=[0]): + def _register_deepspeed_module(self, module, count=[0]): my_count = count[0] - module.id = my_count + module.ds_id = my_count - #print(f"{module.__class__} : {module.id}") + #print(f"{module.__class__} : {module.ds_id}") if z3_leaf_module(module): for param in module.parameters(): @@ -278,7 +278,7 @@ def _register_hooks_recursively(self, module, count=[0]): else: for child in module.children(): count[0] = count[0] + 1 - self._register_hooks_recursively(child, count=count) + self._register_deepspeed_module(child, count=count) @instrument_w_nvtx def _pre_forward_module_hook(module, *args): @@ -463,14 +463,16 @@ def pre_sub_module_forward_function(self, sub_module): @torch.no_grad() def post_sub_module_forward_function(self, sub_module): - see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", - force=False) + see_memory_usage( + f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} before release", + force=False) param_coordinator = self.get_param_coordinator() param_coordinator.release_sub_module(sub_module) - see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", - force=False) + see_memory_usage( + f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} after release", + force=False) @torch.no_grad() def pre_sub_module_backward_function(self, sub_module): @@ -485,13 +487,13 @@ def pre_sub_module_backward_function(self, sub_module): def post_sub_module_backward_function(self, sub_module): # assert sub_module.training, "backward pass is invalid for module in evaluation mode" see_memory_usage( - f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} before release", force=False) self.get_param_coordinator().release_sub_module(sub_module) see_memory_usage( - f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} after release", force=False) def _set_z3_leaf_modules_by_threshold(self, module, zero_module_granularity_threshold): diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 596d0e9c20f9..b62f101e3f8f 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -172,17 +172,17 @@ def trace_prologue(self, sub_module: Module) -> None: # sub_module must match expectation else invalidate trace cache if len(self.__submodule_order) <= self.__step_id: print_rank_0( - f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.id}: " + f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.ds_id}: " f"cache has only {len(self.__submodule_order)} modules", force=True) self._invalidate_trace() return if sub_module != self.__submodule_order[self.__step_id]: - expected_module_id = self.__submodule_order[self.__step_id].id + expected_module_id = self.__submodule_order[self.__step_id].ds_id print_rank_0( f"Invalidate trace cache @ step {self.__step_id}: " - f"expected module {expected_module_id}, but got module {sub_module.id}", + f"expected module {expected_module_id}, but got module {sub_module.ds_id}", force=True) self._invalidate_trace() @@ -196,7 +196,7 @@ def record_module(self, sub_module: Module) -> None: raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}") self.__submodule_order.append(sub_module) - self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id) + self.__step_id_module_fetched_for[sub_module.ds_id].append(self.__step_id) def record_parameters(self, sub_module: Module) -> None: if is_compiling(): @@ -205,7 +205,7 @@ def record_parameters(self, sub_module: Module) -> None: if not self.is_record_trace(): raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}") - step_id = self.__step_id_module_fetched_for[sub_module.id].popleft() + step_id = self.__step_id_module_fetched_for[sub_module.ds_id].popleft() for param in sorted(set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))), key=lambda p: p.ds_id): self.__param_order.append(__class__.__ParamInTrace(param=param, step_id_last_used_at=step_id)) @@ -225,7 +225,7 @@ def reset_step(self) -> None: if not self.is_complete_trace(): # not self.trace_complete: # Make sure that recorded submodule orders are identical across ranks - assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order]) + assert_ints_same_as_other_ranks([m.ds_id for m in self.__submodule_order]) if self.is_record_trace(): # Successfully recorded a trace @@ -238,7 +238,7 @@ def reset_step(self) -> None: self.__param_order = tuple(self.__param_order) # freeze self.__trace_mode = ZeRoTraceMode.COMPLETE print_rank_0( - f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.id for m in self.__submodule_order]}", + f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.ds_id for m in self.__submodule_order]}", force=False) else: # Enable trace recording for next forward/backward pass @@ -281,7 +281,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: """ if logger.isEnabledFor(logging.DEBUG): debug_rank0( - f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))]} " + f"{self.__step_id}: M{current_submodule.ds_id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))]} " + str({ "avail": f"{self.__n_available_params:.1e}", "queue_sz": f"{len(self.__param_queue or [])}", @@ -294,7 +294,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: if fetch_numel > 0: event_name = __class__.FORWARD_FETCH_SUBMIT if forward else __class__.BACKWARD_FETCH_SUBMIT - self._dump_param_ids(event_name, current_submodule.id, + self._dump_param_ids(event_name, current_submodule.ds_id, [p.ds_id for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE]) self.__profiler.start_event(event_name) # kick off all gather for params in the immediately required submodule @@ -310,7 +310,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: self.__profiler.start_event(wait_event_name) # wait for parameters in the immediately needed submodule to become available for param in params_to_fetch: - param.ds_active_sub_modules.add(current_submodule.id) + param.ds_active_sub_modules.add(current_submodule.ds_id) if logger.isEnabledFor(logging.DEBUG): debug_rank0(f"-wait: {param.ds_summary()}") if param in self.__inflight_param_registry: @@ -352,7 +352,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: if discarded_from_prefetch_queue != params_not_already_fetched: raise RuntimeError( f"tracing error at step {self.__step_id}: \n" - f"module id: {current_submodule.id}, training: {current_submodule.training}\n" + f"module id: {current_submodule.ds_id}, training: {current_submodule.training}\n" f"expected the next {len(params_not_already_fetched)} parameters in the " f"parameter fetch queue to be {tuple(p.ds_summary(use_debug_name=True) for p in params_not_already_fetched)} \n" f"but got \n {tuple(p.ds_summary(use_debug_name=True) for p in discarded_from_prefetch_queue)}.") @@ -413,7 +413,7 @@ def release_sub_module(self, submodule: Module) -> None: params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set( p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule)))) for param in iter_params(submodule, recurse=z3_leaf_module(submodule)): - param.ds_active_sub_modules.discard(submodule.id) + param.ds_active_sub_modules.discard(submodule.ds_id) if param.ds_id in params_to_release and not param.is_external_param: self.__release_param(param) diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index 5dffd70aab68..d39953f40f07 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -1673,3 +1673,37 @@ def test(self, prefetch_ratio, zero_stage=3): with torch.no_grad(): for batch in data_loader: loss = model(batch[0], batch[1]) + + +# Avoid overwriting client module id +# https://github.com/microsoft/DeepSpeed/issues/6772 +class TestZero3ClientModuleID(DistributedTest): + world_size = 2 + + def test_client_module_id(self): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + }, + "zero_optimization": { + "stage": 3 + }, + } + + class MyModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.id = 3 # ID arbitrary client usage, e.g. GPU placement + self.fc = Linear(128, 128) + + def forward(self, x): + return self.fc(x) + + model = MyModel() + pre_init_m_id = model.id + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + post_init_m_id = model.id + assert pre_init_m_id == post_init_m_id