diff --git a/.github/workflows/nv-ds-chat.yml b/.github/workflows/nv-ds-chat.yml index 94571eb101bb..cf8756fbd528 100644 --- a/.github/workflows/nv-ds-chat.yml +++ b/.github/workflows/nv-ds-chat.yml @@ -10,6 +10,10 @@ on: required: false default: 'master' type: string + pull_request: + paths: + - "deepspeed/runtime/zero/stage_1_and_2.py" + - "deepspeed/runtime/zero/stage3.py" concurrency: group: ${{ github.workflow }}-${{ github.ref }} diff --git a/accelerator/abstract_accelerator.py b/accelerator/abstract_accelerator.py index 768d5ea34e5e..690f19292299 100644 --- a/accelerator/abstract_accelerator.py +++ b/accelerator/abstract_accelerator.py @@ -12,6 +12,7 @@ class DeepSpeedAccelerator(ABC): def __init__(self): self._name = None self._communication_backend_name = None + self._compile_backend = None @abc.abstractmethod def is_synchronized_device(self): @@ -295,3 +296,11 @@ def visible_devices_envs(self): @abc.abstractmethod def set_visible_devices_envs(self, current_env, local_accelerator_ids): ... + + @abc.abstractmethod + def get_compile_backend(self): + ... + + @abc.abstractmethod + def set_compile_backend(self, backend): + ... diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index 237e7f51dcb4..bd11d034f312 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -20,6 +20,7 @@ class CPU_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'cpu' + self._compile_backend = "inductor" if oneccl_imported_p: self._communication_backend_name = 'ccl' else: @@ -330,3 +331,14 @@ def visible_devices_envs(self): def set_visible_devices_envs(self, current_env, local_accelerator_ids): for env in self.visible_devices_envs(): current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index 2fc0cfd94125..60d66b6cdbab 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -25,6 +25,7 @@ class CUDA_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'cuda' self._communication_backend_name = 'nccl' + self._compile_backend = "inductor" if pynvml is None: self._init_pynvml() @@ -367,3 +368,14 @@ def visible_devices_envs(self): def set_visible_devices_envs(self, current_env, local_accelerator_ids): for env in self.visible_devices_envs(): current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index 326efc8fa01b..114f367e879d 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -16,6 +16,7 @@ class HPU_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'hpu' self._communication_backend_name = 'hccl' + self._compile_backend = "hpu_backend" try: import habana_frameworks.torch.hpu as hpu hpu.setDeterministic(True) @@ -301,3 +302,14 @@ def visible_devices_envs(self): def set_visible_devices_envs(self, current_env, local_accelerator_ids): for env in self.visible_devices_envs(): current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/mps_accelerator.py b/accelerator/mps_accelerator.py index ff70b860d7c7..5fc9b1c8cfb6 100644 --- a/accelerator/mps_accelerator.py +++ b/accelerator/mps_accelerator.py @@ -20,6 +20,7 @@ class MPS_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = "mps" self._communication_backend_name = None + self._compile_backend = "inductor" def is_synchronized_device(self): return False @@ -267,3 +268,14 @@ def visible_devices_envs(self): def set_visible_devices_envs(self, current_env, local_accelerator_ids): for env in self.visible_devices_envs(): current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py index 5d891ecb707d..b0e0ff948e52 100644 --- a/accelerator/npu_accelerator.py +++ b/accelerator/npu_accelerator.py @@ -20,6 +20,7 @@ def __init__(self): super().__init__() self._name = 'npu' self._communication_backend_name = 'hccl' + self._compile_backend = "inductor" # dict that holds class name <--> class type mapping i.e. # 'AsyncIOBuilder': # this dict will be filled at init stage @@ -285,3 +286,14 @@ def visible_devices_envs(self): def set_visible_devices_envs(self, current_env, local_accelerator_ids): for env in self.visible_devices_envs(): current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends }") diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py index c59f60077d2f..9c4a9c903f96 100644 --- a/accelerator/xpu_accelerator.py +++ b/accelerator/xpu_accelerator.py @@ -14,6 +14,7 @@ class XPU_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'xpu' self._communication_backend_name = 'ccl' + self._compile_backend = "inductor" self.aligned_tensors = [] def is_synchronized_device(self): @@ -296,3 +297,14 @@ def visible_devices_envs(self): def set_visible_devices_envs(self, current_env, local_accelerator_ids): for env in self.visible_devices_envs(): current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/build_win.bat b/build_win.bat index ec8c8a362a78..af5c5103fa4b 100644 --- a/build_win.bat +++ b/build_win.bat @@ -1,6 +1,10 @@ @echo off set DS_BUILD_AIO=0 +set DS_BUILD_CUTLASS_OPS=0 +set DS_BUILD_EVOFORMER_ATTN=0 +set DS_BUILD_FP_QUANTIZER=0 +set DS_BUILD_RAGGED_DEVICE_OPS=0 set DS_BUILD_SPARSE_ATTN=0 echo Administrative permissions required. Detecting permissions... diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 8fc962c4f2a7..bdec8a55fcbc 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -34,6 +34,7 @@ def get_all_parameters(sub_module, recurse=False): return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters()) +@compiler.disable def iter_params(module: Module, recurse=False) -> Iterable[Parameter]: return map(lambda pair: pair[1], get_all_parameters(module, recurse)) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 68cab13c4a93..c6ff216edfcb 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -15,7 +15,7 @@ from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce -from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item +from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum @@ -2027,7 +2027,7 @@ def step(self, closure=None): return norm_groups = self._get_norm_groups() - scaled_global_grad_norm = torch.norm(torch.stack(norm_groups)) + scaled_global_grad_norm = get_global_norm(norm_list=norm_groups) # Stash unscaled gradient norm self._global_grad_norm = scaled_global_grad_norm / self.loss_scale diff --git a/tests/unit/runtime/compile/test_compile_wrapper.py b/tests/unit/runtime/compile/test_compile_wrapper.py index d1830534f6ea..62af25ac3ba4 100644 --- a/tests/unit/runtime/compile/test_compile_wrapper.py +++ b/tests/unit/runtime/compile/test_compile_wrapper.py @@ -31,11 +31,9 @@ def base_config(): }, "compile": { "enabled": True, - "backend": "inductor" + "backend": get_accelerator().get_compile_backend() } } - if get_accelerator().device_name() == 'hpu': - config_dict['compile']['backend'] = 'hpu_backend' return config_dict diff --git a/tests/unit/runtime/compile/test_compile_zero.py b/tests/unit/runtime/compile/test_compile_zero.py index 7568c27e3ed2..9890ea708eec 100644 --- a/tests/unit/runtime/compile/test_compile_zero.py +++ b/tests/unit/runtime/compile/test_compile_zero.py @@ -51,12 +51,10 @@ def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device): }, "compile": { "enabled": True, - "backend": "inductor" + "backend": get_accelerator().get_compile_backend() } } - if get_accelerator().device_name() == 'hpu': - config_dict['compile']['backend'] = 'hpu_backend' if offload_device == OffloadDeviceEnum.cpu: config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device} elif offload_device == OffloadDeviceEnum.nvme: diff --git a/tests/unit/runtime/compile/test_load_config.py b/tests/unit/runtime/compile/test_load_config.py index 601adae58884..cee8d3b23f6b 100644 --- a/tests/unit/runtime/compile/test_load_config.py +++ b/tests/unit/runtime/compile/test_load_config.py @@ -47,12 +47,10 @@ def base_config(): }, "compile": { "enabled": True, - "backend": "inductor" + "backend": get_accelerator().get_compile_backend() } } - if get_accelerator().device_name() == 'hpu': - config_dict['compile']['backend'] = 'hpu_backend' return config_dict