diff --git a/.github/workflows/cpu-torch-latest.yml b/.github/workflows/cpu-torch-latest.yml
index 9c1ad02f75a6..5727ff2e1cde 100644
--- a/.github/workflows/cpu-torch-latest.yml
+++ b/.github/workflows/cpu-torch-latest.yml
@@ -50,5 +50,5 @@ jobs:
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
- TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.2"
- TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.2"
+ TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.3"
+ TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.3"
diff --git a/.github/workflows/nv-ds-chat.yml b/.github/workflows/nv-ds-chat.yml
index 94571eb101bb..cf8756fbd528 100644
--- a/.github/workflows/nv-ds-chat.yml
+++ b/.github/workflows/nv-ds-chat.yml
@@ -10,6 +10,10 @@ on:
required: false
default: 'master'
type: string
+ pull_request:
+ paths:
+ - "deepspeed/runtime/zero/stage_1_and_2.py"
+ - "deepspeed/runtime/zero/stage3.py"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml
index 3ca8ac43dfa4..3109f6060944 100644
--- a/.github/workflows/nv-torch-latest-v100.yml
+++ b/.github/workflows/nv-torch-latest-v100.yml
@@ -55,5 +55,5 @@ jobs:
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
- pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.2" --cuda_ver="11.8"
- pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.2" --cuda_ver="11.8"
+ pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.3" --cuda_ver="11.8"
+ pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.3" --cuda_ver="11.8"
diff --git a/MANIFEST.in b/MANIFEST.in
index ab79573ef96c..85e00695d648 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -3,7 +3,7 @@ include deepspeed/inference/v2/kernels/ragged_ops/libs/*.so
include deepspeed/inference/v2/kernels/cutlass_ops/libs/*.so
recursive-include requirements *.txt
recursive-include deepspeed *.cpp *.h *.cu *.hip *.tr *.cuh *.cc *.json
-recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc
+recursive-include csrc *.cpp *.h *.hpp *.cu *.tr *.cuh *.cc
recursive-include op_builder *.py
recursive-include benchmarks *.py
recursive-include accelerator *.py
diff --git a/README.md b/README.md
index a1335caa4949..f9d81eddfdae 100755
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@
DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat).
* [2024/03] [DeepSpeed-FP6:The power of FP6-Centric Serving for Large Language Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)]
-* [2024/01] [DeepSpeed-FastGen: Introducting Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19)
+* [2024/01] [DeepSpeed-FastGen: Introducing Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19)
* [2023/11] [Llama 2 Inference on 4th Gen Intel® Xeon® Scalable Processor with DeepSpeed](https://github.com/microsoft/DeepSpeed/tree/master/blogs/intel-inference) [[Intel version]](https://www.intel.com/content/www/us/en/developer/articles/technical/xllama-2-on-xeon-scalable-processor-with-deepspeed.html)
* [2023/11] [DeepSpeed ZeRO-Offload++: 6x Higher Training Throughput via Collaborative CPU/GPU Twin-Flow](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-offloadpp)
* [2023/11] [DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/japanese/README.md)]
@@ -159,11 +159,12 @@ dynamically link them at runtime.
## Contributed HW support
* DeepSpeed now support various HW accelerators.
-| Contributor | Hardware | Accelerator Name | Contributor validated | Upstream validated |
-| ----------- | -------- | ---------------- | --------------------- | ------------------ |
-| Intel | Intel(R) Gaudi(R) 2 AI accelerator | hpu | Yes | Yes |
-| Intel | Intel(R) Xeon(R) Processors | cpu | Yes | Yes |
-| Intel | Intel(R) Data Center GPU Max series | xpu | Yes | Yes |
+| Contributor | Hardware | Accelerator Name | Contributor validated | Upstream validated |
+|-------------|-------------------------------------|------------------| --------------------- |--------------------|
+| Huawei | Huawei Ascend NPU | npu | Yes | No |
+| Intel | Intel(R) Gaudi(R) 2 AI accelerator | hpu | Yes | Yes |
+| Intel | Intel(R) Xeon(R) Processors | cpu | Yes | Yes |
+| Intel | Intel(R) Data Center GPU Max series | xpu | Yes | Yes |
## PyPI
We regularly push releases to [PyPI](https://pypi.org/project/deepspeed/) and encourage users to install from there in most cases.
diff --git a/accelerator/abstract_accelerator.py b/accelerator/abstract_accelerator.py
index 3c5d799e293e..690f19292299 100644
--- a/accelerator/abstract_accelerator.py
+++ b/accelerator/abstract_accelerator.py
@@ -12,6 +12,7 @@ class DeepSpeedAccelerator(ABC):
def __init__(self):
self._name = None
self._communication_backend_name = None
+ self._compile_backend = None
@abc.abstractmethod
def is_synchronized_device(self):
@@ -287,3 +288,19 @@ def build_extension(self):
@abc.abstractmethod
def export_envs(self):
...
+
+ @abc.abstractmethod
+ def visible_devices_envs(self):
+ ...
+
+ @abc.abstractmethod
+ def set_visible_devices_envs(self, current_env, local_accelerator_ids):
+ ...
+
+ @abc.abstractmethod
+ def get_compile_backend(self):
+ ...
+
+ @abc.abstractmethod
+ def set_compile_backend(self, backend):
+ ...
diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py
index a0171723cfb8..bd11d034f312 100644
--- a/accelerator/cpu_accelerator.py
+++ b/accelerator/cpu_accelerator.py
@@ -20,6 +20,7 @@ class CPU_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'cpu'
+ self._compile_backend = "inductor"
if oneccl_imported_p:
self._communication_backend_name = 'ccl'
else:
@@ -322,3 +323,22 @@ def build_extension(self):
def export_envs(self):
return []
+
+ # TODO: cpu's visible envs is confirmed, keep as CUDA_VISIBLE_DEVICES
+ def visible_devices_envs(self):
+ return ['CUDA_VISIBLE_DEVICES']
+
+ def set_visible_devices_envs(self, current_env, local_accelerator_ids):
+ for env in self.visible_devices_envs():
+ current_env[env] = ",".join(map(str, local_accelerator_ids))
+
+ def get_compile_backend(self):
+ return self._compile_backend
+
+ def set_compile_backend(self, backend):
+ supported_backends = torch._dynamo.list_backends(exclude_tags=())
+ if backend in supported_backends:
+ self._compile_backend = backend
+ else:
+ raise ValueError(
+ f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py
index 3d5e9c168c16..60d66b6cdbab 100644
--- a/accelerator/cuda_accelerator.py
+++ b/accelerator/cuda_accelerator.py
@@ -25,6 +25,7 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'cuda'
self._communication_backend_name = 'nccl'
+ self._compile_backend = "inductor"
if pynvml is None:
self._init_pynvml()
@@ -360,3 +361,21 @@ def build_extension(self):
def export_envs(self):
return ['NCCL']
+
+ def visible_devices_envs(self):
+ return ['CUDA_VISIBLE_DEVICES']
+
+ def set_visible_devices_envs(self, current_env, local_accelerator_ids):
+ for env in self.visible_devices_envs():
+ current_env[env] = ",".join(map(str, local_accelerator_ids))
+
+ def get_compile_backend(self):
+ return self._compile_backend
+
+ def set_compile_backend(self, backend):
+ supported_backends = torch._dynamo.list_backends(exclude_tags=())
+ if backend in supported_backends:
+ self._compile_backend = backend
+ else:
+ raise ValueError(
+ f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py
index 30b115e8b1ab..114f367e879d 100644
--- a/accelerator/hpu_accelerator.py
+++ b/accelerator/hpu_accelerator.py
@@ -16,6 +16,7 @@ class HPU_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'hpu'
self._communication_backend_name = 'hccl'
+ self._compile_backend = "hpu_backend"
try:
import habana_frameworks.torch.hpu as hpu
hpu.setDeterministic(True)
@@ -294,3 +295,21 @@ def build_extension(self):
def export_envs(self):
return []
+
+ def visible_devices_envs(self):
+ return ['HABANA_VISIBLE_MODULES']
+
+ def set_visible_devices_envs(self, current_env, local_accelerator_ids):
+ for env in self.visible_devices_envs():
+ current_env[env] = ",".join(map(str, local_accelerator_ids))
+
+ def get_compile_backend(self):
+ return self._compile_backend
+
+ def set_compile_backend(self, backend):
+ supported_backends = torch._dynamo.list_backends(exclude_tags=())
+ if backend in supported_backends:
+ self._compile_backend = backend
+ else:
+ raise ValueError(
+ f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
diff --git a/accelerator/mps_accelerator.py b/accelerator/mps_accelerator.py
index 972b33caece1..5fc9b1c8cfb6 100644
--- a/accelerator/mps_accelerator.py
+++ b/accelerator/mps_accelerator.py
@@ -20,6 +20,7 @@ class MPS_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = "mps"
self._communication_backend_name = None
+ self._compile_backend = "inductor"
def is_synchronized_device(self):
return False
@@ -258,3 +259,23 @@ def build_extension(self):
def export_envs(self):
return []
+
+ # TODO: mpu's visible envs is confirmed, keep as CUDA_VISIBLE_DEVICES
+ def visible_devices_envs(self):
+ # TODO: could not find visible devices env for mps
+ return ['CUDA_VISIBLE_DEVICES']
+
+ def set_visible_devices_envs(self, current_env, local_accelerator_ids):
+ for env in self.visible_devices_envs():
+ current_env[env] = ",".join(map(str, local_accelerator_ids))
+
+ def get_compile_backend(self):
+ return self._compile_backend
+
+ def set_compile_backend(self, backend):
+ supported_backends = torch._dynamo.list_backends(exclude_tags=())
+ if backend in supported_backends:
+ self._compile_backend = backend
+ else:
+ raise ValueError(
+ f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py
index 472157e32c02..b0e0ff948e52 100644
--- a/accelerator/npu_accelerator.py
+++ b/accelerator/npu_accelerator.py
@@ -20,6 +20,7 @@ def __init__(self):
super().__init__()
self._name = 'npu'
self._communication_backend_name = 'hccl'
+ self._compile_backend = "inductor"
# dict that holds class name <--> class type mapping i.e.
# 'AsyncIOBuilder':
# this dict will be filled at init stage
@@ -278,3 +279,21 @@ def build_extension(self):
def export_envs(self):
return ['ASCEND', 'HCCL', 'LD_LIBRARY', 'PATH']
+
+ def visible_devices_envs(self):
+ return ['ASCEND_RT_VISIBLE_DEVICES']
+
+ def set_visible_devices_envs(self, current_env, local_accelerator_ids):
+ for env in self.visible_devices_envs():
+ current_env[env] = ",".join(map(str, local_accelerator_ids))
+
+ def get_compile_backend(self):
+ return self._compile_backend
+
+ def set_compile_backend(self, backend):
+ supported_backends = torch._dynamo.list_backends(exclude_tags=())
+ if backend in supported_backends:
+ self._compile_backend = backend
+ else:
+ raise ValueError(
+ f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends }")
diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py
index 3f65263946ab..9c4a9c903f96 100644
--- a/accelerator/xpu_accelerator.py
+++ b/accelerator/xpu_accelerator.py
@@ -14,6 +14,7 @@ class XPU_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'xpu'
self._communication_backend_name = 'ccl'
+ self._compile_backend = "inductor"
self.aligned_tensors = []
def is_synchronized_device(self):
@@ -289,3 +290,21 @@ def build_extension(self):
def export_envs(self):
return []
+
+ def visible_devices_envs(self):
+ return ['ZE_AFFINITY_MASK']
+
+ def set_visible_devices_envs(self, current_env, local_accelerator_ids):
+ for env in self.visible_devices_envs():
+ current_env[env] = ",".join(map(str, local_accelerator_ids))
+
+ def get_compile_backend(self):
+ return self._compile_backend
+
+ def set_compile_backend(self, backend):
+ supported_backends = torch._dynamo.list_backends(exclude_tags=())
+ if backend in supported_backends:
+ self._compile_backend = backend
+ else:
+ raise ValueError(
+ f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
diff --git a/build_win.bat b/build_win.bat
index ec8c8a362a78..af5c5103fa4b 100644
--- a/build_win.bat
+++ b/build_win.bat
@@ -1,6 +1,10 @@
@echo off
set DS_BUILD_AIO=0
+set DS_BUILD_CUTLASS_OPS=0
+set DS_BUILD_EVOFORMER_ATTN=0
+set DS_BUILD_FP_QUANTIZER=0
+set DS_BUILD_RAGGED_DEVICE_OPS=0
set DS_BUILD_SPARSE_ATTN=0
echo Administrative permissions required. Detecting permissions...
diff --git a/csrc/adam/multi_tensor_adam.cu b/csrc/adam/multi_tensor_adam.cu
index d6b9b2f70710..a1fc7d15aec9 100644
--- a/csrc/adam/multi_tensor_adam.cu
+++ b/csrc/adam/multi_tensor_adam.cu
@@ -30,7 +30,7 @@ typedef enum : int {
using MATH_T = float;
-template
+template
struct AdamFunctor {
__device__ __forceinline__ void operator()(int chunk_size,
volatile int* noop_gmem,
@@ -48,13 +48,13 @@ struct AdamFunctor {
// if(*noop_gmem == 1)
// return;
- int tensor_loc = tl.block_to_tensor[blockIdx.x];
+ index_t tensor_loc = tl.block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
- int chunk_idx = tl.block_to_chunk[blockIdx.x];
- int n = tl.sizes[tensor_loc];
+ index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
+ index_t n = tl.sizes[tensor_loc];
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
@@ -71,7 +71,8 @@ struct AdamFunctor {
n -= chunk_idx * chunk_size;
// see note in multi_tensor_scale_kernel.cu
- for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
+ for (index_t i_start = 0; i_start < n && i_start < chunk_size;
+ i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
@@ -146,23 +147,57 @@ void multi_tensor_adam_cuda(int chunk_size,
bias_correction2 = 1 - std::pow(beta2, step);
}
+ size_t max_size = 0;
+ bool requires_64bit_indexing = false;
+ for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
+ for (auto it2 = it->begin(); it2 != it->end(); it2++) {
+ if (it2->numel() > max_size) {
+ max_size = it2->numel();
+ if (max_size >= INT_MAX) {
+ requires_64bit_indexing = true;
+ break;
+ }
+ }
+ }
+ if (requires_64bit_indexing) { break; }
+ }
+
// Assume single type across p,g,m1,m2 now
- DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
- 0,
- "adam",
- multi_tensor_apply<4>(BLOCK_SIZE,
- chunk_size,
- noop_flag,
- tensor_lists,
- AdamFunctor(),
- beta1,
- beta2,
- bias_correction1,
- bias_correction2,
- epsilon,
- lr,
- (adamMode_t)mode,
- weight_decay);)
+ if (requires_64bit_indexing) {
+ DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
+ 0,
+ "adam",
+ multi_tensor_apply<4>((int64_t)BLOCK_SIZE,
+ (int64_t)chunk_size,
+ noop_flag,
+ tensor_lists,
+ AdamFunctor(),
+ beta1,
+ beta2,
+ bias_correction1,
+ bias_correction2,
+ epsilon,
+ lr,
+ (adamMode_t)mode,
+ weight_decay);)
+ } else {
+ DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
+ 0,
+ "adam",
+ multi_tensor_apply<4>(BLOCK_SIZE,
+ chunk_size,
+ noop_flag,
+ tensor_lists,
+ AdamFunctor(),
+ beta1,
+ beta2,
+ bias_correction1,
+ bias_correction2,
+ epsilon,
+ lr,
+ (adamMode_t)mode,
+ weight_decay);)
+ }
AT_CUDA_CHECK(cudaGetLastError());
}
diff --git a/csrc/adam/multi_tensor_apply.cuh b/csrc/adam/multi_tensor_apply.cuh
index 12f41cb49c6b..342376c141be 100644
--- a/csrc/adam/multi_tensor_apply.cuh
+++ b/csrc/adam/multi_tensor_apply.cuh
@@ -35,7 +35,7 @@ struct TensorListMetadata {
};
template
-__global__ void multi_tensor_apply_kernel(int chunk_size,
+__global__ void multi_tensor_apply_kernel(int64_t chunk_size,
volatile int* noop_flag,
T tl,
U callable,
@@ -46,8 +46,8 @@ __global__ void multi_tensor_apply_kernel(int chunk_size,
}
template
-void multi_tensor_apply(int block_size,
- int chunk_size,
+void multi_tensor_apply(int64_t block_size,
+ int64_t chunk_size,
const at::Tensor& noop_flag,
const std::vector>& tensor_lists,
T callable,
@@ -91,9 +91,9 @@ void multi_tensor_apply(int block_size,
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;
- int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
+ auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
- for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
+ for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) {
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
diff --git a/csrc/fp_quantizer/includes/quantize.h b/csrc/fp_quantizer/includes/quantize.h
index 2204c1ba74fc..507252d6e722 100644
--- a/csrc/fp_quantizer/includes/quantize.h
+++ b/csrc/fp_quantizer/includes/quantize.h
@@ -113,3 +113,14 @@ void launch_dequantization(uint8_t* val,
int q_mantisa_bits,
int q_exponent_bits,
cudaStream_t stream);
+
+template
+void launch_selective_dequantization(uint8_t* val,
+ T* q_val,
+ int32_t* indexes,
+ int num_groups,
+ int group_size,
+ int num_indexes,
+ int q_mantisa_bits,
+ int q_exponent_bits,
+ cudaStream_t stream);
diff --git a/csrc/fp_quantizer/quantize.cpp b/csrc/fp_quantizer/quantize.cpp
index 4a88ff767636..ec631c576e27 100644
--- a/csrc/fp_quantizer/quantize.cpp
+++ b/csrc/fp_quantizer/quantize.cpp
@@ -78,8 +78,39 @@ void dequantize(torch::Tensor& val,
#endif
}
+#define DISPATCH_DEQUANTIZE_INDEX(T_TYPE, C_TYPE, mantisa) \
+ if (val.options().dtype() == torch::T_TYPE) { \
+ launch_selective_dequantization((uint8_t*)val_q.data_ptr(), \
+ (C_TYPE*)val.data_ptr(), \
+ (int32_t*)indexes.data_ptr(), \
+ num_groups, \
+ group_size, \
+ num_indexes, \
+ q_mantisa_bits, \
+ q_exponent_bits, \
+ at::cuda::getCurrentCUDAStream()); \
+ return; \
+ }
+void selective_dequantize(torch::Tensor& val,
+ torch::Tensor& val_q,
+ torch::Tensor& indexes,
+ int group_size,
+ int q_mantisa_bits,
+ int q_exponent_bits)
+{
+ int total_elems = at::numel(val);
+ int num_indexes = indexes.size(0);
+ int num_groups = total_elems / group_size;
+
+ DISPATCH_DEQUANTIZE_INDEX(kHalf, __half, 10);
+#ifdef BF16_AVAILABLE
+ DISPATCH_DEQUANTIZE_INDEX(kBFloat16, __nv_bfloat16, 7);
+#endif
+}
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("quantize", &quantize, "quantize function");
m.def("dequantize", &dequantize, "dequantize function");
+ m.def("selective_dequantize", &selective_dequantize, "selective dequantize function");
}
diff --git a/csrc/fp_quantizer/quantize.cu b/csrc/fp_quantizer/quantize.cu
index 5f0b58f124f0..5ada6894747f 100644
--- a/csrc/fp_quantizer/quantize.cu
+++ b/csrc/fp_quantizer/quantize.cu
@@ -270,6 +270,7 @@ __global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int
mem_access::load_global(
int8_data + quantization::quanitzed_access_granularity_6bits * 2,
load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2);
+
} else {
mem_access::load_global(int8_data,
load_base_ptr);
@@ -393,3 +394,137 @@ void launch_dequantization(uint8_t* val,
INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7);
#endif
INSTANTIATE_LAUNCH_DEQUANTIZATION(__half, 10);
+
+template
+__global__ void apply_selective_dequantization(uint8_t* val,
+ T* q_val,
+ int32_t* indexes,
+ int group_size,
+ int total_num_elements)
+{
+ int index = indexes[blockIdx.x];
+ constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
+ int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size;
+ int input_index = index * total_num_elements + tidx;
+ constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
+ constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
+ constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
+ constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
+ constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits);
+ const uint32_t g_index = (input_index / group_size);
+ const uint32_t group_size_bytes = (group_size * quantized_bits / 8);
+ const uint8_t* load_base_ptr =
+ val + g_index * (group_size_bytes + 4) + (input_index % group_size) * quantized_bits / 8;
+
+ int mantisa_mask = ((1 << q_mantisa_bits) - 1);
+ mantisa_mask <<= (_mantisa_bits - q_mantisa_bits);
+
+ T* store_base_ptr = q_val + tidx + blockIdx.x * total_num_elements;
+ float scale;
+
+ uint8_t* scale_as_int8 = reinterpret_cast(&scale);
+ if (quantized_bits == 6) {
+ mem_access::load_global(
+ scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);
+ mem_access::load_global(
+ scale_as_int8 + quantization::quanitzed_access_granularity_6bits,
+ val + g_index * (group_size_bytes + 4) + group_size_bytes +
+ quantization::quanitzed_access_granularity_6bits);
+ } else
+ mem_access::load_global(
+ scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);
+
+ if (tidx < total_num_elements) {
+ uint64_t q_buf_in;
+ uint64_t q_buf_in1;
+ uint8_t* int8_data = reinterpret_cast(&q_buf_in);
+ uint8_t* int8_data1 = reinterpret_cast(&q_buf_in1);
+ if (quantized_bits == 6) {
+ mem_access::load_global(
+ int8_data, load_base_ptr);
+ mem_access::load_global(
+ int8_data + quantization::quanitzed_access_granularity_6bits,
+ load_base_ptr + quantization::quanitzed_access_granularity_6bits);
+ mem_access::load_global(
+ int8_data + quantization::quanitzed_access_granularity_6bits * 2,
+ load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2);
+ } else {
+ mem_access::load_global(int8_data,
+ load_base_ptr);
+ if (quantized_bits > 4) {
+ mem_access::load_global(
+ int8_data + quantization::quanitzed_access_granularity,
+ load_base_ptr + quantization::quanitzed_access_granularity);
+ if (quantized_bits == 12) {
+ mem_access::load_global(
+ int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2);
+ }
+ }
+ }
+ T store_buf[vector_size];
+ uint16_t* q_buf = reinterpret_cast(store_buf);
+#pragma unroll
+ for (int j = 0; j < vector_size; j++) {
+ uint16_t new_data;
+ if (j < 5 || quantized_bits != 12) {
+ new_data = (uint16_t)(q_buf_in >> (j * quantized_bits));
+ } else {
+ if (j == 5) {
+ new_data = (uint16_t)(q_buf_in1);
+ new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60));
+ } else
+ new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
+ }
+
+ uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
+ uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
+ uint16_t dst_mantisa = (new_data & _mantisa_mask);
+
+ if (dst_exponent != (1 << q_exponent_bits) - 1)
+ dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
+ (1 << (q_exponent_bits - 1)) - 1;
+
+ q_buf[j] =
+ ((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) |
+ (dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
+ float up_cast = conversion::to(store_buf[j]);
+ store_buf[j] = conversion::to(up_cast * scale);
+ }
+ mem_access::store_global(store_base_ptr, store_buf);
+ }
+}
+
+template
+void launch_selective_dequantization(uint8_t* val,
+ T* q_val,
+ int32_t* indexes,
+ int num_groups,
+ int group_size,
+ int num_indexes,
+ int q_mantisa_bits,
+ int q_exponent_bits,
+ cudaStream_t stream)
+{
+ int total_elements_per_index = (num_groups / num_indexes) * group_size;
+ int blocks = (total_elements_per_index - 1) /
+ (quantization::threads * (quantization::access_granularity / sizeof(T))) +
+ 1;
+ const dim3 grid(num_indexes, blocks);
+ const dim3 block(quantization::threads);
+ DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] {
+ apply_selective_dequantization
+ <<>>(val, q_val, indexes, group_size, total_elements_per_index);
+ });
+}
+#define INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(T, mantisa) \
+ template void launch_selective_dequantization( \
+ uint8_t*, T*, int32_t*, int, int, int, int, int, cudaStream_t);
+// fp8(E4M3)
+#ifdef BF16_AVAILABLE
+INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__nv_bfloat16, 7);
+#endif
+INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__half, 10);
diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py
index d5eca81c804f..b1a8276589b6 100755
--- a/deepspeed/checkpoint/ds_to_universal.py
+++ b/deepspeed/checkpoint/ds_to_universal.py
@@ -10,7 +10,7 @@
import argparse
import glob
import itertools
-import multiprocessing
+from concurrent.futures import ProcessPoolExecutor
import os
import re
import shutil
@@ -132,6 +132,10 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):
cnt = 0
+def dp_index_to_str(dp_index):
+ return f"{dp_index:0>2d}"
+
+
def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel):
global cnt # temp hack
@@ -140,9 +144,8 @@ def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor,
os.makedirs(param_base_path, exist_ok=True)
cnt += 1
- counter = f"{dp_index:0>2d}"
- path = os.path.join(param_base_path, f"{state_name}.{counter}")
+ path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}")
#print(f"{param_name}: {offset}: {numel} => {path}")
@@ -156,10 +159,21 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
slices = []
for tp_index in range(tp_degree):
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
- paths = sorted(list(glob.glob(f"{prefix_path}.*")))
+ paths = glob.glob(f"{prefix_path}.*")
+
if len(paths) == 0:
continue
+ pattern = re.compile(f"{prefix_path}\\.([0-9]+)")
+ dp_indices = set()
+ for p in paths:
+ m = pattern.match(p)
+ if m:
+ dp_indices.add(int(m.group(1)))
+ else:
+ raise ValueError(f"Cannot parse dp_rank from {p}")
+
+ paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))]
shards = [torch.load(p) for p in paths]
if state == "step":
@@ -278,27 +292,18 @@ def get_matched_sub_params_pattern(name_):
return unmatched_patterns
-def _get_chunks(l, n):
- for i in range(0, len(l), n):
- yield l[i:i + n]
-
-
def _do_parallel_work(do_work, work_chunks, num_workers):
+ results = []
if num_workers > 1:
- pool = multiprocessing.Pool(num_workers)
- results = []
- for batch in tqdm.tqdm(work_chunks):
- res = pool.map(do_work, batch)
- results.extend(res)
- pool.close()
- pool.join()
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
+ future_list = [executor.submit(do_work, work) for work in work_chunks]
+ for f in tqdm.tqdm(future_list):
+ results.append(f.result())
else:
# No parallel pass for unit testing
# We can't create child processes in tests
- results = []
- for batch in tqdm.tqdm(work_chunks):
- res = [do_work(x) for x in batch]
- results.extend(res)
+ for work in tqdm.tqdm(work_chunks):
+ results.append(do_work(work))
return results
@@ -307,20 +312,15 @@ def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree),
range(ds_checkpoint.dp_degree)))
#pprint(f'{_3d_range_list=}')
- work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers))
- #pprint(f'{work_chunks=}')
- # extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0])
do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint)
- _do_parallel_work(do_work, work_chunks, args.num_extract_workers)
+ _do_parallel_work(do_work, _3d_range_list, args.num_extract_workers)
def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
- work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers))
- #pprint(work_chunks)
zero_output_folder = os.path.join(args.output_folder, "zero")
do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree)
- unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_merge_workers)
+ unmatched_patterns_lists = _do_parallel_work(do_work, list(slice_shapes.items()), args.num_merge_workers)
# verify that all patterns were used
# if a pattern was not used by any of the workers, then it was not used at all -> assert/alert
diff --git a/deepspeed/launcher/launch.py b/deepspeed/launcher/launch.py
index ffb9cbc18fbd..079182a3b45b 100755
--- a/deepspeed/launcher/launch.py
+++ b/deepspeed/launcher/launch.py
@@ -22,6 +22,7 @@
from collections import defaultdict
from typing import Dict
from argparse import ArgumentParser, REMAINDER
+from deepspeed.accelerator import get_accelerator
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..nebula.constants import DLTS_POD_ENV_PATH
from ..utils import logger, get_numactl_cmd
@@ -146,8 +147,8 @@ def main():
node_list = list(world_info.keys())
args.nnodes = len(node_list)
local_node = node_list[args.node_rank]
- local_gpu_ids = world_info[local_node]
- num_local_procs = len(local_gpu_ids)
+ local_accelerator_ids = world_info[local_node]
+ num_local_procs = len(local_accelerator_ids)
logger.info(f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}")
global_rank_mapping = defaultdict(list)
@@ -161,8 +162,10 @@ def main():
curr_global_rank += 1
logger.info(f"global_rank_mapping={global_rank_mapping}")
logger.info(f"dist_world_size={dist_world_size}")
- current_env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, local_gpu_ids))
- logger.info(f"Setting CUDA_VISIBLE_DEVICES={current_env['CUDA_VISIBLE_DEVICES']}")
+
+ get_accelerator().set_visible_devices_envs(current_env, local_accelerator_ids)
+ for env in get_accelerator().visible_devices_envs():
+ logger.info(f"Setting {env}={current_env[env]}")
# set PyTorch distributed related environmental variables
current_env["MASTER_ADDR"] = args.master_addr
diff --git a/deepspeed/linear/__init__.py b/deepspeed/linear/__init__.py
new file mode 100644
index 000000000000..a27f1c3eaee7
--- /dev/null
+++ b/deepspeed/linear/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from .optimized_linear import OptimizedLinear
+from .config import LoRAConfig, QuantizationConfig
diff --git a/deepspeed/linear/config.py b/deepspeed/linear/config.py
new file mode 100644
index 000000000000..ae9050a3c92b
--- /dev/null
+++ b/deepspeed/linear/config.py
@@ -0,0 +1,39 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from dataclasses import dataclass
+
+
+@dataclass
+class LoRAConfig:
+ """
+ Configuration settings for LoRAOptimizedLinear.
+
+ Attributes:
+ lora_r (int): LoRA attention dimension, also know as the rank. Defaults is 64.
+ lora_alpha (float): LoRA scaling factor, default is 16.
+ base_weight_sharding (int): The degree to which the base weights are sharded,
+ should typically be set to the data-parallel world size to maximize the memory
+ reduction benefits. Defaults to 1, which means this feature is disabled.
+ """
+ lora_r: int = 64
+ lora_alpha: float = 16.
+ base_weight_sharding: int = 1
+
+
+@dataclass
+class QuantizationConfig:
+ """
+ Configuration settings for quantization for LoRAOptimizedLinear, QuantizedLinear,
+ and QuantizedParameter
+
+ Attributes:
+ q_bits (int): The number of bits used for quantization. Default is 8.
+ mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3.
+ group_size (int): The size of the group used for quantization. Default is 512.
+ """
+ q_bits: int = 8
+ mantissa_bits: int = 3
+ group_size: int = 512
diff --git a/deepspeed/linear/optimized_linear.py b/deepspeed/linear/optimized_linear.py
new file mode 100644
index 000000000000..e982785a8122
--- /dev/null
+++ b/deepspeed/linear/optimized_linear.py
@@ -0,0 +1,150 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import torch
+import math
+import torch.nn as nn
+import torch.nn.functional as F
+from dataclasses import is_dataclass
+from deepspeed.accelerator import get_accelerator
+import deepspeed.comm as dist
+
+from .config import LoRAConfig, QuantizationConfig
+from .quantization import QuantizedParameter, QuantizedLinear
+
+
+class OptimizedLinear(nn.Module):
+ """
+ Optimized version of nn.Linear that adds features such as:
+ * LoRA w. base weight sharding
+ * FP [6,8,12] quantization
+
+ Arguments:
+ input_dim: Required: size of each input sample
+ output_dim: Required: size of each output sample
+ bias: Optional: If set to False, the layer will not learn an additive bias. Default: False
+ lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree
+ quantization_config: Optional: QuantizationConfig defining quantization features
+ dtype: Optional: parameter dtype, only supports bfloat16 currently
+
+ Returns:
+ Returns a new nn.Module depending on the input config. Either native
+ torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear.
+ """
+
+ def __new__(self,
+ input_dim: int,
+ output_dim: int,
+ bias: bool = False,
+ lora_config: LoRAConfig = None,
+ quantization_config: QuantizationConfig = None,
+ dtype=torch.bfloat16):
+
+ if quantization_config is not None and not is_dataclass(quantization_config):
+ raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}")
+ if lora_config is not None and not is_dataclass(lora_config):
+ raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}")
+ if lora_config is None and quantization_config is None:
+ # Everything disabled, fall back to normal nn.Linear
+ self = nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype)
+
+ elif lora_config:
+ # lora enabled, quantization may or may not be
+ self = LoRAOptimizedLinear(input_dim=input_dim,
+ output_dim=output_dim,
+ bias=bias,
+ lora_config=lora_config,
+ quantization_config=quantization_config,
+ dtype=dtype)
+
+ elif quantization_config:
+ # only quantization enabled, no lora
+ self = QuantizedLinear(input_dim=input_dim,
+ output_dim=output_dim,
+ bias=bias,
+ quantization_config=quantization_config,
+ dtype=dtype)
+ return self
+
+
+class LoRAOptimizedLinear(nn.Module):
+
+ def __init__(self,
+ input_dim: int,
+ output_dim: int,
+ bias: bool = False,
+ lora_config: LoRAConfig = None,
+ quantization_config: QuantizationConfig = None,
+ device=None,
+ dtype=torch.bfloat16):
+ super().__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.bias = bias
+ self.lora_config = lora_config
+ self.quantization_config = quantization_config
+ device = get_accelerator().current_device_name() if device is None else device
+ assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config"
+
+ self.zero_shards = self.lora_config.base_weight_sharding
+ self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards)
+ w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype))
+ torch.nn.init.xavier_uniform_(w)
+
+ if self.quantization_config is not None:
+ assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization"
+ self.base_weight = QuantizedParameter(w, quantization_config=quantization_config)
+ else:
+ self.base_weight = w
+
+ self.base_weight.requires_grad = False
+
+ # Use RS lora for now.
+ self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r)
+ # Keeping lora weights in bf16 precision for ease of training.
+ self.lora_weight_1 = nn.Linear(self.input_dim,
+ self.lora_config.lora_r,
+ bias=self.bias,
+ device=device,
+ dtype=dtype)
+ self.lora_weight_2 = nn.Linear(self.lora_config.lora_r,
+ self.output_dim,
+ bias=self.bias,
+ device=device,
+ dtype=dtype)
+ self.lora_weight_1.weight.requires_grad = True
+ self.lora_weight_2.weight.requires_grad = True
+
+ def full_weight(self):
+ # This assumes weights are evenly sharded across gpus. which might not be correct.
+ # in that case, we should flatten before all_gather.
+ local_weight = self.base_weight.dequantized() if isinstance(self.base_weight,
+ QuantizedParameter) else self.base_weight
+ tensor_list = [
+ torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype)
+ for _ in range(self.zero_shards)
+ ]
+ dist.all_gather(tensor_list, local_weight)
+ weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1))
+ return weight
+
+ def linear_without_F_linear(self, input, weight):
+ output = torch.mm(input.reshape(-1, input.shape[-1]), weight)
+ output = output.view(*input.shape[:-1], weight.shape[1])
+ return output
+
+ def forward(self, input_tensor):
+ # Gather the sharded base weight
+ if self.zero_shards > 1:
+ with torch.no_grad():
+ base_weight = self.full_weight()
+ elif self.quantization_config:
+ base_weight = self.base_weight.dequantized()
+ else:
+ base_weight = self.base_weight
+
+ base_weight_output = F.linear(input_tensor, base_weight)
+ lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor))
+ return base_weight_output + self.lora_scaling_factor * lora_output
diff --git a/deepspeed/linear/quantization.py b/deepspeed/linear/quantization.py
new file mode 100644
index 000000000000..f5343af45fb8
--- /dev/null
+++ b/deepspeed/linear/quantization.py
@@ -0,0 +1,137 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import copy
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Optional
+
+from deepspeed.accelerator import get_accelerator
+from deepspeed.ops.fp_quantizer import Quantizer, FP_Quantize
+from .config import QuantizationConfig
+
+
+class QuantizedParameter(nn.Parameter):
+ """
+ Quantized parameter class that implements weight quantization. Weights
+ are stored in quantized form on GPUs, and can be dequantized on-the-fly when
+ needed by the model. The weights are actually quantized during any `.to(device)`.
+
+ Arguments:
+ data (Tensor): parameter tensor.
+ requires_grad (bool, optional): if the parameter requires gradient. Defaults
+ to False and is not supported to be True. Argument provided only for interface
+ compatibility with torch.nn.Parameter.
+ quantization_config (QuantizationConfig, optional):
+ quantizer (Quantizer, optional): Defaults to FP_Quantize but can be any quantizer
+ that implements deepspeed.ops.fp_quantizer.Quantizer. This argument is also
+ required since the quantizer is stashed in the Parameter itself, some models
+ may clone the Parameter by passing an attribute __dict__. For an example, see
+ tests/unit/linear/test_quant_param.py::TestQuantParam::test_hf_clone
+ """
+
+ def __new__(
+ cls,
+ data: Optional[torch.Tensor] = None,
+ requires_grad: bool = False, # quantized weights must be frozen
+ quantization_config: QuantizationConfig = None,
+ quantizer: Quantizer = None,
+ ):
+ if requires_grad:
+ raise ValueError(f"requires_grad=True is not supported with QuantizedParameter")
+ if data is None:
+ data = torch.empty(0)
+ self = torch.Tensor._make_subclass(cls, data, requires_grad)
+ self.quantization_config = QuantizationConfig() if quantization_config is None else quantization_config
+ if quantizer is not None:
+ self.quantizer = quantizer
+ else:
+ # if FPQuantizerBuilder is not compatible in this env this init will fail
+ self.quantizer = FP_Quantize(group_size=self.quantization_config.group_size)
+ self._ensure_quantized(self)
+ return self
+
+ def _ensure_quantized(self, tensor: torch.Tensor):
+ # If the tensor is on the accelerator and is not quantized, then quantize it in-place.
+ if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.int8:
+ with get_accelerator().stream(get_accelerator().current_stream(tensor.device)):
+ tensor.data = self.quantizer.quantize(tensor.data,
+ q_bits=self.quantization_config.q_bits,
+ q_mantisa_bits=self.quantization_config.mantissa_bits)
+ assert tensor.dtype == torch.int8
+
+ def dequantized(self) -> torch.Tensor:
+ """
+ Return a tensor containing the dequantized weights of this parameter.
+ """
+ if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.int8:
+ with get_accelerator().stream(get_accelerator().current_stream(self.data.device)):
+ return self.quantizer.dequantize(self.data,
+ q_bits=self.quantization_config.q_bits,
+ q_mantisa_bits=self.quantization_config.mantissa_bits)
+ return self.data
+
+ def __getstate__(self):
+ state = self.__dict__
+ state["data"] = self.data
+ state["quantization_config"] = self.quantization_config
+ state["requires_grad"] = self.requires_grad
+ return state
+
+ def __setstate__(self, state):
+ self.quantizer = state["quantizer"]
+ self.quantization_config = state["quantization_config"]
+ self.data = state["data"]
+ self.requires_grad = state["requires_grad"]
+
+ def __deepcopy__(self, memo):
+ new_instance = type(self).__new__(type(self))
+ state = self.__getstate__()
+ new_instance.__setstate__(state)
+ new_instance.quantizer = copy.deepcopy(state["quantizer"])
+ new_instance.quantization_config = copy.deepcopy(state["quantization_config"])
+ new_instance.data = copy.deepcopy(state["data"])
+ return new_instance
+
+ def __copy__(self):
+ new_instance = type(self).__new__(type(self))
+ state = self.__getstate__()
+ new_instance.__setstate__(state)
+ return new_instance
+
+ def cuda(self, device=None, non_blocking=False):
+ return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
+
+ def to(self, *args, **kwargs):
+ """
+ Move the parameter to the given device. Then, if the device is a cuda device,
+ quantize it.
+ """
+ tensor = super().to(*args, **kwargs)
+ self._ensure_quantized(tensor)
+ return tensor
+
+
+class QuantizedLinear(nn.Linear):
+ """
+ Linear layer that implements weight quantization. Parameters
+ are stored via `QuantizedParameter` and are dequantized on-the-fly during any
+ forward pass.
+ """
+
+ def __init__(self,
+ input_dim: int,
+ output_dim: int,
+ bias: bool = False,
+ quantization_config: QuantizationConfig = None,
+ dtype=torch.bfloat16):
+ super().__init__(input_dim, output_dim, bias=bias, dtype=dtype)
+ assert dtype == torch.bfloat16, "currently only supports bfloat16 dtype"
+ self.weight = QuantizedParameter(self.weight.data, quantization_config=quantization_config)
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.linear(input, self.weight.dequantized(), self.bias)
diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py
index 88f7086518e8..4944e1954e37 100644
--- a/deepspeed/module_inject/auto_tp.py
+++ b/deepspeed/module_inject/auto_tp.py
@@ -13,7 +13,7 @@
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from deepspeed.accelerator import get_accelerator
-from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw
+from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_chunk_mlp
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
@@ -133,7 +133,8 @@ def is_load_module(module):
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
- "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm"
+ "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding",
+ "Phi3RMSNorm"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names
@@ -328,6 +329,10 @@ def _replace(self, child, name, conv_linear_layer):
# For mixtral-7x8b, need to skip MoE gate linear replace.
if name == "block_sparse_moe.gate":
return child
+ # for phi3.
+ if 'gate_up_proj' in name:
+ weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
+ return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]
diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py
index cf087c16da8a..33d36fbfae54 100644
--- a/deepspeed/module_inject/fusedqkv_utils.py
+++ b/deepspeed/module_inject/fusedqkv_utils.py
@@ -4,7 +4,7 @@
# DeepSpeed Team
import torch
from deepspeed.utils.logging import warning_once
-from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd
+from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd, get_num_attention_heads
def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):
@@ -42,6 +42,7 @@ def prepare_tp_fused_qkvw(module, src, mp_size, gpu_index):
"FalconDecoderLayer": 'bloomtype',
"GPTBigCodeBlock": 'bigcodetype',
"DecoderLayer": 'glmtype',
+ "Phi3DecoderLayer": "phi3type"
}
def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
@@ -93,6 +94,20 @@ def _bigcode_type_transpose(input, mp_size):
split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0)
return torch.cat((split_q[gpu_index], kv), dim=0)
+ def _phi3_type_transpose(input, mp_size):
+ num_kv_heads = get_num_kv_heads()
+ num_heads = get_num_attention_heads()
+ hidden_size = input.shape[1]
+ head_dim = hidden_size // num_heads
+ q_pos = input.shape[0] - 2 * num_kv_heads * head_dim
+ q = input[:q_pos]
+ k = input[q_pos:q_pos + num_kv_heads * head_dim]
+ v = input[q_pos + num_kv_heads * head_dim:]
+ split_q = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0)
+ split_k = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0)
+ split_v = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0)
+ return torch.cat((split_q[gpu_index], split_k[gpu_index], split_v[gpu_index]), dim=0)
+
def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):
# suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following
@@ -110,6 +125,8 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):
return _qwen_type_transpose(src, mp_size, module)
elif fused_qkv_type == 'bigcodetype':
return _bigcode_type_transpose(src, mp_size)
+ elif fused_qkv_type == 'phi3type':
+ return _phi3_type_transpose(src, mp_size)
raise ValueError("unknown fused_qkv_type")
@@ -123,3 +140,24 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):
warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type,"
f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors")
return _bloom_type_transpose(src, mp_size)
+
+
+# For phi3 with chunk mlp, adjust the weight order.
+def shard_chunk_mlp(
+ weight,
+ bias,
+ rank,
+ world_size,
+):
+ weight_gate, weight_states = weight.chunk(2, dim=0)
+ total_size = weight_gate.shape[0]
+ split_weight_gate = weight_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
+ split_weight_states = weight_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
+ shard_weight = torch.cat((split_weight_gate[rank], split_weight_states[rank]), dim=0)
+ if bias is not None:
+ bias_gate, bias_states = bias.chunk(2, dim=0)
+ split_bias_gate = bias_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
+ split_bias_states = bias_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
+ return shard_weight, torch.cat((split_bias_gate[rank], split_bias_states[rank]), dim=0)
+
+ return shard_weight, None
diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py
index e1703562d180..3029a79698dc 100644
--- a/deepspeed/module_inject/replace_module.py
+++ b/deepspeed/module_inject/replace_module.py
@@ -16,7 +16,7 @@
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading
from deepspeed import comm as dist
-from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd
+from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads
from .load_checkpoint import load_model_with_checkpoint
import time
@@ -290,6 +290,10 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
# 4.2 set n_embd
set_n_embd(n_embd)
+ # 4.3 set attention_heads
+ if hasattr(model_config, 'num_attention_heads'):
+ set_num_attention_heads(getattr(model_config, 'num_attention_heads'))
+
# 5. Set linear policies
_autotp.update_linear_policies()
diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py
index 79c19b5f1272..6758c7a657f6 100644
--- a/deepspeed/module_inject/tp_shard.py
+++ b/deepspeed/module_inject/tp_shard.py
@@ -12,6 +12,11 @@ def set_num_kv_heads(num):
num_kv_heads = num
+def set_num_attention_heads(num):
+ global num_attention_heads
+ num_attention_heads = num
+
+
def set_n_embd(num):
global n_embd
n_embd = num
@@ -22,6 +27,11 @@ def get_num_kv_heads():
return num_kv_heads
+def get_num_attention_heads():
+ global num_attention_heads
+ return num_attention_heads
+
+
def get_shard_size(total_size, mp_size, name=None, rank=None):
global num_kv_heads
last_linear = ["lm_head", "embed_out"]
diff --git a/deepspeed/ops/fp_quantizer/__init__.py b/deepspeed/ops/fp_quantizer/__init__.py
index 5575f3567185..995bbae4aeaf 100644
--- a/deepspeed/ops/fp_quantizer/__init__.py
+++ b/deepspeed/ops/fp_quantizer/__init__.py
@@ -3,4 +3,4 @@
# DeepSpeed Team
-from .quantize import FP_Quantize
+from .quantize import FP_Quantize, Quantizer
diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py
index 5dc3c190ae5d..f8435bda16c1 100644
--- a/deepspeed/ops/fp_quantizer/quantize.py
+++ b/deepspeed/ops/fp_quantizer/quantize.py
@@ -4,20 +4,47 @@
# DeepSpeed Team
import torch
+import abc
+from abc import ABC
from deepspeed.ops.op_builder import FPQuantizerBuilder
fp_quant_module = None
-class FP_Quantize:
+class Quantizer(ABC):
+ """
+ Abstract Quantizer class that implmenents quantize/dequantize methods.
+
+ Arguments:
+ group_size (int, optional): number of values or elements that are grouped
+ together for the quantization process.
+ """
+
+ def __init__(self, group_size=512) -> None:
+ self.group_size = group_size
+
+ @abc.abstractmethod
+ def quantize(self,
+ input,
+ q_bits=8,
+ q_mantisa_bits=3,
+ stochastic_mode=False,
+ return_meta_tensor=False) -> torch.Tensor:
+ ...
+
+ @abc.abstractmethod
+ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:
+ ...
+
+
+class FP_Quantize(Quantizer):
def __init__(self, group_size=512) -> None:
global fp_quant_module
+ super().__init__(group_size=group_size)
if fp_quant_module is None:
fp_quant_module = FPQuantizerBuilder().load()
-
- self.group_size = group_size
self.orig_dtype = None
def quantize(self,
@@ -77,3 +104,38 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non
fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)
return fp_out
+
+ def selective_dequantize(self,
+ input_q,
+ indexes,
+ fp_out=None,
+ q_bits=8,
+ q_mantisa_bits=3,
+ scale=None) -> torch.Tensor:
+ assert (not hasattr(self, 'orig_shape') or len(self.orig_shape) == 3), \
+ "Selective-Dequantization works on 3d tensor only! Please reshape the tensor before calling dequantize function."
+ assert (self.orig_dtype is not None), \
+ "[De-quantization Error]: you need to call quantize before dequantizing!"
+ fp_out = torch.empty(
+ (indexes.shape[0],
+ *self.orig_shape[1:]), dtype=self.orig_dtype, device=input_q.device) if fp_out is None else fp_out
+ if q_bits == 8:
+ pass
+ elif q_bits == 12:
+ q_mantisa_bits = 4
+ elif q_bits == 6:
+ q_mantisa_bits = 2
+ elif q_bits == 4:
+ q_mantisa_bits = 1
+ else:
+ assert (0), \
+ f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
+
+ if scale is not None:
+ assert input_q.numel() == fp_out.numel(), \
+ f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
+ input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
+
+ fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits,
+ q_bits - q_mantisa_bits - 1)
+ return fp_out
diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py
index f970e582b354..1f3365b20f4e 100644
--- a/deepspeed/runtime/bf16_optimizer.py
+++ b/deepspeed/runtime/bf16_optimizer.py
@@ -341,7 +341,7 @@ def _update_hp_grad(self, lp, group_idx, param_idx, clear_lp_grads):
# clear gradients
if clear_lp_grads:
- lp.grad._zero()
+ lp.grad.zero_()
@torch.no_grad()
def _update_hp_grads_func(self, clear_lp_grads=False):
@@ -441,11 +441,20 @@ def clear_hp_grads(self):
self.fp32_groups_has_gradients[i] = [False] * len(group)
def clear_lp_grads(self):
+
+ # using zero_() fixed memory address for graph replay
+ set_to_none = False if self.graph_harvesting else True
+ zero_grads_list = []
for group in self.bf16_groups:
for param in group:
- if param.grad is not None:
- # Using zero_() fixed memory address for graph replay
- param.grad.zero_()
+ if set_to_none:
+ param.grad = None
+ elif param.grad is not None:
+ if param.grad.grad_fn is not None:
+ param.grad.detach_()
+ zero_grads_list.append(param.grad)
+ if not set_to_none and len(zero_grads_list) > 0:
+ torch._foreach_zero_(zero_grads_list)
def state_dict(self):
state_dict = {}
diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py
index b5e4e33425d0..66fe29fbbea2 100644
--- a/deepspeed/runtime/compiler.py
+++ b/deepspeed/runtime/compiler.py
@@ -83,84 +83,85 @@ def validate_enabled(cls, field_value, values):
return field_value
-class CompiledModuleWrapper(torch.nn.Module):
-
- def __init__(self, module, compile_config: Union[CompileConfig, None] = None):
- super().__init__()
-
- assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch."
-
- modules = self.__dict__.get('_modules')
- modules['wrapped'] = module
- self.__dict__['wrapped'] = module
- self._is_compiled = False
- self._backend = get_backend_fn(compile_config.backend)
- self._compile_kwargs = compile_config.kwargs
- self._compiler_fn = None
-
- def __getattr__(self, name):
- return getattr(self.__dict__['wrapped'], name)
-
- def set_backend(self, backend: Union[str, Callable]):
- """Set the backend for torch.compile.
-
- Args:
- backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module.
- You can directly pass a function that works as a backend.
- See also `backend` field in `CompileConfig` for more details.
- """
- self._backend = get_backend_fn(backend)
-
- def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None:
- """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten.
- You can also pass a backend name with "backend" key to change the backend.
-
- Args:
- kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile.
- """
-
- if "backend" in kwargs:
- raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.")
- self._compile_kwargs.update(kwargs)
-
- def set_compiler_fn(self, compiler_fn: Callable) -> None:
- """Set a function to be used for compiling the module.
- This function should take a torch.nn.Module as input and return a compiled module.
- Note that other compile options are ignored when a compiler_fn is set.
-
- Example:
- ```python
- def my_compiler_fn(module: torch.nn.Module):
- ...
- return torch.compile(module, ...)
-
- engine.set_compiler_fn(my_compiler_fn)
- ```
- """
- self._compiler_fn = compiler_fn
-
- def forward(self, *args, **kwargs) -> Any:
- if not self.is_compiled:
- if self._compiler_fn is None:
- self.__dict__['wrapped'] = torch.compile(self.wrapped, backend=self._backend, **self._compile_kwargs)
- else:
- self.__dict__['wrapped'] = self._compiler_fn(self.wrapped)
- self._is_compiled = True
-
- return self.__dict__['wrapped'](*args, **kwargs)
-
- @property
- def is_compiled(self) -> bool:
- return self._is_compiled
-
- @property
- def backend(self) -> Union[str, Callable]:
- return self._backend
-
- @property
- def torch_compile_kwargs(self) -> Dict[str, Any]:
- return self._compile_kwargs
-
- @property
- def compiler_fn(self) -> Union[Callable, None]:
- return self._compiler_fn
+def CompiledModuleWrapper(mod, compile_config: Union[CompileConfig, None] = None):
+
+ class wrapper(mod.__class__):
+
+ def __init__(self, module, compile_config: Union[CompileConfig, None] = None):
+ self.__dict__ = module.__dict__.copy()
+
+ assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch."
+
+ self.__dict__['wrapped'] = module
+ self._is_compiled = False
+ self._backend = get_backend_fn(compile_config.backend)
+ self._compile_kwargs = compile_config.kwargs
+ self._compiler_fn = None
+
+ def set_backend(self, backend: Union[str, Callable]):
+ """Set the backend for torch.compile.
+
+ Args:
+ backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module.
+ You can directly pass a function that works as a backend.
+ See also `backend` field in `CompileConfig` for more details.
+ """
+ self._backend = get_backend_fn(backend)
+
+ def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None:
+ """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten.
+ You can also pass a backend name with "backend" key to change the backend.
+
+ Args:
+ kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile.
+ """
+
+ if "backend" in kwargs:
+ raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.")
+ self._compile_kwargs.update(kwargs)
+
+ def set_compiler_fn(self, compiler_fn: Callable) -> None:
+ """Set a function to be used for compiling the module.
+ This function should take a torch.nn.Module as input and return a compiled module.
+ Note that other compile options are ignored when a compiler_fn is set.
+
+ Example:
+ ```python
+ def my_compiler_fn(module: torch.nn.Module):
+ ...
+ return torch.compile(module, ...)
+
+ engine.set_compiler_fn(my_compiler_fn)
+ ```
+ """
+ self._compiler_fn = compiler_fn
+
+ def forward(self, *args, **kwargs) -> Any:
+ if not self.is_compiled:
+ if self._compiler_fn is None:
+ self.__dict__['wrapped'] = torch.compile(self.wrapped,
+ backend=self._backend,
+ **self._compile_kwargs)
+ else:
+ self.__dict__['wrapped'] = self._compiler_fn(self.wrapped)
+ self._is_compiled = True
+
+ return self.__dict__['wrapped'](*args, **kwargs)
+
+ @property
+ def is_compiled(self) -> bool:
+ return self._is_compiled
+
+ @property
+ def backend(self) -> Union[str, Callable]:
+ return self._backend
+
+ @property
+ def torch_compile_kwargs(self) -> Dict[str, Any]:
+ return self._compile_kwargs
+
+ @property
+ def compiler_fn(self) -> Union[Callable, None]:
+ return self._compiler_fn
+
+ return wrapper(mod, compile_config)
diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py
index 9a2b943b0992..34263444c1b7 100644
--- a/deepspeed/runtime/engine.py
+++ b/deepspeed/runtime/engine.py
@@ -469,13 +469,6 @@ def __getattr__(self, name):
return getattr(self, name)
elif name in dir(_module):
return getattr(_module, name)
- elif isinstance(_module, CompiledModuleWrapper):
- try:
- return getattr(_module, name)
- except AttributeError:
- raise AttributeError(
- f"None of {type(self).__name__}, CompiledModuleWrapper, or the wrapped model has the attribute '{name}'"
- )
else:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py
index bf1693307ea7..49093bb73c8f 100755
--- a/deepspeed/runtime/fp16/fused_optimizer.py
+++ b/deepspeed/runtime/fp16/fused_optimizer.py
@@ -241,7 +241,7 @@ def _get_norm_mask_idx(self, group):
group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx])
grad_flat_st_idx = grad_flat_en_idx
- return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device())
+ return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device_name())
def step(self, closure=None):
"""
diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py
index 1dda7f1aad32..be8fe1a368c6 100644
--- a/deepspeed/runtime/pipe/engine.py
+++ b/deepspeed/runtime/pipe/engine.py
@@ -67,9 +67,7 @@ class PipelineEngine(DeepSpeedEngine):
def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
super().__init__(*super_args, **super_kwargs)
- assert isinstance(self.module, PipelineModule) \
- or (hasattr(self.module, 'wrapped') and isinstance(self.module.wrapped, PipelineModule)), \
- "model must base PipelineModule"
+ assert isinstance(self.module, PipelineModule), "model must base PipelineModule"
assert self.zero_optimization_stage(
) < ZeroStageEnum.gradients, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism"
diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py
index 7744b2ee8b98..2c01c3475a70 100755
--- a/deepspeed/runtime/utils.py
+++ b/deepspeed/runtime/utils.py
@@ -171,7 +171,7 @@ def get_norm_with_moe_layers_fast(all_groups_norm, group):
# This implementation standardizes the grad_norm across ranks. A more precise implementation can be found in 'get_norm_with_moe_layers'.
# Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce
scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=group))
- scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device(), dtype=torch.float)
+ scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device_name(), dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=group)
all_groups_norm = scaled_norm_tensor.item()
#print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
@@ -424,9 +424,11 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No
# # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool)
# # for mask_idx in grad_norm_mask[idx]:
# # mask_tensor_[mask_idx[0]:mask_idx[1]] = True
- cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(),
+ cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device_name(),
dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1)
- mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype)
+ mask_tensor = torch.zeros(p.shape[0] + 1,
+ device=get_accelerator().current_device_name(),
+ dtype=p.dtype)
mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1),
cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1]
diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py
index 8fc962c4f2a7..bdec8a55fcbc 100644
--- a/deepspeed/runtime/zero/partitioned_param_coordinator.py
+++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py
@@ -34,6 +34,7 @@ def get_all_parameters(sub_module, recurse=False):
return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())
+@compiler.disable
def iter_params(module: Module, recurse=False) -> Iterable[Parameter]:
return map(lambda pair: pair[1], get_all_parameters(module, recurse))
diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py
index 68cab13c4a93..13ca29c9fceb 100644
--- a/deepspeed/runtime/zero/stage3.py
+++ b/deepspeed/runtime/zero/stage3.py
@@ -15,7 +15,7 @@
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
-from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
+from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
@@ -1409,7 +1409,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
norm_is_nan = total_norm.isnan()
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)
- err = torch.tensor(-1.0, device=self.device, dtype=torch.float)
+ err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float)
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm
return total_norm
@@ -2027,7 +2027,7 @@ def step(self, closure=None):
return
norm_groups = self._get_norm_groups()
- scaled_global_grad_norm = torch.norm(torch.stack(norm_groups))
+ scaled_global_grad_norm = get_global_norm(norm_list=norm_groups)
# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
diff --git a/op_builder/hpu/fused_adam.py b/op_builder/hpu/fused_adam.py
index d77228317ddb..5acb121668e3 100644
--- a/op_builder/hpu/fused_adam.py
+++ b/op_builder/hpu/fused_adam.py
@@ -4,10 +4,88 @@
# DeepSpeed Team
-from .builder import CPUOpBuilder
+try:
+ # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
+ # if successful this also means we're doing a local install and not JIT compile path
+ from op_builder import __deepspeed__ # noqa: F401 # type: ignore
+ from op_builder.builder import OpBuilder
+except ImportError:
+ from deepspeed.ops.op_builder.builder import OpBuilder
+try:
+ import torch
+ import math
+except ImportError as e:
+ pass
-class FusedAdamBuilder(CPUOpBuilder):
+
+class HPUFusedAdam:
+ htcore = None
+ is_lazy_mode = None
+
+ @staticmethod
+ def multi_tensor_adam(chunk_size, noop_flag_buffer, tensor_lists, lr, beta1, beta2, epsilon, step, adam_w_mode,
+ bias_correction, weight_decay, *args):
+
+ if HPUFusedAdam.htcore is None:
+ from habana_frameworks.torch import core as htcore
+ from habana_frameworks.torch.utils.internal import is_lazy
+ HPUFusedAdam.htcore = htcore
+ HPUFusedAdam.is_lazy_mode = is_lazy()
+
+ htcore = HPUFusedAdam.htcore
+
+ htcore.step_closure._mark_step_if_lazy()
+ step_size = lr
+ if bias_correction:
+ bias_correction1 = 1.0 - pow(beta1, step)
+ bias_correction2 = 1.0 - pow(beta2, step)
+ step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
+
+ neg_step = -step_size
+ neg_step_t = (torch.tensor([neg_step], dtype=torch.float,
+ requires_grad=False).to(tensor_lists[1][0].dtype).to(tensor_lists[1][0].device,
+ non_blocking=True))
+
+ weight_decay = weight_decay if adam_w_mode else 0
+
+ # since lr is fed into the kernel as tensor, perform the scalar multiplication of wd here
+ # NOTE: TODO if lr is updated every step, then we need to convert it as tensor and
+ # perform weight decay unconditonally.
+ modified_wd = 1.0 - weight_decay * lr
+
+ if HPUFusedAdam.is_lazy_mode:
+ torch.ops.hpu.optimizer_adamw(
+ tensor_lists[0],
+ tensor_lists[1],
+ tensor_lists[2],
+ tensor_lists[3],
+ neg_step_t,
+ beta1,
+ beta2,
+ epsilon,
+ modified_wd,
+ )
+ else:
+ modified_wd_t = (torch.tensor([modified_wd], dtype=torch.float, requires_grad=False).to(
+ tensor_lists[1][0].dtype).to(tensor_lists[1][0].device, non_blocking=True))
+ torch.ops.hpu.optimizer_adamw(
+ tensor_lists[0],
+ tensor_lists[1],
+ tensor_lists[2],
+ tensor_lists[3],
+ neg_step_t,
+ beta1,
+ beta2,
+ epsilon,
+ modified_wd_t,
+ modified_wd != 1.0,
+ )
+
+ htcore.step_closure._mark_step_if_lazy()
+
+
+class FusedAdamBuilder(OpBuilder):
BUILD_VAR = "DS_BUILD_FUSED_ADAM"
NAME = "fused_adam"
@@ -18,12 +96,10 @@ def absolute_name(self):
return f'deepspeed.ops.adam.{self.NAME}_op'
def sources(self):
- return ['csrc/cpu/adam/fused_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp']
-
- def cxx_args(self):
- args = super().cxx_args()
- args += ['-DENABLE_BFLOAT16']
- return args
+ return []
def include_paths(self):
- return ['csrc/includes']
+ return []
+
+ def load(self, verbose=True):
+ return HPUFusedAdam
diff --git a/setup.py b/setup.py
index f1367b850e02..839941b989c9 100755
--- a/setup.py
+++ b/setup.py
@@ -219,9 +219,9 @@ def create_dir_symlink(src, dest):
if sys.platform == "win32":
# This creates a symbolic links on Windows.
# It needs Administrator privilege to create symlinks on Windows.
- create_dir_symlink('..\\..\\csrc', '.\\deepspeed\\ops\\csrc')
- create_dir_symlink('..\\..\\op_builder', '.\\deepspeed\\ops\\op_builder')
- create_dir_symlink('..\\accelerator', '.\\deepspeed\\accelerator')
+ create_dir_symlink('.\\deepspeed\\ops\\csrc', '..\\..\\csrc')
+ create_dir_symlink('.\\deepspeed\\ops\\op_builder', '..\\..\\op_builder')
+ create_dir_symlink('.\\deepspeed\\accelerator', '..\\accelerator')
egg_info.manifest_maker.template = 'MANIFEST_win.in'
# Parse the DeepSpeed version string from version.txt.
diff --git a/tests/unit/linear/test_linear.py b/tests/unit/linear/test_linear.py
new file mode 100644
index 000000000000..ccd26b4cd726
--- /dev/null
+++ b/tests/unit/linear/test_linear.py
@@ -0,0 +1,128 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import pytest
+import torch
+import deepspeed
+import deepspeed.comm as dist
+
+from deepspeed.accelerator import get_accelerator
+from deepspeed.linear import OptimizedLinear, LoRAConfig, QuantizationConfig
+from unit.common import DistributedTest
+
+from deepspeed.ops.op_builder import FPQuantizerBuilder
+
+if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]:
+ pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True)
+
+
+class TestBasicLinear(DistributedTest):
+ world_size = 2
+
+ def test(self):
+ lora_config = None
+ quantization_config = None
+
+ input_features = 64 # Number of input features
+ output_features = 64 # Number of output features
+ batch_size = 1 # Number of samples in a batch
+
+ linear_layer = OptimizedLinear(input_dim=input_features,
+ output_dim=output_features,
+ lora_config=lora_config,
+ quantization_config=quantization_config,
+ dtype=torch.bfloat16)
+
+ dummy_input = torch.rand(batch_size, input_features, dtype=torch.bfloat16)
+ output = linear_layer(dummy_input)
+ assert output.shape == (batch_size, output_features)
+
+
+@pytest.mark.parametrize("base_weight_sharding", [1, 2])
+class TestLoRALinear(DistributedTest):
+ world_size = 2
+
+ def test(self, base_weight_sharding):
+ rank = dist.get_rank()
+ lora_config = None
+ quantization_config = None
+
+ input_features = 64 # Number of input features
+ output_features = 64 # Number of output features
+ batch_size = 5 # Number of samples in a batch
+
+ lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=base_weight_sharding)
+
+ linear_layer = OptimizedLinear(input_dim=input_features,
+ output_dim=output_features,
+ lora_config=lora_config,
+ quantization_config=quantization_config,
+ dtype=torch.bfloat16)
+ device = get_accelerator().current_device_name()
+ linear_layer = linear_layer.to(device)
+ if rank == 0:
+ for n, p in linear_layer.named_parameters():
+ print(f"{n}, {p.shape}")
+
+ dummy_input = torch.rand(batch_size, input_features, device=device, dtype=torch.bfloat16)
+
+ output = linear_layer(dummy_input)
+ assert output.shape == (batch_size, output_features)
+
+
+@pytest.mark.parametrize("q_bits", [8, 6])
+class TestQuantLinear(DistributedTest):
+ world_size = 2
+
+ def test(self, q_bits):
+ rank = dist.get_rank()
+ lora_config = None
+
+ input_features = 64 # Number of input features
+ output_features = 64 # Number of output features
+ batch_size = 5 # Number of samples in a batch
+
+ lora_config = None
+ quantization_config = QuantizationConfig(q_bits=q_bits)
+
+ linear_layer = OptimizedLinear(input_dim=input_features,
+ output_dim=output_features,
+ lora_config=lora_config,
+ quantization_config=quantization_config,
+ dtype=torch.bfloat16)
+ device = get_accelerator().current_device_name()
+ linear_layer = linear_layer.to(device)
+ dummy_input = torch.rand([batch_size, input_features], device=device, dtype=torch.bfloat16)
+
+ output = linear_layer(dummy_input)
+ assert output.shape == (batch_size, output_features)
+
+
+@pytest.mark.parametrize("base_weight_sharding", [1, 2], ids=['bws1', 'bws2'])
+@pytest.mark.parametrize("q_bits", [8, 6], ids=['qbit8', 'qbit6'])
+class TestOptimizedLinear(DistributedTest):
+ world_size = 2
+
+ def test(self, base_weight_sharding, q_bits):
+ rank = dist.get_rank()
+ lora_config = None
+
+ input_features = 64 # Number of input features
+ output_features = 64 # Number of output features
+ batch_size = 5 # Number of samples in a batch
+
+ lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=base_weight_sharding)
+ quantization_config = QuantizationConfig(q_bits=q_bits)
+
+ linear_layer = OptimizedLinear(input_dim=input_features,
+ output_dim=output_features,
+ lora_config=lora_config,
+ quantization_config=quantization_config,
+ dtype=torch.bfloat16)
+ device = get_accelerator().current_device_name()
+ linear_layer = linear_layer.to(device)
+ dummy_input = torch.rand([batch_size, input_features], device=device, dtype=torch.bfloat16)
+ output = linear_layer(dummy_input)
+ assert output.shape == (batch_size, output_features)
diff --git a/tests/unit/linear/test_quant_param.py b/tests/unit/linear/test_quant_param.py
new file mode 100644
index 000000000000..9479b3cba8a0
--- /dev/null
+++ b/tests/unit/linear/test_quant_param.py
@@ -0,0 +1,58 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import pytest
+import torch
+import deepspeed
+
+from deepspeed.accelerator import get_accelerator
+from deepspeed.linear.quantization import QuantizedParameter
+from deepspeed.linear.config import QuantizationConfig
+
+from deepspeed.ops.op_builder import FPQuantizerBuilder
+
+from unit.common import DistributedTest
+
+if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]:
+ pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True)
+
+
+class TestQuantParam(DistributedTest):
+ world_size = 1
+
+ @pytest.mark.parametrize('dtype', [torch.half, torch.float])
+ def test_unsupported_dtypes(self, dtype):
+ device = get_accelerator().current_device_name()
+ data = torch.rand(5, 5, device='cpu', dtype=dtype)
+ qp = QuantizedParameter(data)
+ with pytest.raises(AssertionError):
+ qp.to(device)
+
+ def test_requires_grad(self):
+ data = torch.rand(5, 5, dtype=torch.bfloat16)
+ with pytest.raises(ValueError):
+ QuantizedParameter(data, requires_grad=True)
+
+ def test_move_to_accelerator(self):
+ device = get_accelerator().current_device()
+ data = torch.rand(5, 5, device='cpu', dtype=torch.bfloat16)
+ qp = QuantizedParameter(data)
+ assert qp.device == torch.device('cpu')
+ qp = qp.to(get_accelerator().current_device_name())
+ assert qp.device == torch.device(device)
+ assert qp.dtype == torch.int8
+
+ def test_hf_clone(self):
+ device = get_accelerator().current_device_name()
+ data = torch.rand(5, 5, device=device, dtype=torch.bfloat16)
+
+ quantization_config = QuantizationConfig(q_bits=6)
+ qp = QuantizedParameter(data, quantization_config=quantization_config)
+
+ # should be able to clone parameter via dict, HF expects this to work
+ qp_copy = QuantizedParameter(qp.data, **qp.__dict__)
+
+ assert all(qp.data == qp_copy.data)
+ assert qp.quantization_config == qp_copy.quantization_config
diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py
index d39f9fe3d651..fdff9430a4e6 100644
--- a/tests/unit/moe/test_moe.py
+++ b/tests/unit/moe/test_moe.py
@@ -177,7 +177,7 @@ class TestTopk(DistributedTest):
world_size = 2
def test(self):
- device = get_accelerator().current_device()
+ device = get_accelerator().current_device_name()
if dist.get_rank() == 0:
logits = torch.rand(2, 2, device=device)
elif dist.get_rank() == 1:
diff --git a/tests/unit/ops/fp_quantizer/test_fp_quant.py b/tests/unit/ops/fp_quantizer/test_fp_quant.py
index 101f4cd69811..bed8bd7e3bcc 100644
--- a/tests/unit/ops/fp_quantizer/test_fp_quant.py
+++ b/tests/unit/ops/fp_quantizer/test_fp_quant.py
@@ -61,6 +61,35 @@ def test_fp_quant_meta(dtype):
assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}"
+@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
+def test_fp_quant_selective(dtype):
+ group_size = 128
+ q_bits = 8
+ exp_bits = 4
+ man_bits = 3
+
+ fpq = FP_Quantize(group_size=group_size)
+ indexes = torch.zeros(2, dtype=torch.int32, device='cuda')
+ indexes[0] = 1
+ indexes[1] = 3
+ for i in range(10):
+ x = torch.rand(4, 1024, dtype=dtype, device='cuda')
+
+ x = x.reshape(4, 1, x.shape[-1])
+ ds_x = x.clone()
+ x_quantized = fpq.quantize(ds_x, q_bits=q_bits)
+ x_dequantized = fpq.selective_dequantize(x_quantized, indexes, q_bits=q_bits)
+
+ qtorch_out = qtorch_quantize(x.index_select(0, indexes),
+ exp_bits=exp_bits,
+ man_bits=man_bits,
+ group_size=group_size)
+ qtorch_error = (qtorch_out - x.index_select(0, indexes)).abs().sum() / x.numel()
+ ds_error = (x_dequantized - x.index_select(0, indexes)).abs().sum() / x.numel()
+
+ assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}"
+
+
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize("q_bits", [8, 6, 12], ids=["qbits8", "qbits6", "qbits12"])
def test_fp_quant(dtype, q_bits):
diff --git a/tests/unit/runtime/compile/test_compile_wrapper.py b/tests/unit/runtime/compile/test_compile_wrapper.py
index d1830534f6ea..62af25ac3ba4 100644
--- a/tests/unit/runtime/compile/test_compile_wrapper.py
+++ b/tests/unit/runtime/compile/test_compile_wrapper.py
@@ -31,11 +31,9 @@ def base_config():
},
"compile": {
"enabled": True,
- "backend": "inductor"
+ "backend": get_accelerator().get_compile_backend()
}
}
- if get_accelerator().device_name() == 'hpu':
- config_dict['compile']['backend'] = 'hpu_backend'
return config_dict
diff --git a/tests/unit/runtime/compile/test_compile_zero.py b/tests/unit/runtime/compile/test_compile_zero.py
index 7568c27e3ed2..a0736b0f5425 100644
--- a/tests/unit/runtime/compile/test_compile_zero.py
+++ b/tests/unit/runtime/compile/test_compile_zero.py
@@ -12,7 +12,7 @@
from unit.runtime.compile.util import compare_loss
from unit.common import DistributedTest
-from unit.util import bf16_required_version_check
+from unit.util import bf16_required_version_check, skip_on_arch
pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.1),
reason="Compile tests requires Pytorch version 2.1 or above")
@@ -26,9 +26,11 @@ class TestZeRO(DistributedTest):
@pytest.mark.parametrize('zero_stage', [1, 2, 3])
@pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme])
def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device):
+ if dtype == torch.bfloat16:
+ skip_on_arch(min_arch=8)
if dtype == torch.bfloat16 and not bf16_required_version_check():
pytest.skip(
- " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
+ "DeepSpeed BFloat16 tests need NCCL >= 2.10.3, CUDA >=11.0, and HW support for BFloat16 to run correctly"
)
if get_accelerator().device_name() == "cpu":
pytest.skip("CPU does not support this test yet")
@@ -51,12 +53,10 @@ def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device):
},
"compile": {
"enabled": True,
- "backend": "inductor"
+ "backend": get_accelerator().get_compile_backend()
}
}
- if get_accelerator().device_name() == 'hpu':
- config_dict['compile']['backend'] = 'hpu_backend'
if offload_device == OffloadDeviceEnum.cpu:
config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device}
elif offload_device == OffloadDeviceEnum.nvme:
diff --git a/tests/unit/runtime/compile/test_load_config.py b/tests/unit/runtime/compile/test_load_config.py
index 601adae58884..cee8d3b23f6b 100644
--- a/tests/unit/runtime/compile/test_load_config.py
+++ b/tests/unit/runtime/compile/test_load_config.py
@@ -47,12 +47,10 @@ def base_config():
},
"compile": {
"enabled": True,
- "backend": "inductor"
+ "backend": get_accelerator().get_compile_backend()
}
}
- if get_accelerator().device_name() == 'hpu':
- config_dict['compile']['backend'] = 'hpu_backend'
return config_dict
diff --git a/version.txt b/version.txt
index e867cc2a66a8..ac4a79626c87 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.14.2
+0.14.3