diff --git a/.github/workflows/cpu-torch-latest.yml b/.github/workflows/cpu-torch-latest.yml index 51bc60c2c2ae..78a51905834b 100644 --- a/.github/workflows/cpu-torch-latest.yml +++ b/.github/workflows/cpu-torch-latest.yml @@ -42,7 +42,7 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - git checkout 6c3f168b3 + # git checkout 6c3f168b3 git rev-parse --short HEAD pip install . diff --git a/.github/workflows/no-torch.yml b/.github/workflows/no-torch.yml index 1a13c0f3f4f1..5b89a6f36787 100644 --- a/.github/workflows/no-torch.yml +++ b/.github/workflows/no-torch.yml @@ -4,6 +4,7 @@ on: workflow_dispatch: pull_request: paths: + - 'accelerator/**' - '.github/workflows/no-torch.yml' - 'op_builder/**' schedule: diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml index 2d69d0b94cb5..a1ba4937d164 100644 --- a/.github/workflows/nv-torch-latest-v100.yml +++ b/.github/workflows/nv-torch-latest-v100.yml @@ -38,7 +38,7 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - git checkout 6c3f168b3 + # git checkout 6c3f168b3 git rev-parse --short HEAD pip install . diff --git a/.github/workflows/nv-torch-nightly-v100.yml b/.github/workflows/nv-torch-nightly-v100.yml index c2d10a454f94..0a9570a1ceaa 100644 --- a/.github/workflows/nv-torch-nightly-v100.yml +++ b/.github/workflows/nv-torch-nightly-v100.yml @@ -37,7 +37,7 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - git checkout 6c3f168b3 + # git checkout 6c3f168b3 git rev-parse --short HEAD pip install . diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index 0e49bd9f6458..de711f73144e 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -3,9 +3,15 @@ # DeepSpeed Team -import torch from .abstract_accelerator import DeepSpeedAccelerator +# During setup stage torch may not be installed, pass on no torch will +# allow op builder related API to be executed. +try: + import torch +except ImportError as e: + pass + try: import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore oneccl_imported_p = True diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index c7c7684fff79..42ffebbc4386 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -40,6 +40,9 @@ class DeepSpeedTPConfig(DeepSpeedConfigModel): tp_size: int = 1 """ Number of devices to split the model across using tensor parallelism. """ + tp_grain_size: int = 64 + "Desired MLP/lm_head tp size granularity. DNN library favors tensor size in granularity of power of 2, we pick 64 as a default size." + mpu: object = None """ A model parallelism unit object that implements diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py index 74d20a6d53e5..fe2fa1b476be 100644 --- a/deepspeed/launcher/multinode_runner.py +++ b/deepspeed/launcher/multinode_runner.py @@ -104,6 +104,8 @@ def get_cmd(self, environment, active_resources): deepspeed_launch.append("--no_local_rank") if self.args.save_pid: deepspeed_launch += ["--save_pid", f"{os.getpid()}"] + if self.args.enable_each_rank_log: + deepspeed_launch.append(f"--enable_each_rank_log={self.args.enable_each_rank_log}") if self.args.elastic_training: deepspeed_launch.append("--enable_elastic_training") deepspeed_launch.append(f"--max_elastic_nodes={self.args.max_elastic_nodes}") diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 7afe6ca903fb..e59f84bc8453 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -17,7 +17,7 @@ from .layers import TensorParallelOcShardConv2d, TensorParallelIcShardConv2d from deepspeed import comm as dist -from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads +from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads, set_tp_grain_size from .load_checkpoint import load_model_with_checkpoint import time @@ -303,6 +303,9 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): if hasattr(model_config, 'num_attention_heads'): set_num_attention_heads(getattr(model_config, 'num_attention_heads')) + # 4.4 set tp_grain_size + set_tp_grain_size(config.tensor_parallel.tp_grain_size) + # 5. Set linear policies _autotp.update_linear_policies() diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index 57be0c793856..3e6fc2b63ef1 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -22,6 +22,11 @@ def set_n_embd(num): n_embd = num +def set_tp_grain_size(num): + global tp_grain_size + tp_grain_size = num + + def get_num_kv_heads(): global num_kv_heads if 'num_kv_heads' in globals(): @@ -45,9 +50,9 @@ def get_shard_size(total_size, mp_size, name=None, rank=None): my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0) return total_size * my_slices // num_kv_heads else: - if total_size >= 64: - grain_size = total_size // 64 - return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * 64 + if total_size >= tp_grain_size: + grain_size = total_size // tp_grain_size + return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * tp_grain_size else: return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 0aad018528d3..5f023d87f375 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3738,6 +3738,11 @@ def offload_states(self, assert self.zero_optimization_stage( ) == ZeroStageEnum.weights, "Moving buffers across devices is supported only for ZeRO stage 3." + opt_offload_config = self.zero_offload_optimizer() + assert opt_offload_config is None or opt_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded optimizer states." + param_offload_config = self.zero_offload_param() + assert param_offload_config is None or param_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded parameters." + assert not self.zero_offload_param(), "Moving states across devices is not supported for offloaded parameters." if device == OffloadDeviceEnum.none: diff --git a/tests/unit/ops/transformer/inference/test_bias_geglu.py b/tests/unit/ops/transformer/inference/test_bias_geglu.py index 05de4fbb4cf8..c995d2a8c46d 100644 --- a/tests/unit/ops/transformer/inference/test_bias_geglu.py +++ b/tests/unit/ops/transformer/inference/test_bias_geglu.py @@ -15,8 +15,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def run_bias_geglu_reference(activations, bias): # Expected behavior is that of casting to float32 internally diff --git a/tests/unit/ops/transformer/inference/test_bias_gelu.py b/tests/unit/ops/transformer/inference/test_bias_gelu.py index b69030e87ace..e3a3bad63961 100644 --- a/tests/unit/ops/transformer/inference/test_bias_gelu.py +++ b/tests/unit/ops/transformer/inference/test_bias_gelu.py @@ -16,8 +16,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def run_bias_gelu_reference(activations, bias): # Expected behavior is that of casting to float32 internally and using the tanh approximation diff --git a/tests/unit/ops/transformer/inference/test_bias_relu.py b/tests/unit/ops/transformer/inference/test_bias_relu.py index 57134665b241..69078f9f7646 100644 --- a/tests/unit/ops/transformer/inference/test_bias_relu.py +++ b/tests/unit/ops/transformer/inference/test_bias_relu.py @@ -15,8 +15,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def run_bias_relu_reference(activations, bias): # Expected behavior is that of casting to float32 internally diff --git a/tests/unit/ops/transformer/inference/test_gelu.py b/tests/unit/ops/transformer/inference/test_gelu.py index 5f820ef3b579..a58abfdb100c 100644 --- a/tests/unit/ops/transformer/inference/test_gelu.py +++ b/tests/unit/ops/transformer/inference/test_gelu.py @@ -9,12 +9,11 @@ from deepspeed.ops.op_builder import InferenceBuilder from deepspeed.ops.transformer import DeepSpeedInferenceConfig from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp +from deepspeed.utils.torch import required_torch_version if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def allclose(x, y): assert x.dtype == y.dtype @@ -23,14 +22,11 @@ def allclose(x, y): def version_appropriate_gelu(activations): - global torch_minor_version - if torch_minor_version is None: - torch_minor_version = int(torch.__version__.split('.')[1]) - # If torch version = 1.12 - if torch_minor_version < 12: - return torch.nn.functional.gelu(activations) - else: + # gelu behavior changes (correctly) in torch 1.12 + if required_torch_version(min_version=1.12): return torch.nn.functional.gelu(activations, approximate='tanh') + else: + return torch.nn.functional.gelu(activations) def run_gelu_reference(activations): diff --git a/tests/unit/ops/transformer/inference/test_matmul.py b/tests/unit/ops/transformer/inference/test_matmul.py index 559aa2c60afe..2ab195ee0115 100644 --- a/tests/unit/ops/transformer/inference/test_matmul.py +++ b/tests/unit/ops/transformer/inference/test_matmul.py @@ -12,7 +12,6 @@ pytest.skip("Inference ops are not available on this system", allow_module_level=True) inference_module = None -torch_minor_version = None def allclose(x, y): diff --git a/tests/unit/ops/transformer/inference/test_softmax.py b/tests/unit/ops/transformer/inference/test_softmax.py index e582be1b926a..83785ac38ebb 100644 --- a/tests/unit/ops/transformer/inference/test_softmax.py +++ b/tests/unit/ops/transformer/inference/test_softmax.py @@ -11,8 +11,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def allclose(x, y): assert x.dtype == y.dtype