diff --git a/.github/workflows/no-torch.yml b/.github/workflows/no-torch.yml new file mode 100644 index 0000000000000..794c02b3c4dbf --- /dev/null +++ b/.github/workflows/no-torch.yml @@ -0,0 +1,45 @@ +name: no-torch + +on: + workflow_dispatch: + pull_request: + paths: + - '.github/workflows/no-torch.yml' + schedule: + - cron: "0 0 * * *" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + issues: write + +jobs: + unit-tests: + runs-on: ubuntu-22.04 + + steps: + - uses: actions/checkout@v4 + + - id: setup-venv + uses: ./.github/workflows/setup-venv + + - name: Python environment + run: | + pip uninstall torch --yes + pip list + + - name: Build deepspeed + run: | + DS_BUILD_STRING=" " python setup.py sdist + + - name: Open GitHub issue if nightly CI fails + if: ${{ failure() && (github.event_name == 'schedule') }} + uses: JasonEtco/create-an-issue@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + filename: .github/ISSUE_TEMPLATE/ci_failure_report.md + update_existing: true diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp index 3a35ad3145a0d..15fd516acaae5 100644 --- a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp @@ -20,7 +20,7 @@ deepspeed_gds_handle_t::deepspeed_gds_handle_t(const int block_size, const bool single_submit, const bool overlap_events, const int num_threads) - : deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, num_threads) + : deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, 1) { _init_cuFile(block_size, queue_depth, num_threads); } diff --git a/deepspeed/comm/ccl.py b/deepspeed/comm/ccl.py index cdf4c030f5d65..8753cf9f70ed5 100644 --- a/deepspeed/comm/ccl.py +++ b/deepspeed/comm/ccl.py @@ -15,7 +15,7 @@ def build_ccl_op(): builder = get_accelerator().create_op_builder("CCLCommBuilder") - if builder is None or NotImplementedBuilder: + if builder is None or isinstance(builder, NotImplementedBuilder): return None ccl_cpp_module = builder.load() print(f'DeepSpeed {builder.absolute_name()} built successfully') diff --git a/docs/_tutorials/getting-started.md b/docs/_tutorials/getting-started.md index f9a4cfdc68b49..ce9e3ee9a8922 100644 --- a/docs/_tutorials/getting-started.md +++ b/docs/_tutorials/getting-started.md @@ -11,6 +11,7 @@ tags: getting-started * To get started with DeepSpeed on AzureML, please see the [AzureML Examples GitHub](https://github.com/Azure/azureml-examples/tree/main/cli/jobs/deepspeed) * DeepSpeed has direct integrations with [HuggingFace Transformers](https://github.com/huggingface/transformers) and [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning). HuggingFace Transformers users can now easily accelerate their models with DeepSpeed through a simple ``--deepspeed`` flag + config file [See more details](https://huggingface.co/docs/transformers/main_classes/deepspeed). PyTorch Lightning provides easy access to DeepSpeed through the Lightning Trainer [See more details](https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html?highlight=deepspeed#deepspeed). * DeepSpeed on AMD can be used via our [ROCm images](https://hub.docker.com/r/deepspeed/rocm501/tags), e.g., `docker pull deepspeed/rocm501:ds060_pytorch110`. +* DeepSpeed also supports Intel Xeon CPU, Intel Data Center Max Series XPU, Intel Gaudi HPU, Huawei Ascend NPU etc, please refer to the [accelerator setup guide](/tutorials/accelerator-setup-guide/) @@ -226,6 +227,36 @@ deepspeed --include="worker-2:0,1" \ \ --deepspeed --deepspeed_config ds_config.json ``` +### Launching without passwordless SSH + +DeepSpeed now supports launching training jobs without the need for passwordless SSH. This mode is +particularly useful in cloud environments such as Kubernetes, where flexible container orchestration +is possible, and setting up a leader-worker architecture with passwordless SSH adds unnecessary +complexity. + +To use this mode, you need to run the DeepSpeed command separately on all nodes. The command should +be structured as follows: + +```bash +deepspeed --hostfile=myhostfile --no_ssh --node_rank= \ + --master_addr= --master_port= \ + \ + --deepspeed --deepspeed_config ds_config.json +``` + +- `--hostfile=myhostfile`: Specifies the hostfile that contains information about the nodes and GPUs. +- `--no_ssh`: Enables the no-SSH mode. +- `--node_rank=`: Specifies the rank of the node. This should be a unique integer from 0 to n - 1. +- `--master_addr=`: The address of the leader node (rank 0). +- `--master_port=`: The port of the leader node. + +In this setup, the hostnames in the hostfile do not need to be reachable via passwordless SSH. +However, the hostfile is still required for the launcher to collect information about the environment, +such as the number of nodes and the number of GPUs per node. + +Each node must be launched with a unique `node_rank`, and all nodes must be provided with the address +and port of the leader node (rank 0). This mode causes the launcher to act similarly to the `torchrun` +launcher, as described in the [PyTorch documentation](https://pytorch.org/docs/stable/elastic/run.html). ## Multi-Node Environment Variables diff --git a/op_builder/fp_quantizer.py b/op_builder/fp_quantizer.py index 40cf504c2c834..75ee54c09bf60 100644 --- a/op_builder/fp_quantizer.py +++ b/op_builder/fp_quantizer.py @@ -49,23 +49,27 @@ def is_compatible(self, verbose=False): import triton except ImportError: if verbose: - self.warning(f"please install triton==2.3.0 or 2.3.1 if you want to use the FP Quantizer Kernels") + self.warning( + f"please install triton==2.3.0, 2.3.1 or 3.0.0 if you want to use the FP Quantizer Kernels") return False - # triton 2.3.0 and 2.3.1 are okay and the only versions released in 2.3.x before 3.x was released + # triton 2.3.{0,1} and 3.0.0 are ok. + allowed_versions = ("2.3", "3.0") if pkg_version: - allowed = pkg_version.parse("2.3") + allowed = (pkg_version.parse(v) for v in allowed_versions) installed_triton = pkg_version.parse(triton.__version__) - triton_mismatch = installed_triton.major != allowed.major or installed_triton.minor != allowed.minor + triton_mismatch = all(installed_triton.major != a.major or installed_triton.minor != a.minor + for a in allowed) else: installed_triton = triton.__version__ major, minor, _ = installed_triton.split(".") - triton_mismatch = major != "2" or minor != "3" + allowed = (v.split(".") for v in allowed_versions) + triton_mismatch = all(major != v[0] or minor != v[1] for v in allowed) if triton_mismatch: if verbose: self.warning( - f"FP Quantizer is using an untested triton version ({installed_triton}), only 2.3.0 and 2.3.1 are known to be compatible with these kernels" + f"FP Quantizer is using an untested triton version ({installed_triton}), only 2.3.{0,1} and 3.0.0 are known to be compatible with these kernels" ) return False diff --git a/op_builder/gds.py b/op_builder/gds.py index 01c2d5a245d1f..727ebdf483722 100644 --- a/op_builder/gds.py +++ b/op_builder/gds.py @@ -36,6 +36,11 @@ def extra_ldflags(self): return super().extra_ldflags() + ['-lcufile'] def is_compatible(self, verbose=False): + if self.is_rocm_pytorch(): + if verbose: + self.warning(f'{self.NAME} is not compatible with ROCM') + return False + try: import torch.utils.cpp_extension except ImportError: diff --git a/tests/unit/alexnet_model.py b/tests/unit/alexnet_model.py index dfab28aa7477c..51e80e7f9e62f 100644 --- a/tests/unit/alexnet_model.py +++ b/tests/unit/alexnet_model.py @@ -101,12 +101,14 @@ def cifar_trainset(fp16=False): dist.barrier() if local_rank != 0: dist.barrier() - data_root = os.getenv("TEST_DATA_DIR", "/tmp/") - trainset = torchvision.datasets.CIFAR10(root=os.path.join(data_root, "cifar10-data"), - train=True, - download=True, - transform=transform) + if os.getenv("CIFAR10_DATASET_PATH"): + data_root = os.getenv("CIFAR10_DATASET_PATH") + download = False + else: + data_root = os.path.join(os.getenv("TEST_DATA_DIR", "/tmp"), "cifar10-data") + download = True + trainset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=download, transform=transform) if local_rank == 0: dist.barrier() return trainset diff --git a/tests/unit/inference/test_checkpoint_sharding.py b/tests/unit/inference/test_checkpoint_sharding.py index 5bae9a151a27c..f1e37ee26536e 100644 --- a/tests/unit/inference/test_checkpoint_sharding.py +++ b/tests/unit/inference/test_checkpoint_sharding.py @@ -14,6 +14,7 @@ from huggingface_hub import snapshot_download from transformers.utils import is_offline_mode from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.accelerator import get_accelerator if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("This op had not been implemented on this system.", allow_module_level=True) @@ -44,6 +45,8 @@ def model_name(request): @pytest.fixture(params=[torch.float16, torch.int8], ids=["fp16", "int8"]) def dtype(request): + if request.param not in get_accelerator().supported_dtypes(): + pytest.skip(f"{request.param} not supported by {get_accelerator().device_name()}.") return request.param diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index eadf670d93283..581a2ce433edc 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -298,6 +298,12 @@ def verify_injection(module): verify_injection(model) +# Used to Get Device name +def getDeviceId(local_rank): + device = torch.device(f"{get_accelerator().device_name(local_rank)}") + return device + + # Verify that test is valid def validate_test(model_w_task, dtype, enable_cuda_graph, enable_triton): model, task = model_w_task @@ -484,8 +490,8 @@ def test( pytest.skip(f"Acceleraor {get_accelerator().device_name()} does not support {dtype}.") local_rank = int(os.getenv("LOCAL_RANK", "0")) - - pipe = pipeline(task, model=model, model_kwargs={"low_cpu_mem_usage": True}, device=local_rank, framework="pt") + device = getDeviceId(local_rank) + pipe = pipeline(task, model=model, model_kwargs={"low_cpu_mem_usage": True}, device=device, framework="pt") bs_output = pipe(query, **inf_kwargs) pipe.model = deepspeed.init_inference(pipe.model, mp_size=self.world_size, diff --git a/tests/unit/inference/test_model_profiling.py b/tests/unit/inference/test_model_profiling.py index 23e49f89025b3..319055d0ea55a 100644 --- a/tests/unit/inference/test_model_profiling.py +++ b/tests/unit/inference/test_model_profiling.py @@ -16,6 +16,9 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("This op had not been implemented on this system.", allow_module_level=True) +if torch.half not in get_accelerator().supported_dtypes(): + pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) + @pytest.mark.inference @pytest.mark.parametrize("use_cuda_events", [True, False]) diff --git a/tests/unit/ops/transformer/inference/inference_test_utils.py b/tests/unit/ops/transformer/inference/inference_test_utils.py index 9c7b428c0e68c..9cfcae809f090 100644 --- a/tests/unit/ops/transformer/inference/inference_test_utils.py +++ b/tests/unit/ops/transformer/inference/inference_test_utils.py @@ -26,12 +26,7 @@ def get_tolerances(): def get_dtypes(): global DTYPES if DTYPES is None: - DTYPES = [torch.float16, torch.float32] - try: - if get_accelerator().is_bf16_supported(): - DTYPES.append(torch.bfloat16) - except (AssertionError, AttributeError): - pass + DTYPES = get_accelerator().supported_dtypes() return DTYPES