From db98cc3ad1e0a20807e0c2513f0eee40f626860e Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Mon, 16 Dec 2024 11:05:55 -0800 Subject: [PATCH 1/6] Fix assertion for offloading states (#6855) This PR fixes the assertions in `offload_states` method mentioned in #6833. Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/engine.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 0aad018528d3..5f023d87f375 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3738,6 +3738,11 @@ def offload_states(self, assert self.zero_optimization_stage( ) == ZeroStageEnum.weights, "Moving buffers across devices is supported only for ZeRO stage 3." + opt_offload_config = self.zero_offload_optimizer() + assert opt_offload_config is None or opt_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded optimizer states." + param_offload_config = self.zero_offload_param() + assert param_offload_config is None or param_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded parameters." + assert not self.zero_offload_param(), "Moving states across devices is not supported for offloaded parameters." if device == OffloadDeviceEnum.none: From 87c650681eb285ab34a69a011b520f756f42d4b9 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Mon, 16 Dec 2024 11:21:51 -0800 Subject: [PATCH 2/6] Remove pin from transformers version and fix Processing/Threading issues in tests (#6822) Changes from https://github.com/huggingface/transformers/pull/34966 caused the `nv-torch-latest-v100` tests to fail with the following error: ``` File "/tmp/azureml/cr/j/e4bfd57a509846d6bbc4914639ad248d/exe/wd/actions-runner/_work/DeepSpeed/DeepSpeed/unit-test-venv/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3941, in from_pretrained raise EnvironmentError( OSError: Can't load the model for 'hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2' is the correct path to a directory containing a file named pytorch_model.bin, tf_model.h5, model.ckpt or flax_model.msgpack. ``` Sample failure here: https://github.com/microsoft/DeepSpeed/actions/runs/12169422174/job/33942348835?pr=6794#step:8:3506 This was resolved on the Transformers side here: https://github.com/huggingface/transformers/pull/35236 --- .github/workflows/cpu-torch-latest.yml | 2 +- .github/workflows/nv-torch-latest-v100.yml | 2 +- .github/workflows/nv-torch-nightly-v100.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cpu-torch-latest.yml b/.github/workflows/cpu-torch-latest.yml index 51bc60c2c2ae..78a51905834b 100644 --- a/.github/workflows/cpu-torch-latest.yml +++ b/.github/workflows/cpu-torch-latest.yml @@ -42,7 +42,7 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - git checkout 6c3f168b3 + # git checkout 6c3f168b3 git rev-parse --short HEAD pip install . diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml index 2d69d0b94cb5..a1ba4937d164 100644 --- a/.github/workflows/nv-torch-latest-v100.yml +++ b/.github/workflows/nv-torch-latest-v100.yml @@ -38,7 +38,7 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - git checkout 6c3f168b3 + # git checkout 6c3f168b3 git rev-parse --short HEAD pip install . diff --git a/.github/workflows/nv-torch-nightly-v100.yml b/.github/workflows/nv-torch-nightly-v100.yml index c2d10a454f94..0a9570a1ceaa 100644 --- a/.github/workflows/nv-torch-nightly-v100.yml +++ b/.github/workflows/nv-torch-nightly-v100.yml @@ -37,7 +37,7 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - git checkout 6c3f168b3 + # git checkout 6c3f168b3 git rev-parse --short HEAD pip install . From da771ed42e41a44d5047813ca4672f1cfe9d1731 Mon Sep 17 00:00:00 2001 From: Yejing-Lai Date: Tue, 17 Dec 2024 06:14:53 +0800 Subject: [PATCH 3/6] Add MLP/lm_head tp grain size setting. (#6828) This PR aims to add MLP/lm_head tp size granularity setting to deepspeed.init_inference() API. It will be more flexible to set the MLP/lm_head sharding grain size. DNN library favors tensor size in granularity of power of 2, we pick 64 as a default size. We aim to be able to set the MLP/lm_head tp grain size flexibly. This is a preliminary solution. If there is a better solution, we can discuss it together. Thanks~ --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase --- deepspeed/inference/config.py | 3 +++ deepspeed/module_inject/replace_module.py | 5 ++++- deepspeed/module_inject/tp_shard.py | 11 ++++++++--- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index c7c7684fff79..42ffebbc4386 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -40,6 +40,9 @@ class DeepSpeedTPConfig(DeepSpeedConfigModel): tp_size: int = 1 """ Number of devices to split the model across using tensor parallelism. """ + tp_grain_size: int = 64 + "Desired MLP/lm_head tp size granularity. DNN library favors tensor size in granularity of power of 2, we pick 64 as a default size." + mpu: object = None """ A model parallelism unit object that implements diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 7afe6ca903fb..e59f84bc8453 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -17,7 +17,7 @@ from .layers import TensorParallelOcShardConv2d, TensorParallelIcShardConv2d from deepspeed import comm as dist -from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads +from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads, set_tp_grain_size from .load_checkpoint import load_model_with_checkpoint import time @@ -303,6 +303,9 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): if hasattr(model_config, 'num_attention_heads'): set_num_attention_heads(getattr(model_config, 'num_attention_heads')) + # 4.4 set tp_grain_size + set_tp_grain_size(config.tensor_parallel.tp_grain_size) + # 5. Set linear policies _autotp.update_linear_policies() diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index 57be0c793856..3e6fc2b63ef1 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -22,6 +22,11 @@ def set_n_embd(num): n_embd = num +def set_tp_grain_size(num): + global tp_grain_size + tp_grain_size = num + + def get_num_kv_heads(): global num_kv_heads if 'num_kv_heads' in globals(): @@ -45,9 +50,9 @@ def get_shard_size(total_size, mp_size, name=None, rank=None): my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0) return total_size * my_slices // num_kv_heads else: - if total_size >= 64: - grain_size = total_size // 64 - return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * 64 + if total_size >= tp_grain_size: + grain_size = total_size // tp_grain_size + return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * tp_grain_size else: return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0) From a964e435532699908e5750abdb027ae583ff793d Mon Sep 17 00:00:00 2001 From: Aviv Keshet Date: Tue, 17 Dec 2024 09:33:09 -0800 Subject: [PATCH 4/6] Fix --enable_each_rank_log when used with PDSH multi-node runner (#6863) This PR addresses fixes https://github.com/microsoft/DeepSpeed/issues/6859 by threading this argument into the deepspeed launcher command build by PDSHRunner. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/launcher/multinode_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py index 74d20a6d53e5..fe2fa1b476be 100644 --- a/deepspeed/launcher/multinode_runner.py +++ b/deepspeed/launcher/multinode_runner.py @@ -104,6 +104,8 @@ def get_cmd(self, environment, active_resources): deepspeed_launch.append("--no_local_rank") if self.args.save_pid: deepspeed_launch += ["--save_pid", f"{os.getpid()}"] + if self.args.enable_each_rank_log: + deepspeed_launch.append(f"--enable_each_rank_log={self.args.enable_each_rank_log}") if self.args.elastic_training: deepspeed_launch.append("--enable_elastic_training") deepspeed_launch.append(f"--max_elastic_nodes={self.args.max_elastic_nodes}") From 2f32966b1cd874aa4373177c8f8c4214ad57d020 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Tue, 17 Dec 2024 11:53:47 -0800 Subject: [PATCH 5/6] Update transformers ops unit tests to use `requried_torch_version` (#6884) --- .../ops/transformer/inference/test_bias_geglu.py | 2 -- .../ops/transformer/inference/test_bias_gelu.py | 2 -- .../ops/transformer/inference/test_bias_relu.py | 2 -- tests/unit/ops/transformer/inference/test_gelu.py | 14 +++++--------- .../unit/ops/transformer/inference/test_matmul.py | 1 - .../unit/ops/transformer/inference/test_softmax.py | 2 -- 6 files changed, 5 insertions(+), 18 deletions(-) diff --git a/tests/unit/ops/transformer/inference/test_bias_geglu.py b/tests/unit/ops/transformer/inference/test_bias_geglu.py index 05de4fbb4cf8..c995d2a8c46d 100644 --- a/tests/unit/ops/transformer/inference/test_bias_geglu.py +++ b/tests/unit/ops/transformer/inference/test_bias_geglu.py @@ -15,8 +15,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def run_bias_geglu_reference(activations, bias): # Expected behavior is that of casting to float32 internally diff --git a/tests/unit/ops/transformer/inference/test_bias_gelu.py b/tests/unit/ops/transformer/inference/test_bias_gelu.py index b69030e87ace..e3a3bad63961 100644 --- a/tests/unit/ops/transformer/inference/test_bias_gelu.py +++ b/tests/unit/ops/transformer/inference/test_bias_gelu.py @@ -16,8 +16,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def run_bias_gelu_reference(activations, bias): # Expected behavior is that of casting to float32 internally and using the tanh approximation diff --git a/tests/unit/ops/transformer/inference/test_bias_relu.py b/tests/unit/ops/transformer/inference/test_bias_relu.py index 57134665b241..69078f9f7646 100644 --- a/tests/unit/ops/transformer/inference/test_bias_relu.py +++ b/tests/unit/ops/transformer/inference/test_bias_relu.py @@ -15,8 +15,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def run_bias_relu_reference(activations, bias): # Expected behavior is that of casting to float32 internally diff --git a/tests/unit/ops/transformer/inference/test_gelu.py b/tests/unit/ops/transformer/inference/test_gelu.py index 5f820ef3b579..a58abfdb100c 100644 --- a/tests/unit/ops/transformer/inference/test_gelu.py +++ b/tests/unit/ops/transformer/inference/test_gelu.py @@ -9,12 +9,11 @@ from deepspeed.ops.op_builder import InferenceBuilder from deepspeed.ops.transformer import DeepSpeedInferenceConfig from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp +from deepspeed.utils.torch import required_torch_version if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def allclose(x, y): assert x.dtype == y.dtype @@ -23,14 +22,11 @@ def allclose(x, y): def version_appropriate_gelu(activations): - global torch_minor_version - if torch_minor_version is None: - torch_minor_version = int(torch.__version__.split('.')[1]) - # If torch version = 1.12 - if torch_minor_version < 12: - return torch.nn.functional.gelu(activations) - else: + # gelu behavior changes (correctly) in torch 1.12 + if required_torch_version(min_version=1.12): return torch.nn.functional.gelu(activations, approximate='tanh') + else: + return torch.nn.functional.gelu(activations) def run_gelu_reference(activations): diff --git a/tests/unit/ops/transformer/inference/test_matmul.py b/tests/unit/ops/transformer/inference/test_matmul.py index 559aa2c60afe..2ab195ee0115 100644 --- a/tests/unit/ops/transformer/inference/test_matmul.py +++ b/tests/unit/ops/transformer/inference/test_matmul.py @@ -12,7 +12,6 @@ pytest.skip("Inference ops are not available on this system", allow_module_level=True) inference_module = None -torch_minor_version = None def allclose(x, y): diff --git a/tests/unit/ops/transformer/inference/test_softmax.py b/tests/unit/ops/transformer/inference/test_softmax.py index e582be1b926a..83785ac38ebb 100644 --- a/tests/unit/ops/transformer/inference/test_softmax.py +++ b/tests/unit/ops/transformer/inference/test_softmax.py @@ -11,8 +11,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def allclose(x, y): assert x.dtype == y.dtype From 4cd1d97460b677563d57f07a293724bdc02e0ef5 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Tue, 17 Dec 2024 17:30:52 -0800 Subject: [PATCH 6/6] Don't error out when cpu accelerator doesn't have torch (as default for whl building) (#6886) This fixes a bug introduced in #6845, which breaks the `no-torch` workflow that we require in order to do releases where we do not require torch to be in the environment when building an sdist. This adds the same logic to the cpuaccelerator that the cudaaccelerator had where we don't require torch to be installed to build the whl. --- .github/workflows/no-torch.yml | 1 + accelerator/cpu_accelerator.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/no-torch.yml b/.github/workflows/no-torch.yml index 1a13c0f3f4f1..5b89a6f36787 100644 --- a/.github/workflows/no-torch.yml +++ b/.github/workflows/no-torch.yml @@ -4,6 +4,7 @@ on: workflow_dispatch: pull_request: paths: + - 'accelerator/**' - '.github/workflows/no-torch.yml' - 'op_builder/**' schedule: diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index 0e49bd9f6458..de711f73144e 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -3,9 +3,15 @@ # DeepSpeed Team -import torch from .abstract_accelerator import DeepSpeedAccelerator +# During setup stage torch may not be installed, pass on no torch will +# allow op builder related API to be executed. +try: + import torch +except ImportError as e: + pass + try: import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore oneccl_imported_p = True