From 351569dd4a00dea0e00040a816cbc44b1e38a214 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Tue, 5 Nov 2024 14:53:01 -0800 Subject: [PATCH] Use one param coordinator for both train/inference scenarios (#6662) The parameter coordinator in ZeRO3 throws a "backward pass is invalid for module in evaluation mode" error when the training mode is unexpected, as it expects all modules to be in training mode during the backward pass. This is an unnecessarily strict restriction. This PR relaxes the restriction by using a single parameter coordinator (instead of separate ones for training and evaluation modes) and resetting the prefetch state before starting a forward pass. Use of `is_compiling` needs to be fixed after #6663 is merged. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/zero/parameter_offload.py | 57 +++++++++---------- .../zero/partitioned_param_coordinator.py | 12 +++- deepspeed/runtime/zero/stage3.py | 8 +-- tests/unit/runtime/zero/test_zero.py | 45 +++++++++++++++ 4 files changed, 85 insertions(+), 37 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 90afaf03550a..4b0ddb7679a9 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -133,7 +133,6 @@ def __init__( self.persistent_parameters = self.mark_persistent_parameters(self.param_numel_persistence_threshold, self.model_persistence_threshold) - self.param_coordinators = {} self._prefetch_bucket_sz = int(prefetch_bucket_size) self._max_reuse_distance_in_numel = int(max_reuse_distance) self._max_available_parameters_in_numel = int(max_live_parameters) @@ -141,12 +140,21 @@ def __init__( ) if overlap_comm else get_accelerator().default_stream() if not hasattr(module, "ds_inflight_param_registry"): - module.ds_inflight_param_registry = dict() - # we need two registries, one for training and one for eval. They will be used when creating PartitionedParameterCoordinator - module.ds_inflight_param_registry[True] = InflightParamRegistry() - module.ds_inflight_param_registry[False] = InflightParamRegistry() + module.ds_inflight_param_registry = InflightParamRegistry() self.__inflight_param_registry = module.ds_inflight_param_registry + self.param_coordinator = PartitionedParameterCoordinator( + prefetch_bucket_sz=self._prefetch_bucket_sz, + max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, + max_available_parameters_in_numel=self._max_available_parameters_in_numel, + allgather_stream=self.__allgather_stream, + inflight_param_registry=self.__inflight_param_registry, + prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme, + timers=self.timers, + zero_quantized_weights=self.zero_quantized_weights, + zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, + ) + self.forward_hooks = [] self.backward_hooks = [] self.setup_zero_stage3_hooks() @@ -161,26 +169,13 @@ def partition_all_parameters(self): """Partitioning Parameters that were not partitioned usually if parameters of modules whose input parameters do not require grad computation do not trigger post call and will therefore will remain unpartitioned""" - self.get_param_coordinator(training=self.module.training).release_and_reset_all(self.module) + self.get_param_coordinator().release_and_reset_all(self.module) for param in iter_params(self.module, recurse=True): if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: raise RuntimeError(f"{param.ds_summary()} expected to be released") - def get_param_coordinator(self, training): - if not training in self.param_coordinators: - self.param_coordinators[training] = PartitionedParameterCoordinator( - prefetch_bucket_sz=self._prefetch_bucket_sz, - max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, - max_available_parameters_in_numel=self._max_available_parameters_in_numel, - allgather_stream=self.__allgather_stream, - inflight_param_registry=self.__inflight_param_registry[training], - prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme, - timers=self.timers, - zero_quantized_weights=self.zero_quantized_weights, - zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, - ) - - return self.param_coordinators[training] + def get_param_coordinator(self): + return self.param_coordinator def empty_partition_cache(self): self.partition_all_parameters() @@ -228,14 +223,14 @@ def setup_zero_stage3_hooks(self): #reset step if in inference mode @instrument_w_nvtx - def _end_of_forward_hook(module, *args): + def _start_of_forward_hook(module, *args): + + self.get_param_coordinator().reset_step() - if not torch._C.is_grad_enabled(): - self.get_param_coordinator(training=False).reset_step() + self.module.register_forward_pre_hook(_start_of_forward_hook) #likely one of them should be enough but just to be safe self._register_hooks_recursively(self.module) - self.module.register_forward_hook(_end_of_forward_hook) # Add top module to stack trace global FWD_MODULE_STACK @@ -447,7 +442,7 @@ def pre_sub_module_forward_function(self, sub_module): global FWD_MODULE_STACK FWD_MODULE_STACK.append(sub_module) - param_coordinator = self.get_param_coordinator(training=sub_module.training) + param_coordinator = self.get_param_coordinator() param_coordinator.trace_prologue(sub_module) if param_coordinator.is_record_trace(): param_coordinator.record_module(sub_module) @@ -460,7 +455,7 @@ def post_sub_module_forward_function(self, sub_module): see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", force=False) - param_coordinator = self.get_param_coordinator(training=sub_module.training) + param_coordinator = self.get_param_coordinator() param_coordinator.release_sub_module(sub_module) see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", @@ -468,8 +463,8 @@ def post_sub_module_forward_function(self, sub_module): @torch.no_grad() def pre_sub_module_backward_function(self, sub_module): - assert sub_module.training, "backward pass is invalid for module in evaluation mode" - param_coordinator = self.get_param_coordinator(training=True) + # assert sub_module.training, "backward pass is invalid for module in evaluation mode" + param_coordinator = self.get_param_coordinator() param_coordinator.trace_prologue(sub_module) if param_coordinator.is_record_trace(): param_coordinator.record_module(sub_module) @@ -477,12 +472,12 @@ def pre_sub_module_backward_function(self, sub_module): @torch.no_grad() def post_sub_module_backward_function(self, sub_module): - assert sub_module.training, "backward pass is invalid for module in evaluation mode" + # assert sub_module.training, "backward pass is invalid for module in evaluation mode" see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", force=False) - self.get_param_coordinator(training=True).release_sub_module(sub_module) + self.get_param_coordinator().release_sub_module(sub_module) see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 5780b2afd6de..49f477cc4a1b 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -18,6 +18,7 @@ from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id from deepspeed.accelerator import get_accelerator import deepspeed.runtime.compiler as compiler +from deepspeed.runtime.compiler import is_compiling import logging @@ -92,7 +93,7 @@ def __init__( # keeps track of the number of submodules invoked so far. self.__step_id: int = 0 # network tracing mode - self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.RECORD + self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.INVALID # sequence of submodules/parameters in forward pass + backward pass self.__submodule_order: Iterable[Module] = [] self.__param_order: Iterable[__class__.__ParamInTrace] = [] @@ -188,6 +189,9 @@ def trace_prologue(self, sub_module: Module) -> None: @compiler.disable def record_module(self, sub_module: Module) -> None: """adds sub module to trace""" + if is_compiling(): + return + if not self.is_record_trace(): raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}") @@ -195,6 +199,8 @@ def record_module(self, sub_module: Module) -> None: self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id) def record_parameters(self, sub_module: Module) -> None: + if is_compiling(): + return """adds sub module to trace""" if not self.is_record_trace(): raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}") @@ -209,8 +215,12 @@ def construct_parameter_trace_from_module_trace(self): for sub_module in self.__submodule_order: self.record_parameters(sub_module) + @compiler.disable def reset_step(self) -> None: """indicate that we have completed one fwd+bwd for the model""" + if is_compiling(): + return + self._clean_inflight_param_registry() if not self.is_complete_trace(): # not self.trace_complete: diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 65460eb72a2f..2c0c9d498d13 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -593,8 +593,8 @@ def defragment(tensors: List[Tensor]) -> Tensor: return device_buffer - def _get_param_coordinator(self, training): - return self.parameter_offload.get_param_coordinator(training) + def _get_param_coordinator(self): + return self.parameter_offload.get_param_coordinator() def _configure_offloading(self, offload_optimizer_config, offload_param_config): ###################### offload optimizer setup ################################## @@ -1874,7 +1874,7 @@ def _pre_step(self): see_memory_usage(f"In step before checking overflow", force=False) print_rank_0("Finished Tracing at Beginning of Step") - self._get_param_coordinator(training=True).hierarchy = 0 + self._get_param_coordinator().hierarchy = 0 print_rank_0("Finished Tracing at Beginning of Step") @@ -2258,8 +2258,6 @@ def backward(self, loss, retain_graph=False): else: self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) - self._get_param_coordinator(training=True).reset_step() - if self.swap_optimizer: self.optimizer_swapper.post_backward() diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index 7262a1b2c998..5dffd70aab68 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -1628,3 +1628,48 @@ def test_empty_param_groups(self, dtype, use_client_optimizer, empty_weight_grou optimizer=optimizer, config=config_dict, ) + + +class TestZero3SwitchModes(DistributedTest): + world_size = 2 + + @pytest.mark.parametrize("prefetch_ratio", [0.0, 0.5, 1.0]) + def test(self, prefetch_ratio, zero_stage=3): + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + + prefetch_bucket_size = int(sum([p.numel() for p in model.parameters(recurse=True)]) * prefetch_ratio) + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 2, + "zero_optimization": { + "stage": zero_stage, + "stage3_prefetch_bucket_size": prefetch_bucket_size + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + } + } + + model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) + + for _ in range(3): + model.train() + for batch in data_loader: + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + model.eval() + with torch.no_grad(): + for batch in data_loader: + loss = model(batch[0], batch[1])