Skip to content

Commit

Permalink
Merge branch 'master' into torch_compile_micro_offset_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Sep 3, 2024
2 parents 9e3339d + 9b7fc54 commit 9c5bd48
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 21 deletions.
45 changes: 45 additions & 0 deletions .github/workflows/no-torch.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: no-torch

on:
workflow_dispatch:
pull_request:
paths:
- '.github/workflows/no-torch.yml'
schedule:
- cron: "0 0 * * *"

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

permissions:
contents: read
issues: write

jobs:
unit-tests:
runs-on: ubuntu-22.04

steps:
- uses: actions/checkout@v4

- id: setup-venv
uses: ./.github/workflows/setup-venv

- name: Python environment
run: |
pip uninstall torch --yes
pip list
- name: Build deepspeed
run: |
DS_BUILD_STRING=" " python setup.py sdist
- name: Open GitHub issue if nightly CI fails
if: ${{ failure() && (github.event_name == 'schedule') }}
uses: JasonEtco/create-an-issue@v2
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
filename: .github/ISSUE_TEMPLATE/ci_failure_report.md
update_existing: true
2 changes: 1 addition & 1 deletion csrc/gds/py_lib/deepspeed_py_gds_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ deepspeed_gds_handle_t::deepspeed_gds_handle_t(const int block_size,
const bool single_submit,
const bool overlap_events,
const int num_threads)
: deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, num_threads)
: deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, 1)
{
_init_cuFile(block_size, queue_depth, num_threads);
}
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/comm/ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

def build_ccl_op():
builder = get_accelerator().create_op_builder("CCLCommBuilder")
if builder is None or NotImplementedBuilder:
if builder is None or isinstance(builder, NotImplementedBuilder):
return None
ccl_cpp_module = builder.load()
print(f'DeepSpeed {builder.absolute_name()} built successfully')
Expand Down
31 changes: 31 additions & 0 deletions docs/_tutorials/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ tags: getting-started
* To get started with DeepSpeed on AzureML, please see the [AzureML Examples GitHub](https://github.com/Azure/azureml-examples/tree/main/cli/jobs/deepspeed)
* DeepSpeed has direct integrations with [HuggingFace Transformers](https://github.com/huggingface/transformers) and [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning). HuggingFace Transformers users can now easily accelerate their models with DeepSpeed through a simple ``--deepspeed`` flag + config file [See more details](https://huggingface.co/docs/transformers/main_classes/deepspeed). PyTorch Lightning provides easy access to DeepSpeed through the Lightning Trainer [See more details](https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html?highlight=deepspeed#deepspeed).
* DeepSpeed on AMD can be used via our [ROCm images](https://hub.docker.com/r/deepspeed/rocm501/tags), e.g., `docker pull deepspeed/rocm501:ds060_pytorch110`.
* DeepSpeed also supports Intel Xeon CPU, Intel Data Center Max Series XPU, Intel Gaudi HPU, Huawei Ascend NPU etc, please refer to the [accelerator setup guide](/tutorials/accelerator-setup-guide/)



Expand Down Expand Up @@ -226,6 +227,36 @@ deepspeed --include="worker-2:0,1" \
<client_entry.py> <client args> \
--deepspeed --deepspeed_config ds_config.json
```
### Launching without passwordless SSH

DeepSpeed now supports launching training jobs without the need for passwordless SSH. This mode is
particularly useful in cloud environments such as Kubernetes, where flexible container orchestration
is possible, and setting up a leader-worker architecture with passwordless SSH adds unnecessary
complexity.

To use this mode, you need to run the DeepSpeed command separately on all nodes. The command should
be structured as follows:

```bash
deepspeed --hostfile=myhostfile --no_ssh --node_rank=<n> \
--master_addr=<addr> --master_port=<port> \
<client_entry.py> <client args> \
--deepspeed --deepspeed_config ds_config.json
```

- `--hostfile=myhostfile`: Specifies the hostfile that contains information about the nodes and GPUs.
- `--no_ssh`: Enables the no-SSH mode.
- `--node_rank=<n>`: Specifies the rank of the node. This should be a unique integer from 0 to n - 1.
- `--master_addr=<addr>`: The address of the leader node (rank 0).
- `--master_port=<port>`: The port of the leader node.

In this setup, the hostnames in the hostfile do not need to be reachable via passwordless SSH.
However, the hostfile is still required for the launcher to collect information about the environment,
such as the number of nodes and the number of GPUs per node.

Each node must be launched with a unique `node_rank`, and all nodes must be provided with the address
and port of the leader node (rank 0). This mode causes the launcher to act similarly to the `torchrun`
launcher, as described in the [PyTorch documentation](https://pytorch.org/docs/stable/elastic/run.html).

## Multi-Node Environment Variables

Expand Down
16 changes: 10 additions & 6 deletions op_builder/fp_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,27 @@ def is_compatible(self, verbose=False):
import triton
except ImportError:
if verbose:
self.warning(f"please install triton==2.3.0 or 2.3.1 if you want to use the FP Quantizer Kernels")
self.warning(
f"please install triton==2.3.0, 2.3.1 or 3.0.0 if you want to use the FP Quantizer Kernels")
return False

# triton 2.3.0 and 2.3.1 are okay and the only versions released in 2.3.x before 3.x was released
# triton 2.3.{0,1} and 3.0.0 are ok.
allowed_versions = ("2.3", "3.0")
if pkg_version:
allowed = pkg_version.parse("2.3")
allowed = (pkg_version.parse(v) for v in allowed_versions)
installed_triton = pkg_version.parse(triton.__version__)
triton_mismatch = installed_triton.major != allowed.major or installed_triton.minor != allowed.minor
triton_mismatch = all(installed_triton.major != a.major or installed_triton.minor != a.minor
for a in allowed)
else:
installed_triton = triton.__version__
major, minor, _ = installed_triton.split(".")
triton_mismatch = major != "2" or minor != "3"
allowed = (v.split(".") for v in allowed_versions)
triton_mismatch = all(major != v[0] or minor != v[1] for v in allowed)

if triton_mismatch:
if verbose:
self.warning(
f"FP Quantizer is using an untested triton version ({installed_triton}), only 2.3.0 and 2.3.1 are known to be compatible with these kernels"
f"FP Quantizer is using an untested triton version ({installed_triton}), only 2.3.{0,1} and 3.0.0 are known to be compatible with these kernels"
)
return False

Expand Down
5 changes: 5 additions & 0 deletions op_builder/gds.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def extra_ldflags(self):
return super().extra_ldflags() + ['-lcufile']

def is_compatible(self, verbose=False):
if self.is_rocm_pytorch():
if verbose:
self.warning(f'{self.NAME} is not compatible with ROCM')
return False

try:
import torch.utils.cpp_extension
except ImportError:
Expand Down
12 changes: 7 additions & 5 deletions tests/unit/alexnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,14 @@ def cifar_trainset(fp16=False):
dist.barrier()
if local_rank != 0:
dist.barrier()

data_root = os.getenv("TEST_DATA_DIR", "/tmp/")
trainset = torchvision.datasets.CIFAR10(root=os.path.join(data_root, "cifar10-data"),
train=True,
download=True,
transform=transform)
if os.getenv("CIFAR10_DATASET_PATH"):
data_root = os.getenv("CIFAR10_DATASET_PATH")
download = False
else:
data_root = os.path.join(os.getenv("TEST_DATA_DIR", "/tmp"), "cifar10-data")
download = True
trainset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=download, transform=transform)
if local_rank == 0:
dist.barrier()
return trainset
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/inference/test_checkpoint_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from huggingface_hub import snapshot_download
from transformers.utils import is_offline_mode
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.accelerator import get_accelerator

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)
Expand Down Expand Up @@ -44,6 +45,8 @@ def model_name(request):

@pytest.fixture(params=[torch.float16, torch.int8], ids=["fp16", "int8"])
def dtype(request):
if request.param not in get_accelerator().supported_dtypes():
pytest.skip(f"{request.param} not supported by {get_accelerator().device_name()}.")
return request.param


Expand Down
10 changes: 8 additions & 2 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,12 @@ def verify_injection(module):
verify_injection(model)


# Used to Get Device name
def getDeviceId(local_rank):
device = torch.device(f"{get_accelerator().device_name(local_rank)}")
return device


# Verify that test is valid
def validate_test(model_w_task, dtype, enable_cuda_graph, enable_triton):
model, task = model_w_task
Expand Down Expand Up @@ -484,8 +490,8 @@ def test(
pytest.skip(f"Acceleraor {get_accelerator().device_name()} does not support {dtype}.")

local_rank = int(os.getenv("LOCAL_RANK", "0"))

pipe = pipeline(task, model=model, model_kwargs={"low_cpu_mem_usage": True}, device=local_rank, framework="pt")
device = getDeviceId(local_rank)
pipe = pipeline(task, model=model, model_kwargs={"low_cpu_mem_usage": True}, device=device, framework="pt")
bs_output = pipe(query, **inf_kwargs)
pipe.model = deepspeed.init_inference(pipe.model,
mp_size=self.world_size,
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/inference/test_model_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)

if torch.half not in get_accelerator().supported_dtypes():
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)


@pytest.mark.inference
@pytest.mark.parametrize("use_cuda_events", [True, False])
Expand Down
7 changes: 1 addition & 6 deletions tests/unit/ops/transformer/inference/inference_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,7 @@ def get_tolerances():
def get_dtypes():
global DTYPES
if DTYPES is None:
DTYPES = [torch.float16, torch.float32]
try:
if get_accelerator().is_bf16_supported():
DTYPES.append(torch.bfloat16)
except (AssertionError, AttributeError):
pass
DTYPES = get_accelerator().supported_dtypes()
return DTYPES


Expand Down

0 comments on commit 9c5bd48

Please sign in to comment.