Skip to content

Commit

Permalink
Merge branch 'master' into loadams/amd-57
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored May 13, 2024
2 parents 2858e93 + 4696afd commit 0651cc3
Show file tree
Hide file tree
Showing 48 changed files with 1,304 additions and 203 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/cpu-torch-latest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ jobs:
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.2"
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.2"
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.3"
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.3"
4 changes: 4 additions & 0 deletions .github/workflows/nv-ds-chat.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ on:
required: false
default: 'master'
type: string
pull_request:
paths:
- "deepspeed/runtime/zero/stage_1_and_2.py"
- "deepspeed/runtime/zero/stage3.py"

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/nv-torch-latest-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,5 @@ jobs:
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.2" --cuda_ver="11.8"
pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.2" --cuda_ver="11.8"
pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.3" --cuda_ver="11.8"
pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.3" --cuda_ver="11.8"
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ include deepspeed/inference/v2/kernels/ragged_ops/libs/*.so
include deepspeed/inference/v2/kernels/cutlass_ops/libs/*.so
recursive-include requirements *.txt
recursive-include deepspeed *.cpp *.h *.cu *.hip *.tr *.cuh *.cc *.json
recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc
recursive-include csrc *.cpp *.h *.hpp *.cu *.tr *.cuh *.cc
recursive-include op_builder *.py
recursive-include benchmarks *.py
recursive-include accelerator *.py
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>

* [2024/03] [DeepSpeed-FP6:The power of FP6-Centric Serving for Large Language Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)]
* [2024/01] [DeepSpeed-FastGen: Introducting Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19)
* [2024/01] [DeepSpeed-FastGen: Introducing Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19)
* [2023/11] [Llama 2 Inference on 4th Gen Intel® Xeon® Scalable Processor with DeepSpeed](https://github.com/microsoft/DeepSpeed/tree/master/blogs/intel-inference) [[Intel version]](https://www.intel.com/content/www/us/en/developer/articles/technical/xllama-2-on-xeon-scalable-processor-with-deepspeed.html)
* [2023/11] [DeepSpeed ZeRO-Offload++: 6x Higher Training Throughput via Collaborative CPU/GPU Twin-Flow](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-offloadpp)
* [2023/11] [DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/japanese/README.md)]
Expand Down Expand Up @@ -159,11 +159,12 @@ dynamically link them at runtime.
## Contributed HW support
* DeepSpeed now support various HW accelerators.

| Contributor | Hardware | Accelerator Name | Contributor validated | Upstream validated |
| ----------- | -------- | ---------------- | --------------------- | ------------------ |
| Intel | Intel(R) Gaudi(R) 2 AI accelerator | hpu | Yes | Yes |
| Intel | Intel(R) Xeon(R) Processors | cpu | Yes | Yes |
| Intel | Intel(R) Data Center GPU Max series | xpu | Yes | Yes |
| Contributor | Hardware | Accelerator Name | Contributor validated | Upstream validated |
|-------------|-------------------------------------|------------------| --------------------- |--------------------|
| Huawei | Huawei Ascend NPU | npu | Yes | No |
| Intel | Intel(R) Gaudi(R) 2 AI accelerator | hpu | Yes | Yes |
| Intel | Intel(R) Xeon(R) Processors | cpu | Yes | Yes |
| Intel | Intel(R) Data Center GPU Max series | xpu | Yes | Yes |

## PyPI
We regularly push releases to [PyPI](https://pypi.org/project/deepspeed/) and encourage users to install from there in most cases.
Expand Down
17 changes: 17 additions & 0 deletions accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class DeepSpeedAccelerator(ABC):
def __init__(self):
self._name = None
self._communication_backend_name = None
self._compile_backend = None

@abc.abstractmethod
def is_synchronized_device(self):
Expand Down Expand Up @@ -287,3 +288,19 @@ def build_extension(self):
@abc.abstractmethod
def export_envs(self):
...

@abc.abstractmethod
def visible_devices_envs(self):
...

@abc.abstractmethod
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
...

@abc.abstractmethod
def get_compile_backend(self):
...

@abc.abstractmethod
def set_compile_backend(self, backend):
...
20 changes: 20 additions & 0 deletions accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class CPU_Accelerator(DeepSpeedAccelerator):

def __init__(self):
self._name = 'cpu'
self._compile_backend = "inductor"
if oneccl_imported_p:
self._communication_backend_name = 'ccl'
else:
Expand Down Expand Up @@ -322,3 +323,22 @@ def build_extension(self):

def export_envs(self):
return []

# TODO: cpu's visible envs is confirmed, keep as CUDA_VISIBLE_DEVICES
def visible_devices_envs(self):
return ['CUDA_VISIBLE_DEVICES']

def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
current_env[env] = ",".join(map(str, local_accelerator_ids))

def get_compile_backend(self):
return self._compile_backend

def set_compile_backend(self, backend):
supported_backends = torch._dynamo.list_backends(exclude_tags=())
if backend in supported_backends:
self._compile_backend = backend
else:
raise ValueError(
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
19 changes: 19 additions & 0 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'cuda'
self._communication_backend_name = 'nccl'
self._compile_backend = "inductor"
if pynvml is None:
self._init_pynvml()

Expand Down Expand Up @@ -360,3 +361,21 @@ def build_extension(self):

def export_envs(self):
return ['NCCL']

def visible_devices_envs(self):
return ['CUDA_VISIBLE_DEVICES']

def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
current_env[env] = ",".join(map(str, local_accelerator_ids))

def get_compile_backend(self):
return self._compile_backend

def set_compile_backend(self, backend):
supported_backends = torch._dynamo.list_backends(exclude_tags=())
if backend in supported_backends:
self._compile_backend = backend
else:
raise ValueError(
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
19 changes: 19 additions & 0 deletions accelerator/hpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class HPU_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'hpu'
self._communication_backend_name = 'hccl'
self._compile_backend = "hpu_backend"
try:
import habana_frameworks.torch.hpu as hpu
hpu.setDeterministic(True)
Expand Down Expand Up @@ -294,3 +295,21 @@ def build_extension(self):

def export_envs(self):
return []

def visible_devices_envs(self):
return ['HABANA_VISIBLE_MODULES']

def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
current_env[env] = ",".join(map(str, local_accelerator_ids))

def get_compile_backend(self):
return self._compile_backend

def set_compile_backend(self, backend):
supported_backends = torch._dynamo.list_backends(exclude_tags=())
if backend in supported_backends:
self._compile_backend = backend
else:
raise ValueError(
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
21 changes: 21 additions & 0 deletions accelerator/mps_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class MPS_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = "mps"
self._communication_backend_name = None
self._compile_backend = "inductor"

def is_synchronized_device(self):
return False
Expand Down Expand Up @@ -258,3 +259,23 @@ def build_extension(self):

def export_envs(self):
return []

# TODO: mpu's visible envs is confirmed, keep as CUDA_VISIBLE_DEVICES
def visible_devices_envs(self):
# TODO: could not find visible devices env for mps
return ['CUDA_VISIBLE_DEVICES']

def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
current_env[env] = ",".join(map(str, local_accelerator_ids))

def get_compile_backend(self):
return self._compile_backend

def set_compile_backend(self, backend):
supported_backends = torch._dynamo.list_backends(exclude_tags=())
if backend in supported_backends:
self._compile_backend = backend
else:
raise ValueError(
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
19 changes: 19 additions & 0 deletions accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self):
super().__init__()
self._name = 'npu'
self._communication_backend_name = 'hccl'
self._compile_backend = "inductor"
# dict that holds class name <--> class type mapping i.e.
# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
# this dict will be filled at init stage
Expand Down Expand Up @@ -278,3 +279,21 @@ def build_extension(self):

def export_envs(self):
return ['ASCEND', 'HCCL', 'LD_LIBRARY', 'PATH']

def visible_devices_envs(self):
return ['ASCEND_RT_VISIBLE_DEVICES']

def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
current_env[env] = ",".join(map(str, local_accelerator_ids))

def get_compile_backend(self):
return self._compile_backend

def set_compile_backend(self, backend):
supported_backends = torch._dynamo.list_backends(exclude_tags=())
if backend in supported_backends:
self._compile_backend = backend
else:
raise ValueError(
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends }")
19 changes: 19 additions & 0 deletions accelerator/xpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class XPU_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'xpu'
self._communication_backend_name = 'ccl'
self._compile_backend = "inductor"
self.aligned_tensors = []

def is_synchronized_device(self):
Expand Down Expand Up @@ -289,3 +290,21 @@ def build_extension(self):

def export_envs(self):
return []

def visible_devices_envs(self):
return ['ZE_AFFINITY_MASK']

def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
current_env[env] = ",".join(map(str, local_accelerator_ids))

def get_compile_backend(self):
return self._compile_backend

def set_compile_backend(self, backend):
supported_backends = torch._dynamo.list_backends(exclude_tags=())
if backend in supported_backends:
self._compile_backend = backend
else:
raise ValueError(
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
4 changes: 4 additions & 0 deletions build_win.bat
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
@echo off

set DS_BUILD_AIO=0
set DS_BUILD_CUTLASS_OPS=0
set DS_BUILD_EVOFORMER_ATTN=0
set DS_BUILD_FP_QUANTIZER=0
set DS_BUILD_RAGGED_DEVICE_OPS=0
set DS_BUILD_SPARSE_ATTN=0

echo Administrative permissions required. Detecting permissions...
Expand Down
77 changes: 56 additions & 21 deletions csrc/adam/multi_tensor_adam.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ typedef enum : int {

using MATH_T = float;

template <typename T>
template <typename T, typename index_t>
struct AdamFunctor {
__device__ __forceinline__ void operator()(int chunk_size,
volatile int* noop_gmem,
Expand All @@ -48,13 +48,13 @@ struct AdamFunctor {
// if(*noop_gmem == 1)
// return;

int tensor_loc = tl.block_to_tensor[blockIdx.x];
index_t tensor_loc = tl.block_to_tensor[blockIdx.x];

// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;

int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc];

T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
Expand All @@ -71,7 +71,8 @@ struct AdamFunctor {
n -= chunk_idx * chunk_size;

// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
for (index_t i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
Expand Down Expand Up @@ -146,23 +147,57 @@ void multi_tensor_adam_cuda(int chunk_size,
bias_correction2 = 1 - std::pow(beta2, step);
}

size_t max_size = 0;
bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) {
if (it2->numel() > max_size) {
max_size = it2->numel();
if (max_size >= INT_MAX) {
requires_64bit_indexing = true;
break;
}
}
}
if (requires_64bit_indexing) { break; }
}

// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
if (requires_64bit_indexing) {
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE,
(int64_t)chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, int64_t>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
} else {
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, int32_t>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
}

AT_CUDA_CHECK(cudaGetLastError());
}
Loading

0 comments on commit 0651cc3

Please sign in to comment.