From 9068acb6fbbdbaae5429bb89e507977128496bc5 Mon Sep 17 00:00:00 2001 From: jiahao su Date: Tue, 5 Nov 2024 01:49:21 +0800 Subject: [PATCH 1/8] Update URL in README Pipeline Status for Huawei Ascend NPU (#6706) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b302e32dfd9c..31051297d10c 100755 --- a/README.md +++ b/README.md @@ -142,7 +142,7 @@ DeepSpeed has been integrated with several different popular open-source DL fram | PyTorch Nightly | [![nv-torch-nightly-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml) | | Integrations | [![nv-transformers-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-transformers-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-transformers-v100.yml) [![nv-lightning-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-lightning-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-lightning-v100.yml) [![nv-accelerate-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-accelerate-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-accelerate-v100.yml) [![nv-mii](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-mii.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-mii.yml) [![nv-ds-chat](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-ds-chat.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-ds-chat.yml) [![nv-sd](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-sd.yml/badge.svg)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-sd.yml) | | Misc | [![Formatting](https://github.com/microsoft/DeepSpeed/actions/workflows/formatting.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/formatting.yml) [![pages-build-deployment](https://github.com/microsoft/DeepSpeed/actions/workflows/pages/pages-build-deployment/badge.svg)](https://github.com/microsoft/DeepSpeed/actions/workflows/pages/pages-build-deployment) [![Documentation Status](https://readthedocs.org/projects/deepspeed/badge/?version=latest)](https://deepspeed.readthedocs.io/en/latest/?badge=latest)[![python](https://github.com/microsoft/DeepSpeed/actions/workflows/python.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/python.yml) | -| Huawei Ascend NPU | [![Huawei Ascend NPU](https://github.com/cosdt/DeepSpeed/actions/workflows/huawei-ascend-npu.yml/badge.svg?branch=master)](https://github.com/cosdt/DeepSpeed/actions/workflows/huawei-ascend-npu.yml) | +| Huawei Ascend NPU | [![Huawei Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/deepspeed.yaml/badge.svg?branch=main)](https://github.com/Ascend/Ascend-CI/actions/workflows/deepspeed.yaml) | # Installation From 6c08b7f932bc1c7acaec142c720f6f9a82e9e5c8 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Mon, 4 Nov 2024 12:51:01 -0800 Subject: [PATCH 2/8] Pin transformers to 4.45.2 in nv-ds-chat workflow (#6710) This commit causes breaking changes we need to fix, for now we will pin the version but we will fix shortly https://github.com/huggingface/transformers/pull/33325 --- .github/workflows/nv-ds-chat.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/nv-ds-chat.yml b/.github/workflows/nv-ds-chat.yml index 2ad336cac4ed..329a1060f5eb 100644 --- a/.github/workflows/nv-ds-chat.yml +++ b/.github/workflows/nv-ds-chat.yml @@ -12,6 +12,7 @@ on: type: string pull_request: paths: + - ".github/workflows/nv-ds-chat.yml" - "deepspeed/runtime/zero/stage_1_and_2.py" - "deepspeed/runtime/zero/stage3.py" - "deepspeed/runtime/hybrid_engine.py" @@ -42,6 +43,7 @@ jobs: - name: Install deepspeed run: | + pip install transformers==4.45.2 pip install .[dev] ds_report From 2b41d6212c160a3645691b77b210ba7dd957c23f Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Mon, 4 Nov 2024 13:51:27 -0800 Subject: [PATCH 3/8] [Bug Fix] Support threads_per_head < 64 for wavefront size of 64 (#6622) When launching apply_rotary_pos_half kernel, only threads_per_head of 64 is supported for wavefront size of 64. This change adds support for threads_per_head < 64 such as 4, 8, 16. Fixes the issue introduced in https://github.com/microsoft/DeepSpeed/pull/5402 --------- Signed-off-by: Jagadish Krishnamoorthy Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Logan Adams --- .../inference/csrc/apply_rotary_pos_emb.cu | 10 ++++- .../ops/transformer/inference/test_rope.py | 38 +++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 tests/unit/ops/transformer/inference/test_rope.py diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index 25a494111c54..bbb8a7f00b1f 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -101,7 +101,15 @@ __global__ void apply_rotary_pos_half(T* mixed_query, #if defined(__HIP_PLATFORM_AMD__) and ROCM_WAVEFRONT_SIZE == 64 #define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ - if (threads_per_head == 64) { \ + if (threads_per_head == 4) { \ + LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \ + } else if (threads_per_head == 8) { \ + LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \ + } else if (threads_per_head == 16) { \ + LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \ + } else if (threads_per_head == 32) { \ + LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \ + } else if (threads_per_head == 64) { \ LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \ } else { \ assert(false); \ diff --git a/tests/unit/ops/transformer/inference/test_rope.py b/tests/unit/ops/transformer/inference/test_rope.py new file mode 100644 index 000000000000..1f0ca0578e04 --- /dev/null +++ b/tests/unit/ops/transformer/inference/test_rope.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.accelerator import get_accelerator + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("Inference ops are not available on this system", allow_module_level=True) + + +@pytest.mark.inference_ops +@pytest.mark.parametrize("num_heads", [64, 32, 16, 8]) +def test_rope_warp_size_alignment(num_heads): + if get_accelerator().device_name() != "cuda": + pytest.skip("This test runs only on GPU") + + batch = 1 + head = 8 + seq_len = 1024 + head_dim = 32 + rotary_dim = 32 + offset = 8 + rotate_half = False + rope_theta = 2 + + cuda0 = torch.device('cuda:0') + query = torch.randn(batch, head, seq_len, head_dim, device=cuda0) + key = torch.randn(batch, head, seq_len, head_dim, device=cuda0) + + inference = InferenceBuilder().load() + # For num_heads values of 64, 32, 16, 8 + # corresponding threads_per_head (defined in apply_rotary_pos_emb.cu) values are 4, 8, 16, 32 + inference.apply_rotary_pos_emb(query, key, rotary_dim, offset, num_heads, rotate_half, rope_theta) From 351569dd4a00dea0e00040a816cbc44b1e38a214 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Tue, 5 Nov 2024 14:53:01 -0800 Subject: [PATCH 4/8] Use one param coordinator for both train/inference scenarios (#6662) The parameter coordinator in ZeRO3 throws a "backward pass is invalid for module in evaluation mode" error when the training mode is unexpected, as it expects all modules to be in training mode during the backward pass. This is an unnecessarily strict restriction. This PR relaxes the restriction by using a single parameter coordinator (instead of separate ones for training and evaluation modes) and resetting the prefetch state before starting a forward pass. Use of `is_compiling` needs to be fixed after #6663 is merged. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/zero/parameter_offload.py | 57 +++++++++---------- .../zero/partitioned_param_coordinator.py | 12 +++- deepspeed/runtime/zero/stage3.py | 8 +-- tests/unit/runtime/zero/test_zero.py | 45 +++++++++++++++ 4 files changed, 85 insertions(+), 37 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 90afaf03550a..4b0ddb7679a9 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -133,7 +133,6 @@ def __init__( self.persistent_parameters = self.mark_persistent_parameters(self.param_numel_persistence_threshold, self.model_persistence_threshold) - self.param_coordinators = {} self._prefetch_bucket_sz = int(prefetch_bucket_size) self._max_reuse_distance_in_numel = int(max_reuse_distance) self._max_available_parameters_in_numel = int(max_live_parameters) @@ -141,12 +140,21 @@ def __init__( ) if overlap_comm else get_accelerator().default_stream() if not hasattr(module, "ds_inflight_param_registry"): - module.ds_inflight_param_registry = dict() - # we need two registries, one for training and one for eval. They will be used when creating PartitionedParameterCoordinator - module.ds_inflight_param_registry[True] = InflightParamRegistry() - module.ds_inflight_param_registry[False] = InflightParamRegistry() + module.ds_inflight_param_registry = InflightParamRegistry() self.__inflight_param_registry = module.ds_inflight_param_registry + self.param_coordinator = PartitionedParameterCoordinator( + prefetch_bucket_sz=self._prefetch_bucket_sz, + max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, + max_available_parameters_in_numel=self._max_available_parameters_in_numel, + allgather_stream=self.__allgather_stream, + inflight_param_registry=self.__inflight_param_registry, + prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme, + timers=self.timers, + zero_quantized_weights=self.zero_quantized_weights, + zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, + ) + self.forward_hooks = [] self.backward_hooks = [] self.setup_zero_stage3_hooks() @@ -161,26 +169,13 @@ def partition_all_parameters(self): """Partitioning Parameters that were not partitioned usually if parameters of modules whose input parameters do not require grad computation do not trigger post call and will therefore will remain unpartitioned""" - self.get_param_coordinator(training=self.module.training).release_and_reset_all(self.module) + self.get_param_coordinator().release_and_reset_all(self.module) for param in iter_params(self.module, recurse=True): if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: raise RuntimeError(f"{param.ds_summary()} expected to be released") - def get_param_coordinator(self, training): - if not training in self.param_coordinators: - self.param_coordinators[training] = PartitionedParameterCoordinator( - prefetch_bucket_sz=self._prefetch_bucket_sz, - max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, - max_available_parameters_in_numel=self._max_available_parameters_in_numel, - allgather_stream=self.__allgather_stream, - inflight_param_registry=self.__inflight_param_registry[training], - prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme, - timers=self.timers, - zero_quantized_weights=self.zero_quantized_weights, - zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, - ) - - return self.param_coordinators[training] + def get_param_coordinator(self): + return self.param_coordinator def empty_partition_cache(self): self.partition_all_parameters() @@ -228,14 +223,14 @@ def setup_zero_stage3_hooks(self): #reset step if in inference mode @instrument_w_nvtx - def _end_of_forward_hook(module, *args): + def _start_of_forward_hook(module, *args): + + self.get_param_coordinator().reset_step() - if not torch._C.is_grad_enabled(): - self.get_param_coordinator(training=False).reset_step() + self.module.register_forward_pre_hook(_start_of_forward_hook) #likely one of them should be enough but just to be safe self._register_hooks_recursively(self.module) - self.module.register_forward_hook(_end_of_forward_hook) # Add top module to stack trace global FWD_MODULE_STACK @@ -447,7 +442,7 @@ def pre_sub_module_forward_function(self, sub_module): global FWD_MODULE_STACK FWD_MODULE_STACK.append(sub_module) - param_coordinator = self.get_param_coordinator(training=sub_module.training) + param_coordinator = self.get_param_coordinator() param_coordinator.trace_prologue(sub_module) if param_coordinator.is_record_trace(): param_coordinator.record_module(sub_module) @@ -460,7 +455,7 @@ def post_sub_module_forward_function(self, sub_module): see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", force=False) - param_coordinator = self.get_param_coordinator(training=sub_module.training) + param_coordinator = self.get_param_coordinator() param_coordinator.release_sub_module(sub_module) see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", @@ -468,8 +463,8 @@ def post_sub_module_forward_function(self, sub_module): @torch.no_grad() def pre_sub_module_backward_function(self, sub_module): - assert sub_module.training, "backward pass is invalid for module in evaluation mode" - param_coordinator = self.get_param_coordinator(training=True) + # assert sub_module.training, "backward pass is invalid for module in evaluation mode" + param_coordinator = self.get_param_coordinator() param_coordinator.trace_prologue(sub_module) if param_coordinator.is_record_trace(): param_coordinator.record_module(sub_module) @@ -477,12 +472,12 @@ def pre_sub_module_backward_function(self, sub_module): @torch.no_grad() def post_sub_module_backward_function(self, sub_module): - assert sub_module.training, "backward pass is invalid for module in evaluation mode" + # assert sub_module.training, "backward pass is invalid for module in evaluation mode" see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", force=False) - self.get_param_coordinator(training=True).release_sub_module(sub_module) + self.get_param_coordinator().release_sub_module(sub_module) see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 5780b2afd6de..49f477cc4a1b 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -18,6 +18,7 @@ from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id from deepspeed.accelerator import get_accelerator import deepspeed.runtime.compiler as compiler +from deepspeed.runtime.compiler import is_compiling import logging @@ -92,7 +93,7 @@ def __init__( # keeps track of the number of submodules invoked so far. self.__step_id: int = 0 # network tracing mode - self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.RECORD + self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.INVALID # sequence of submodules/parameters in forward pass + backward pass self.__submodule_order: Iterable[Module] = [] self.__param_order: Iterable[__class__.__ParamInTrace] = [] @@ -188,6 +189,9 @@ def trace_prologue(self, sub_module: Module) -> None: @compiler.disable def record_module(self, sub_module: Module) -> None: """adds sub module to trace""" + if is_compiling(): + return + if not self.is_record_trace(): raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}") @@ -195,6 +199,8 @@ def record_module(self, sub_module: Module) -> None: self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id) def record_parameters(self, sub_module: Module) -> None: + if is_compiling(): + return """adds sub module to trace""" if not self.is_record_trace(): raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}") @@ -209,8 +215,12 @@ def construct_parameter_trace_from_module_trace(self): for sub_module in self.__submodule_order: self.record_parameters(sub_module) + @compiler.disable def reset_step(self) -> None: """indicate that we have completed one fwd+bwd for the model""" + if is_compiling(): + return + self._clean_inflight_param_registry() if not self.is_complete_trace(): # not self.trace_complete: diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 65460eb72a2f..2c0c9d498d13 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -593,8 +593,8 @@ def defragment(tensors: List[Tensor]) -> Tensor: return device_buffer - def _get_param_coordinator(self, training): - return self.parameter_offload.get_param_coordinator(training) + def _get_param_coordinator(self): + return self.parameter_offload.get_param_coordinator() def _configure_offloading(self, offload_optimizer_config, offload_param_config): ###################### offload optimizer setup ################################## @@ -1874,7 +1874,7 @@ def _pre_step(self): see_memory_usage(f"In step before checking overflow", force=False) print_rank_0("Finished Tracing at Beginning of Step") - self._get_param_coordinator(training=True).hierarchy = 0 + self._get_param_coordinator().hierarchy = 0 print_rank_0("Finished Tracing at Beginning of Step") @@ -2258,8 +2258,6 @@ def backward(self, loss, retain_graph=False): else: self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) - self._get_param_coordinator(training=True).reset_step() - if self.swap_optimizer: self.optimizer_swapper.post_backward() diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index 7262a1b2c998..5dffd70aab68 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -1628,3 +1628,48 @@ def test_empty_param_groups(self, dtype, use_client_optimizer, empty_weight_grou optimizer=optimizer, config=config_dict, ) + + +class TestZero3SwitchModes(DistributedTest): + world_size = 2 + + @pytest.mark.parametrize("prefetch_ratio", [0.0, 0.5, 1.0]) + def test(self, prefetch_ratio, zero_stage=3): + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + + prefetch_bucket_size = int(sum([p.numel() for p in model.parameters(recurse=True)]) * prefetch_ratio) + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 2, + "zero_optimization": { + "stage": zero_stage, + "stage3_prefetch_bucket_size": prefetch_bucket_size + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + } + } + + model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) + + for _ in range(3): + model.train() + for batch in data_loader: + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + model.eval() + with torch.no_grad(): + for batch in data_loader: + loss = model(batch[0], batch[1]) From d2a4718946b544ab5d4f334f05a4ace9670e3ddd Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Wed, 6 Nov 2024 10:57:12 -0800 Subject: [PATCH 5/8] Update yapf version (#6721) This update is needed to support eventually running on ubuntu-24.04 from GitHub, specifically because the python version is updated to 3.12 and results in the following error: `ModuleNotFoundError: No module named 'lib2to3'` since that package is deprecated. --- .pre-commit-config.yaml | 2 +- deepspeed/__init__.py | 4 ++-- deepspeed/autotuning/autotuner.py | 12 ++++++------ deepspeed/elasticity/elastic_agent.py | 4 ++-- deepspeed/module_inject/replace_module.py | 7 ++++--- deepspeed/runtime/config.py | 4 ++-- deepspeed/runtime/eigenvalue.py | 4 ++-- deepspeed/runtime/pipe/engine.py | 7 ++++--- deepspeed/runtime/utils.py | 4 ++-- tests/unit/runtime/zero/test_zero_context.py | 6 +++--- 10 files changed, 28 insertions(+), 26 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b89c872eefe5..e249411e4868 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/google/yapf - rev: v0.32.0 + rev: v0.40.0 hooks: - id: yapf diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index d8655299282f..de405dc40edb 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -165,8 +165,8 @@ def initialize(args=None, if hasattr(args, "deepscale_config") and args.deepscale_config is not None: logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************") if hasattr(args, "deepspeed_config"): - assert (args.deepspeed_config is - None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" + assert (args.deepspeed_config + is None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" args.deepspeed_config = args.deepscale_config args.deepscale_config = None diff --git a/deepspeed/autotuning/autotuner.py b/deepspeed/autotuning/autotuner.py index dfd195bc37eb..a72b3c951e97 100755 --- a/deepspeed/autotuning/autotuner.py +++ b/deepspeed/autotuning/autotuner.py @@ -248,8 +248,8 @@ def mp_size(self): return self.autotuning_config.mp_size def max_train_micro_batch_size_per_gpu(self): - if self.max_train_batch_size( - ) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size + if self.max_train_batch_size() and self.max_train_batch_size( + ) > 0: # if the user specifies a max_train_batch_size max_train_micro_batch_size = self.max_train_batch_size() * self.mp_size() // ( self.exp_num_gpus * self.exp_num_nodes) # gradient accumulation steps >=1 return min(self.autotuning_config.max_train_micro_batch_size_per_gpu, max_train_micro_batch_size) @@ -964,8 +964,8 @@ def get_min_max_micro_batch_size(self, stage, min_micro_batch_size, calculated_m low = mid + 1 self.update_records(tuning_space_name, exp, metric_val, 1) used_micro_batch_sizes.append(mid) - if prev_metric_val and ( - (metric_val - prev_metric_val) / prev_metric_val) < METRIC_PERCENT_DIFF_CONST: + if prev_metric_val and ((metric_val - prev_metric_val) / + prev_metric_val) < METRIC_PERCENT_DIFF_CONST: logger.info(f"performance plateaus at mbs = {low}") break prev_metric_val = metric_val @@ -1026,8 +1026,8 @@ def get_tuning_micro_batch_size_list(self, min_micro_batch_size, max_micro_batch # NUM_GPUS=$(( ${NUM_WORKERS} * ${NUM_GPUS_PER_WORKER} )) # DP_SIZE=$(( ${NUM_GPUS} / (${PP_SIZE} * ${MP_SIZE}) )) # GRAD_ACC_STEPS=$(( ${TARGET_GLOBAL_BATCH_SIZE} / (${BATCH_SIZE} * ${DP_SIZE}) )) - if self.max_train_batch_size( - ) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size + if self.max_train_batch_size() and self.max_train_batch_size( + ) > 0: # if the user specifies a max_train_batch_size max_train_batch_size_per_gpu = self.max_train_batch_size() * self.mp_size() // (self.exp_num_gpus * self.exp_num_nodes) else: diff --git a/deepspeed/elasticity/elastic_agent.py b/deepspeed/elasticity/elastic_agent.py index c6a69dd2a49f..8fd4293d312c 100644 --- a/deepspeed/elasticity/elastic_agent.py +++ b/deepspeed/elasticity/elastic_agent.py @@ -160,8 +160,8 @@ def _invoke_run(self, role: str = "default") -> RunResult: f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish.") self._exit_barrier() return run_result - elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED - } or len(participants) > len(rdzv_handler._state_holder.state.participants): + elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED} or len(participants) > len( + rdzv_handler._state_holder.state.participants): if self._remaining_restarts > 0: log.info(f"[{role}] Worker group {state.name}. " f"{self._remaining_restarts}/{spec.max_restarts} attempts left;" diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 8b1455f20c69..1c5745dcf168 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -496,9 +496,10 @@ def conv2d_parallel_shard_weights(model, rank, world_size): if not dist.is_initialized() or dist.get_rank() == 0: print("Saving tp-sharded checkpoints") torch.save( - OrderedDict({k: v - for k, v in dict(replaced_module.state_dict()).items() - if transformer_name not in k}), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}') + OrderedDict({ + k: v + for k, v in dict(replaced_module.state_dict()).items() if transformer_name not in k + }), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}') dtype_reprs = { torch.float32: 'float32', diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 8be2f7ac4055..fb786f29722d 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -1012,8 +1012,8 @@ def _do_error_check(self): self.gradient_accumulation_steps), "DeepSpeedConfig: {} is not defined".format(GRADIENT_ACCUMULATION_STEPS) if self.zero_enabled: - assert (self.zero_optimization_stage <= - ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format( + assert (self.zero_optimization_stage + <= ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format( ZeroStageEnum.max_stage) if self.fp16_master_weights_and_gradients: diff --git a/deepspeed/runtime/eigenvalue.py b/deepspeed/runtime/eigenvalue.py index 36300eb904dd..a82d8b1d5c7a 100755 --- a/deepspeed/runtime/eigenvalue.py +++ b/deepspeed/runtime/eigenvalue.py @@ -114,8 +114,8 @@ def compute_eigenvalue(self, module, device=None, scale=1.0): eigenvalue_current, eigenvalue_previous = 1., 0. while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs( - (eigenvalue_current - eigenvalue_previous) / eigenvalue_current) >= - self.tol): # test convergence criteria + (eigenvalue_current - eigenvalue_previous) / eigenvalue_current) + >= self.tol): # test convergence criteria eigenvalue_previous = eigenvalue_current Hv = torch.autograd.grad(grads, params, grad_outputs=v, only_inputs=True, retain_graph=True) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index b75270cbd306..deb44c2e71eb 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -640,9 +640,10 @@ def _aggregate_total_loss(self): self.dp_group_loss = losses[0].clone().detach() agg_loss = losses[1].clone().detach() if additional_losses is not None: - self.agg_additional_losses = OrderedDict( - {name: losses[2 + i].clone().detach() - for i, name in enumerate(additional_losses.keys())}) + self.agg_additional_losses = OrderedDict({ + name: losses[2 + i].clone().detach() + for i, name in enumerate(additional_losses.keys()) + }) return agg_loss def set_dataloader(self, loader): diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index b9617d3e632f..f48adb58c9bf 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -257,8 +257,8 @@ def has_overflow(self, params, has_moe_params=None): elif self.mpu is not None: if self.deepspeed is not None: using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') - if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce is False) or ( - not using_pipeline and self.deepspeed.enable_backward_allreduce is False): + if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce + is False) or (not using_pipeline and self.deepspeed.enable_backward_allreduce is False): dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_data_parallel_group()) dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group()) elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False: diff --git a/tests/unit/runtime/zero/test_zero_context.py b/tests/unit/runtime/zero/test_zero_context.py index ec9e9e94aeaf..1d4fcd60022c 100644 --- a/tests/unit/runtime/zero/test_zero_context.py +++ b/tests/unit/runtime/zero/test_zero_context.py @@ -218,9 +218,9 @@ def test_throughput_calculation(self): engine.tput_timer.stop(global_step=global_step) duration = engine.tput_timer.end_time - engine.tput_timer.start_time # step elapsed time is reset after gradient accumulation steps - assert engine.tput_timer.step_elapsed_time == ( - 0 if engine.tput_timer.global_step_count != engine.tput_timer.start_step else current_duration + - duration) + assert engine.tput_timer.step_elapsed_time == (0 if engine.tput_timer.global_step_count + != engine.tput_timer.start_step else current_duration + + duration) assert engine.tput_timer.total_elapsed_time == total_duration + duration def test_ext_param_getattr(self): From 3beda32e94854e61a138c985e99bde2b3288b1d7 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Wed, 6 Nov 2024 15:17:48 -0800 Subject: [PATCH 6/8] Update flake8 version (#6722) This PR is useful for updating the flake8 checks we run, but is mostly needed to update flake8 so that it can run on newer versions of python which are included in newer ubuntu-latest versions from GitHub that we update to in #6717 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e249411e4868..b5d8afa8e0b4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -65,7 +65,7 @@ repos: ] - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + rev: 5.0.4 hooks: - id: flake8 args: ['--config=.flake8'] From a1b0c35a1def4bfc20fc3eeb89d6f5831fbc4ae8 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Wed, 6 Nov 2024 20:37:52 -0800 Subject: [PATCH 7/8] Switch what versions of python are supported (#5676) Add support for testing compilation with python 3.11/3.12. Also add the dockerfiles used to build those images. --------- Co-authored-by: Michael Wyatt --- .github/workflows/python.yml | 4 +-- docker/gh-builder/Dockerfile.py311 | 35 +++++++++++++++++++++ docker/gh-builder/Dockerfile.py312 | 35 +++++++++++++++++++++ docker/{Dockerfile.rocm => rocm/Dockerfile} | 0 4 files changed, 72 insertions(+), 2 deletions(-) create mode 100644 docker/gh-builder/Dockerfile.py311 create mode 100644 docker/gh-builder/Dockerfile.py312 rename docker/{Dockerfile.rocm => rocm/Dockerfile} (100%) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 3103e3f36e84..37b68f1dbe80 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -21,10 +21,10 @@ jobs: unit-tests: strategy: matrix: - pyVersion: ["3.7", "3.8", "3.9", "3.10"] + pyVersion: ["3.8", "3.9", "3.10", "3.11", "3.12"] fail-fast: false - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 container: image: deepspeed/gh-builder:py${{ matrix.pyVersion }} diff --git a/docker/gh-builder/Dockerfile.py311 b/docker/gh-builder/Dockerfile.py311 new file mode 100644 index 000000000000..603fb614314f --- /dev/null +++ b/docker/gh-builder/Dockerfile.py311 @@ -0,0 +1,35 @@ +# Start with NGC container +FROM nvcr.io/nvidia/pytorch:24.03-py3 + +# Set noninteractive mode for apt-get +ARG DEBIAN_FRONTEND=noninteractive + +# Install necessary dependencies for building Python +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + libssl-dev \ + zlib1g-dev \ + libbz2-dev \ + libreadline-dev \ + libsqlite3-dev \ + curl \ + libncursesw5-dev \ + libgdbm-dev \ + libc6-dev \ + libffi-dev \ + tk-dev \ + && rm -rf /var/lib/apt/lists/* + +# Download and install Python 3.11 +RUN wget https://www.python.org/ftp/python/3.11.9/Python-3.11.9.tgz \ + && tar xzf Python-3.11.9.tgz \ + && cd Python-3.11.9 \ + && ./configure --enable-optimizations \ + && make altinstall \ + && cd .. \ + && rm -rf Python-3.11.9 Python-3.11.9.tgz + +# Set Python 3.11 as the default Python version +RUN update-alternatives --install /usr/bin/python python /usr/local/bin/python3.11 1 \ + && update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.11 1 diff --git a/docker/gh-builder/Dockerfile.py312 b/docker/gh-builder/Dockerfile.py312 new file mode 100644 index 000000000000..a0a7193201d4 --- /dev/null +++ b/docker/gh-builder/Dockerfile.py312 @@ -0,0 +1,35 @@ +# Start with NGC container +FROM nvcr.io/nvidia/pytorch:24.03-py3 + +# Set noninteractive mode for apt-get +ARG DEBIAN_FRONTEND=noninteractive + +# Install necessary dependencies for building Python +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + libssl-dev \ + zlib1g-dev \ + libbz2-dev \ + libreadline-dev \ + libsqlite3-dev \ + curl \ + libncursesw5-dev \ + libgdbm-dev \ + libc6-dev \ + libffi-dev \ + tk-dev \ + && rm -rf /var/lib/apt/lists/* + +# Download and install Python 3.12 +RUN wget https://www.python.org/ftp/python/3.12.5/Python-3.12.5.tgz \ + && tar xzf Python-3.12.5.tgz \ + && cd Python-3.12.5 \ + && ./configure --enable-optimizations \ + && make altinstall \ + && cd .. \ + && rm -rf Python-3.12.5 Python-3.12.5.tgz + +# Set Python 3.12 as the default Python version +RUN update-alternatives --install /usr/bin/python python /usr/local/bin/python3.12 1 \ + && update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.12 1 diff --git a/docker/Dockerfile.rocm b/docker/rocm/Dockerfile similarity index 100% rename from docker/Dockerfile.rocm rename to docker/rocm/Dockerfile From 057d25be6775105f4b9e1d41e6c21981a157c849 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Fri, 8 Nov 2024 08:34:20 -0800 Subject: [PATCH 8/8] Update version.txt after 0.15.4 release (#6731) **Auto-generated PR to update version.txt after a DeepSpeed release** Released version - 0.15.4 Author - @loadams Co-authored-by: loadams --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index 7ffdfa1cad65..1282fff53bfa 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.15.4 +0.15.5