Skip to content

Commit

Permalink
Merge branch 'master' into features/fp6_compile_err
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii authored Apr 15, 2024
2 parents 88bd992 + 54c0687 commit 9870312
Show file tree
Hide file tree
Showing 13 changed files with 186 additions and 134 deletions.
168 changes: 68 additions & 100 deletions csrc/fp_quantizer/quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -219,119 +219,100 @@ __global__ void apply_quantization(T* val,
}

template <typename T,
int unroll,
int q_mantisa_bits,
int total_q_bits = 16,
int _mantisa_bits = 3,
int _exponent_bits = 4>
__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size)
__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int total_num_elements)
{
int tidx = threadIdx.x;
int wid = tidx >> 5;
int lane = tidx & 0x1f;
int gid = blockIdx.x * quantization::warps + wid;
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size;

constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits);

constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
constexpr uint32_t load_stride = vector_size * hw_warp_size;
const uint32_t thread_offset = lane * vector_size;
const uint32_t thread_load_offset = lane * vector_size * quantized_bits / 8;
const uint32_t base_load_offset =
gid * (group_size * quantized_bits / 8 + 4) + thread_load_offset; // 4-byte scale offset
const uint32_t base_store_offset = gid * group_size + thread_offset;
const uint8_t* load_base_ptr = val + base_load_offset;
const uint32_t g_index = (tidx / group_size);
const uint32_t group_size_bytes = (group_size * quantized_bits / 8);
const uint8_t* load_base_ptr =
val + g_index * (group_size_bytes + 4) + (tidx % group_size) * quantized_bits / 8;

int mantisa_mask = ((1 << q_mantisa_bits) - 1);
mantisa_mask <<= (_mantisa_bits - q_mantisa_bits);

T* store_base_ptr = q_val + base_store_offset;
float scale; //= q_scale[gid];
T* store_base_ptr = q_val + tidx;
float scale;

uint8_t* scale_as_int8 = reinterpret_cast<uint8_t*>(&scale);
if (quantized_bits == 6) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
scale_as_int8,
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8));
scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
scale_as_int8 + quantization::quanitzed_access_granularity_6bits,
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8) +
val + g_index * (group_size_bytes + 4) + group_size_bytes +
quantization::quanitzed_access_granularity_6bits);
} else
mem_access::load_global<quantization::quanitzed_access_granularity>(
scale_as_int8,
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8));

#pragma unroll
for (int i = 0; i < unroll; i++) {
if (i * load_stride + thread_offset < group_size) {
uint64_t q_buf_in;
uint64_t q_buf_in1;
uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf_in);
uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf_in1);
uint32_t loading_offset = i * load_stride * quantized_bits / 8;
if (quantized_bits == 6) {
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data, load_base_ptr + loading_offset);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity_6bits);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits * 2,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity_6bits * 2);
} else {
scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);

if (tidx < total_num_elements) {
uint64_t q_buf_in;
uint64_t q_buf_in1;
uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf_in);
uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf_in1);
if (quantized_bits == 6) {
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data, load_base_ptr);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits,
load_base_ptr + quantization::quanitzed_access_granularity_6bits);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits * 2,
load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2);
} else {
mem_access::load_global<quantization::quanitzed_access_granularity>(int8_data,
load_base_ptr);
if (quantized_bits > 4) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
int8_data, load_base_ptr + loading_offset);
if (quantized_bits > 4) {
int8_data + quantization::quanitzed_access_granularity,
load_base_ptr + quantization::quanitzed_access_granularity);
if (quantized_bits == 12) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
int8_data + quantization::quanitzed_access_granularity,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity);
if (quantized_bits == 12) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
int8_data1,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity * 2);
}
int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2);
}
}
T store_buf[vector_size];
uint16_t* q_buf = reinterpret_cast<uint16_t*>(store_buf);
}
T store_buf[vector_size];
uint16_t* q_buf = reinterpret_cast<uint16_t*>(store_buf);
#pragma unroll
for (int j = 0; j < vector_size; j++) {
uint16_t new_data;
if (j < 5 || quantized_bits != 12) {
new_data = (uint16_t)(q_buf_in >> (j * quantized_bits));
} else {
if (j == 5) {
new_data = (uint16_t)(q_buf_in1);
new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60));
} else
new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
}
for (int j = 0; j < vector_size; j++) {
uint16_t new_data;
if (j < 5 || quantized_bits != 12) {
new_data = (uint16_t)(q_buf_in >> (j * quantized_bits));
} else {
if (j == 5) {
new_data = (uint16_t)(q_buf_in1);
new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60));
} else
new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
}

uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
uint16_t dst_mantisa = (new_data & _mantisa_mask);
uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
uint16_t dst_mantisa = (new_data & _mantisa_mask);

if (dst_exponent != (1 << q_exponent_bits) - 1)
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
(1 << (q_exponent_bits - 1)) - 1;
if (dst_exponent != (1 << q_exponent_bits) - 1)
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
(1 << (q_exponent_bits - 1)) - 1;

q_buf[j] = ((sign << (q_exponent_bits + q_mantisa_bits)) |
(dst_exponent << q_mantisa_bits) |
(dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
float up_cast = conversion::to<float>(store_buf[j]);
store_buf[j] = conversion::to<T>(up_cast * scale);
}
mem_access::store_global<quantization::access_granularity>(
store_base_ptr + i * load_stride, store_buf);
q_buf[j] =
((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) |
(dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
float up_cast = conversion::to<float>(store_buf[j]);
store_buf[j] = conversion::to<T>(up_cast * scale);
}
mem_access::store_global<quantization::access_granularity>(store_base_ptr, store_buf);
}
}

Expand Down Expand Up @@ -386,12 +367,6 @@ INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8);
#endif
INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8);

#define LAUNCH_FOR_DEQUANTIZATION_UNROLL(COUNT) \
case COUNT: \
apply_dequantization<T, COUNT, mantisa, 16, CONST_Q_MANTISA_BITS, CONST_Q_EXPONENT_BITS> \
<<<grid, block, 0, stream>>>(val, q_val, group_size); \
break;

template <typename T, int mantisa>
void launch_dequantization(uint8_t* val,
T* q_val,
Expand All @@ -401,21 +376,14 @@ void launch_dequantization(uint8_t* val,
int q_exponent_bits,
cudaStream_t stream)
{
const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps);
int blocks = ((num_groups * group_size) - 1) /
(quantization::threads * (quantization::access_granularity / sizeof(T))) +
1;
const dim3 grid(blocks);
const dim3 block(quantization::threads);

constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T);
const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll;

DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] {
switch (copy_unroll) {
LAUNCH_FOR_DEQUANTIZATION_UNROLL(1)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(2)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(3)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(4)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(5)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(6)
}
apply_dequantization<T, mantisa, 16, CONST_Q_MANTISA_BITS, CONST_Q_EXPONENT_BITS>
<<<grid, block, 0, stream>>>(val, q_val, group_size, (num_groups * group_size));
});
}
#define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@

from .zero_checkpoint import ZeROCheckpoint

from .universal_checkpoint import enable_universal_checkpoint
from .universal_checkpoint import enable_universal_checkpoint, SubparamShape

from .constants import *
4 changes: 4 additions & 0 deletions deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,14 @@
# Similarly, load_hp_checkpoint_state has to take the needed actions when loading from universal.
PARAM_N_SUB_PARAMS = "param_n_sub_params"

SUB_PARAM_SHAPE = "sub_param_shape"

# Regex list of parameters that require special handling
VOCABULARY_PARAMETER_PATTERNS = 'vocabulary_parameter_patterns'
PIPELINE_REPLICATED_PARAMETER_PATTERNS = 'pipeline_replicated_parameter_patterns'
PARAMETER_TO_AVERAGE_PATTERNS = 'parameter_to_average_patterns'
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS = 'parameter_with_row_parallelism_patterns'
TP_REPLICATED_PARAMETER_PATTERNS = 'tp_replicated_parameter_patterns'
PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0 = 'parameter_with_2_sub_params_cat_dim_0'
PARAMETER_WITH_SUB_PARAMS = 'parameter_with_sub_params'
SUB_PARAMS_SHAPE = 'sub_params_shape'
38 changes: 38 additions & 0 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# DeepSpeed Team

from functools import partial
from itertools import chain
import argparse
import glob
import itertools
Expand All @@ -28,6 +29,7 @@
PARAM,
CAT_DIM,
PARAM_N_SUB_PARAMS,
SUB_PARAM_SHAPE,
VOCAB_TENSOR,
UNIVERSAL_CHECKPOINT_INFO,
VOCABULARY_PARAMETER_PATTERNS,
Expand All @@ -36,6 +38,8 @@
PARAMETER_TO_AVERAGE_PATTERNS,
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS,
PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0,
PARAMETER_WITH_SUB_PARAMS,
SubparamShape,
)


Expand Down Expand Up @@ -180,8 +184,11 @@ def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):
parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, [])
vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, [])
parameters_with_2_sub_params_cat_dim_0 = universal_checkpoint_info.get(PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0, [])
parameter_with_sub_params = universal_checkpoint_info.get(PARAMETER_WITH_SUB_PARAMS, [])

unmatched_patterns = set(replicated_parameters + parameters_to_average + parameters_with_row_parallelism +
vocabulary_parameters + parameters_with_2_sub_params_cat_dim_0)
unmatched_patterns.update(chain.from_iterable(SubparamShape(**s).patterns for s in parameter_with_sub_params))

def get_matched_pattern(patterns_, name_):
matched_ = [pattern_ for pattern_ in patterns_ if re.match(pattern_, name_)]
Expand All @@ -192,6 +199,17 @@ def get_matched_pattern(patterns_, name_):
return pattern_
return None

def get_matched_sub_params_pattern(name_):
for subparam_shape_dict in parameter_with_sub_params:
subparam_shape = SubparamShape(**subparam_shape_dict)
for pattern_ in subparam_shape.patterns:
if re.match(pattern_, name_):
unmatched_patterns.discard(pattern_)
return subparam_shape
return None

matched_sub_params_shape = get_matched_sub_params_pattern(name)

step_merged = _merge_zero_shards(slice_base_path, "step", tp_degree, shape)
if step_merged:
_save_checkpoint(os.path.join(param_base_path, f"step.pt"), step_merged[0])
Expand Down Expand Up @@ -219,6 +237,26 @@ def get_matched_pattern(patterns_, name_):
param = torch.cat([merged_chunks_0, merged_chunks_1], dim=cat_dim)
ckpt_dict[CAT_DIM] = cat_dim
ckpt_dict[PARAM_N_SUB_PARAMS] = 2
elif matched_sub_params_shape:
merged_chunks = []
partition_dim = matched_sub_params_shape.partition_dim

sub_dim_sizes = matched_sub_params_shape.shape[partition_dim]
if not isinstance(sub_dim_sizes, tuple):
sub_dim_sizes = (sub_dim_sizes, )

partition_shape = [sum(d) if isinstance(d, tuple) else d for d in matched_sub_params_shape.shape]
partition_shape = [d // tp_degree if i == partition_dim else d for i, d in enumerate(partition_shape)]
slices = [s.view(partition_shape) for s in slices]

offset = 0
for sub_dim_size in sub_dim_sizes:
part_sub_dim_size = sub_dim_size // tp_degree
merged_chunks.append(
torch.cat([s.narrow(partition_dim, offset, part_sub_dim_size) for s in slices], dim=partition_dim))
offset += part_sub_dim_size
param = torch.cat(merged_chunks, dim=partition_dim)
ckpt_dict[SUB_PARAM_SHAPE] = matched_sub_params_shape
else:
cat_dim = 1 if get_matched_pattern(parameters_with_row_parallelism, name) else 0
# print(f"merge {name} with CAT DIM: {cat_dim}")
Expand Down
33 changes: 31 additions & 2 deletions deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@
import re
import torch
import types
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS)
from typing import List, Tuple, Union
from dataclasses import dataclass
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS, SUB_PARAM_SHAPE)


@dataclass
class SubparamShape:
patterns: List[str]
shape: Tuple[Union[Tuple[int], int]]
partition_dim: int


def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
Expand Down Expand Up @@ -76,12 +85,32 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")

sub_param_shape = ckpt_dict.get(SUB_PARAM_SHAPE, None)
# since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse
# special case is when a single parameter is effectively a container for multiple sub parameters
# (more details at PARAM_N_SUB_PARAMS definition)
chunk_dim = ckpt_dict.get(CAT_DIM, 0)
n_sub_params = ckpt_dict.get(PARAM_N_SUB_PARAMS, 1)
if n_sub_params > 1:
if sub_param_shape:
partition_dim = sub_param_shape.partition_dim
sub_dim_sizes = sub_param_shape.shape[partition_dim]
if not isinstance(sub_dim_sizes, tuple):
sub_dim_sizes = (sub_dim_sizes, )

partition_shape = [sum(d) if isinstance(d, tuple) else d for d in sub_param_shape.shape]
full_hp_param = full_hp_param.view(partition_shape)

offset = 0
merged_chunks = []
for sub_dim_size in sub_dim_sizes:
sub_params_tp_slice = full_hp_param.narrow(partition_dim,
offset, sub_dim_size).chunk(tp_world_size,
dim=partition_dim)[tp_rank]
merged_chunks.append(sub_params_tp_slice)
offset += sub_dim_size
tp_hp_slice = torch.cat(merged_chunks, dim=partition_dim)

elif n_sub_params > 1:
sub_params = full_hp_param.chunk(n_sub_params, dim=chunk_dim)
sub_params_tp_slice = [p.chunk(tp_world_size, dim=chunk_dim)[tp_rank] for p in sub_params]
tp_hp_slice = torch.cat(sub_params_tp_slice, dim=chunk_dim)
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
Expand Down Expand Up @@ -2027,7 +2027,7 @@ def step(self, closure=None):
return

norm_groups = self._get_norm_groups()
scaled_global_grad_norm = get_global_norm(norm_list=norm_groups)
scaled_global_grad_norm = torch.norm(torch.stack(norm_groups))

# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
Expand Down
Loading

0 comments on commit 9870312

Please sign in to comment.