diff --git a/CMakeLists.txt b/CMakeLists.txt index ac94aac57..1fc3e2932 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -198,6 +198,7 @@ set(SOURCES src/ops/awq/gemv.cc src/ops/awq/gemv_cpu.cc src/ops/sum.cc + src/ops/int4gemm_cpu.cc src/padder.cc src/profiler.cc src/random.cc @@ -324,7 +325,7 @@ if(WITH_MKL) endif() # Find MKL libraries. - find_library(MKL_CORE_LIBRARY NAMES mkl_core HINTS ${MKL_ROOT}/lib ${MKL_ROOT}/lib/intel64) + find_library(MKL_CORE_LIBRARY NAMES mkl_core PATHS ${MKL_ROOT}/lib ${MKL_ROOT}/lib/intel64) if(MKL_CORE_LIBRARY) get_filename_component(MKL_LIBRARY_DIR ${MKL_CORE_LIBRARY} DIRECTORY) message(STATUS "Found MKL library directory: ${MKL_LIBRARY_DIR}") @@ -605,6 +606,7 @@ if (WITH_CUDA) src/ops/awq/gemm_gpu.cu src/ops/awq/gemv_gpu.cu src/ops/awq/dequantize_gpu.cu + src/ops/int4gemm_gpu.cu ) set_source_files_properties( diff --git a/include/ctranslate2/layers/common.h b/include/ctranslate2/layers/common.h index 3985b3feb..eb1f09db9 100644 --- a/include/ctranslate2/layers/common.h +++ b/include/ctranslate2/layers/common.h @@ -62,6 +62,7 @@ namespace ctranslate2 { const StorageView& _embeddings; const DataType _output_type; const StorageView* _qscale; + const StorageView* _qzero; }; // This enum order should remain fixed. diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h index 32e4f8403..f2ab02db2 100644 --- a/include/ctranslate2/models/model.h +++ b/include/ctranslate2/models/model.h @@ -14,7 +14,8 @@ namespace ctranslate2 { enum class QUANTIZATION_TYPE { CT2, AWQ_GEMM, - AWQ_GEMV + AWQ_GEMV, + HQQ_4BIT, }; static const size_t current_binary_version = 6; @@ -113,11 +114,13 @@ namespace ctranslate2 { } // If the model contains variables, they will be moved to the new device. - void set_device(const Device device, const int index = 0); + void set_device(const Device device, const int index = 0, const bool format_lower_bit = false); // Copy the model to another device. std::shared_ptr copy_to(Device device, int device_index = 0) const; + template + T get_config_if_exists(const std::string& name) const; const StorageView* get_variable_if_exists(const std::string& name) const; const StorageView& get_variable(const std::string& name) const; std::unordered_map get_variables() const; diff --git a/include/ctranslate2/ops/gemm.h b/include/ctranslate2/ops/gemm.h index 3c4efbb02..41cb78fe2 100644 --- a/include/ctranslate2/ops/gemm.h +++ b/include/ctranslate2/ops/gemm.h @@ -18,7 +18,8 @@ namespace ctranslate2 { bool trans_b = false, bool a_is_packed = false, bool b_is_packed = false, - const ActivationType* activation_type = nullptr); + const ActivationType* activation_type = nullptr, + const int _group_size = 0); void operator()(const StorageView& a, const StorageView& b, @@ -26,6 +27,15 @@ namespace ctranslate2 { const StorageView* a_shift_compensation = nullptr, const StorageView* bias = nullptr) const; + void operator()(const StorageView& a, + const StorageView& b, + const StorageView& scaleAndZero, + StorageView& c, + const StorageView* bias = nullptr) const; + + StorageView convert_to_int4pack(const StorageView& input, + int32_t innerKTiles); + // Return the packed representation of b, if implemented by the GEMM backend. static StorageView pack_b_input(const StorageView& b, const bool transpose, @@ -49,13 +59,26 @@ namespace ctranslate2 { bool _trans_b; bool _a_is_packed; bool _b_is_packed; + const int _group_size; template void compute(const StorageView& a, const StorageView& b, StorageView& c, const StorageView* a_shift_compensation) const; - }; +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + template + void compute(const StorageView& a, + const StorageView& b, + const StorageView& scaleAndZero, + StorageView& c) const; + + template + void convert_weight_to_int4pack(const StorageView& a, + StorageView& b, + int32_t innerKTiles); +#endif + }; } } diff --git a/include/ctranslate2/types.h b/include/ctranslate2/types.h index a6dc0e1fb..8fa47b7ff 100644 --- a/include/ctranslate2/types.h +++ b/include/ctranslate2/types.h @@ -36,6 +36,7 @@ namespace ctranslate2 { INT16, FLOAT16, BFLOAT16, + INT32_BFLOAT16, }; ComputeType str_to_compute_type(const std::string& compute_type); diff --git a/python/ctranslate2/converters/converter.py b/python/ctranslate2/converters/converter.py index ecede044a..f8ea1671b 100644 --- a/python/ctranslate2/converters/converter.py +++ b/python/ctranslate2/converters/converter.py @@ -5,7 +5,7 @@ from typing import Optional -from ctranslate2.specs.model_spec import ACCEPTED_MODEL_TYPES, ModelSpec +from ctranslate2.specs.model_spec import ACCEPTED_MODEL_TYPES, DEVICE, ModelSpec class Converter(abc.ABC): @@ -30,6 +30,18 @@ def declare_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParse choices=ACCEPTED_MODEL_TYPES, help="Weight quantization type.", ) + parser.add_argument( + "--group_size", + type=int, + default=None, + help="Group size used in quantization lower bit", + ) + parser.add_argument( + "--device", + default=None, + choices=DEVICE, + help="Device where the quantization runs on. Available only with hqq quantization", + ) parser.add_argument( "--force", action="store_true", @@ -51,6 +63,8 @@ def convert_from_args(self, args: argparse.Namespace) -> str: args.output_dir, vmap=args.vocab_mapping, quantization=args.quantization, + group_size=args.group_size, + device=args.device, force=args.force, ) @@ -59,6 +73,8 @@ def convert( output_dir: str, vmap: Optional[str] = None, quantization: Optional[str] = None, + group_size: Optional[int] = None, + device: Optional[str] = None, force: bool = False, ) -> str: """Converts the model to the CTranslate2 format. @@ -69,6 +85,10 @@ def convert( in the converted model directory. quantization: Weight quantization scheme (possible values are: int8, int8_float32, int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + group_size: group size used by the quantization in lower bit (possible values + are: 32, 64, 128...) + device: Device where the compute of the scales and zero in quantization lower bit + runs (possible values are: cuda, cpu) force: Override the output directory if it already exists. Returns: @@ -95,7 +115,7 @@ def convert( model_spec.register_vocabulary_mapping(vmap) model_spec.validate() - model_spec.optimize(quantization=quantization) + model_spec.optimize(quantization=quantization, group_size=group_size, device=device) # Create model directory. if os.path.exists(output_dir): diff --git a/python/ctranslate2/specs/common_spec.py b/python/ctranslate2/specs/common_spec.py index b1162839c..d16ed88eb 100644 --- a/python/ctranslate2/specs/common_spec.py +++ b/python/ctranslate2/specs/common_spec.py @@ -29,6 +29,7 @@ class Quantization(enum.IntEnum): CT2 = 0 AWQ_GEMM = 1 AWQ_GEMV = 2 + HQQ_INT4 = 3 class LayerNormSpec(model_spec.LayerSpec): @@ -62,4 +63,5 @@ class EmbeddingsSpec(model_spec.LayerSpec): def __init__(self): self.weight = None self.weight_scale = model_spec.OPTIONAL + self.weight_zero = model_spec.OPTIONAL self.multiply_by_sqrt_depth = model_spec.OPTIONAL diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index 41710ff41..5a086191a 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -11,6 +11,11 @@ import struct from typing import Dict, List, Optional +try: + from hqq.core.quantize import HQQLinear, BaseQuantizeConfig, Quantizer + hqq_is_available = True +except ImportError: + hqq_is_available = False import numpy as np @@ -21,10 +26,15 @@ except ImportError: torch_is_available = False +from ctranslate2.specs import ( + common_spec, +) + OPTIONAL = "__optional" CURRENT_BINARY_VERSION = 6 ACCEPTED_MODEL_TYPES = ( + "hqq_int4", "int8", "int8_float32", "int8_float16", @@ -37,6 +47,10 @@ SKIP_CREATING_ALIAS = ("rotary_scaling_long_factor", "rotary_scaling_short_factor") +DEVICE = ( + "cuda", + "cpu", +) def _join_scope(scope, name): if not scope: @@ -188,7 +202,50 @@ def _alias_variables(self): setattr(spec, attr_name, other_name) break - def _quantize(self, quantization): + def _hqq_quants_to_torch_quants(self, w_q, scales, zeros, shape, nbits=4): + max_int = 2**nbits - 1 + min_int = 0 + dump = 2 ** (nbits - 1) + + # HQQ -> torch logic + new_zeros = (scales * dump) - zeros * scales + + min_val = new_zeros - scales * dump + + # group_quantize_tensor_from_qparams + w_r = (w_q - zeros) * scales + + w_q = ( + w_r.sub(min_val) + .div(scales) + .round() + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape(shape) + .contiguous() + ) + + # group_dequantize_tensor_from_qparams + # W_r = W_q*scales + min_val + + scales = scales.contiguous().reshape(shape[0], -1) + new_zeros = new_zeros.contiguous().reshape(shape[0], -1) + scale_and_zero = ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + new_zeros.reshape(new_zeros.size(0), new_zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + + return w_q, scale_and_zero + + def _quantize(self, quantization, group_size, device): """Possibly quantizes the variable of the layer.""" if quantization is not None and quantization not in ACCEPTED_MODEL_TYPES: raise ValueError( @@ -202,6 +259,7 @@ def _quantize(spec, name, value): key = _split_scope(name)[-1] scale = None + zero = None is_quantizable = hasattr(spec, "%s_scale" % key) is_convertible = value.dtype in ("float32", "float16", "bfloat16") @@ -244,11 +302,37 @@ def _quantize(spec, name, value): value = NumpyVariable(value) elif quantization in ("float16", "bfloat16", "float32"): value = value.to(quantization) + elif quantization in ( + "hqq_int4", + ) and value.shape != 3 and hqq_is_available: + if 'embeddings' in name: + value = value.to("bfloat16") + else: + quant_config = BaseQuantizeConfig(nbits=4, group_size=group_size, + quant_zero=False, quant_scale=False, axis=1) + hqq_linear = HQQLinear(None, quant_config=quant_config, + compute_dtype=torch.bfloat16, device=device) + hqq_linear.quantize(value.to("bfloat16").tensor, **hqq_linear.quant_config) + + value = hqq_linear.W_q.cpu() + scale = hqq_linear.meta['scale'].cpu() + zero = hqq_linear.meta['zero'].cpu() + old_shape = hqq_linear.meta['shape'] + value = Quantizer.unpack[hqq_linear.meta["packing"]](value) + value, scale = self._hqq_quants_to_torch_quants(value, scale, zero, old_shape) + + scale = scale.cpu() + value = value.cpu() + scale = PyTorchVariable(scale) + value = PyTorchVariable(value) + del hqq_linear.W_q + del hqq_linear.meta['scale'] + del hqq_linear.meta['zero'] elif is_convertible: if quantization in ("float16", "int8_float16"): value = value.to("float16") - elif quantization in ("bfloat16", "int8_bfloat16"): + elif quantization in ("bfloat16", "int8_bfloat16", "hqq_int4"): value = value.to("bfloat16") elif quantization in ("float32", "int16", "int8_float32"): value = value.to("float32") @@ -259,7 +343,9 @@ def _quantize(spec, name, value): self._visit(_quantize) - def optimize(self, quantization: Optional[str] = None) -> None: + def optimize(self, quantization: Optional[str] = None, + group_size: Optional[int] = None, + device: Optional[str] = None) -> None: """Recursively applies some optimizations to this layer: * Alias variables with the same shape and value. @@ -268,9 +354,11 @@ def optimize(self, quantization: Optional[str] = None) -> None: Arguments: quantization: Weight quantization scheme (possible values are: int8, int8_float32, int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + group_size: Group size of quantization lower bits + device: device where the quantization runs on (support hqq only) """ self._alias_variables() - self._quantize(quantization) + self._quantize(quantization, group_size, device) def _visit(self, fn): """Recursively visits this layer and its children.""" @@ -414,6 +502,34 @@ def _write_string(string): _write_string(variable_name) + def _save_quantization_type(self, quantization, group_size): + if quantization == 'hqq_int4': + self._config.add_attribute("quantization_type", common_spec.Quantization.HQQ_INT4) + elif quantization is not None: + self._config.add_attribute("quantization_type", common_spec.Quantization.CT2) + + if group_size is not None: + self._config.add_attribute("quantization_group_size", group_size) + + def optimize(self, quantization: Optional[str] = None, + group_size: Optional[int] = None, + device: Optional[str] = None) -> None: + """Recursively applies some optimizations to its layer: + + * Alias variables with the same shape and value. + * Quantize weights. + + Arguments: + quantization: Weight quantization scheme (possible values are: int8, int8_float32, + int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + group_size: Group size of quantization lower bits + device: device where the quantization runs on (support hqq only) + """ + self._save_quantization_type(quantization, group_size) + self._alias_variables() + self._quantize(quantization, group_size, device) + + def _flatten_vocabularies(vocabularies): for name, vocabulary in vocabularies.items(): if len(vocabulary) == 1: diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index abb812c8b..954819cfa 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -233,7 +233,9 @@ def __init__( if quant_type is not None: self._config["quantization_type"] = quant_type + if quant_bits is not None: self._config["quantization_bits"] = quant_bits + if quant_group_size is not None: self._config["quantization_group_size"] = quant_group_size @property diff --git a/src/layers/common.cc b/src/layers/common.cc index 86fb66a7d..bf1c7b226 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -290,7 +290,8 @@ namespace ctranslate2 { /*trans_b=*/true, /*a_is_packed=*/false, _packed_weight, - _quantized_gemm ? nullptr : activation_type) + _quantized_gemm ? nullptr : activation_type, + (model.get_config_if_exists("quantization_group_size"))) , _quantize_op(model.use_global_int16_scale() ? ops::Quantize::ScaleType::GLOBAL : ops::Quantize::ScaleType::PER_LAYER, @@ -427,6 +428,8 @@ namespace ctranslate2 { throw std::invalid_argument("Dense forward: invalid quantized type," "support only ct2 and awq quantization"); } + } else if (_quant_method == models::QUANTIZATION_TYPE::HQQ_4BIT) { + _gemm_op(input, *weight, *qscale, output, bias); } else { _gemm_op(input, *weight, output, nullptr, bias); } diff --git a/src/models/model.cc b/src/models/model.cc index b8e1c2d8f..c84f085e3 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -50,6 +50,15 @@ namespace ctranslate2 { + std::to_string(position)); } + static void format_weight(const std::string name, StorageView& weight, int device_index) { + if (!ends_with(name, "embeddings/weight") && ends_with(name, "weight")) { + ops::Gemm gemm_ops; + StorageView tmp = gemm_ops.convert_to_int4pack(weight, 8); + weight = std::move(tmp); + } + synchronize_device(weight.device(), device_index); + } + template T consume(std::istream& in) { const std::streampos position = in.tellg(); @@ -87,19 +96,23 @@ namespace ctranslate2 { } template - static void move_variables_to_device(VariablesCollection& variables, const Device device) { + static void move_variables_to_device(VariablesCollection& variables, const Device device, + bool format_lower_bit, const int device_index) { for (auto& pair : variables) { StorageView& variable = *pair.second; if (variable.is_scalar() || variable.device() == device) continue; variable = variable.to(device); + if (format_lower_bit) + format_weight(pair.first, variable, device_index); } } template static void move_variables(VariablesCollection& variables, const Device src_device, const int src_device_index, - const Device dst_device, const int dst_device_index) { + const Device dst_device, const int dst_device_index, + const bool format_lower_bit) { if (variables.empty()) return; if (src_device == dst_device && src_device_index == dst_device_index) @@ -108,13 +121,13 @@ namespace ctranslate2 { // Move variables back to the CPU device. if (src_device != Device::CPU && dst_device == Device::CPU) { ScopedDeviceSetter scoped_device_setter(src_device, src_device_index); - move_variables_to_device(variables, Device::CPU); + move_variables_to_device(variables, Device::CPU, format_lower_bit, 0); } // Move variables to the destination device. if (src_device == Device::CPU && dst_device != Device::CPU) { ScopedDeviceSetter scoped_device_setter(dst_device, dst_device_index); - move_variables_to_device(variables, dst_device); + move_variables_to_device(variables, dst_device, format_lower_bit, dst_device_index); } synchronize_device(src_device, src_device_index); // Wait for asynchronous deallocations. @@ -167,8 +180,8 @@ namespace ctranslate2 { return 1; } - void Model::set_device(const Device device, const int index) { - move_variables(_variable_index, _device, _device_index, device, index); + void Model::set_device(const Device device, const int index, const bool format_lower_bit) { + move_variables(_variable_index, _device, _device_index, device, index, format_lower_bit); _device = device; _device_index = index; } @@ -231,6 +244,15 @@ namespace ctranslate2 { } } + template + T Model::get_config_if_exists(const std::string &name) const { + T value = 0; + if (config.contains(name)) { + value = config[name]; + } + return value; + } + const StorageView* Model::get_variable_if_exists(const std::string& name) const { auto it = _variable_index.find(name); if (it == _variable_index.end()) @@ -640,8 +662,9 @@ namespace ctranslate2 { QUANTIZATION_TYPE quantization_type = QUANTIZATION_TYPE::CT2; if (model->config.contains("quantization_type")) - model->set_quant_method(model->config["quantization_type"]); + quantization_type =model->config["quantization_type"]; + model->set_quant_method(quantization_type); for (uint32_t i = 0; i < num_variables; ++i) { auto name = consume(model_file); const size_t rank = consume(model_file); @@ -759,13 +782,16 @@ namespace ctranslate2 { case QUANTIZATION_TYPE::AWQ_GEMV: model->set_compute_type(ComputeType::FLOAT16, device, device_index, false); break; + case QUANTIZATION_TYPE::HQQ_4BIT: + model->set_compute_type(compute_type, device, device_index, false); + break; default: throw std::invalid_argument("Quantization type is not supported"); break; } // Move variables to the target device. - model->set_device(device, device_index); + model->set_device(device, device_index, quantization_type == QUANTIZATION_TYPE::HQQ_4BIT); // Register variable aliases. if (binary_version >= 3) { @@ -895,6 +921,10 @@ namespace ctranslate2 { return models; } +#define DECLARE_IMPL(T) \ + template T \ + Model::get_config_if_exists(const std::string& name) const; + DECLARE_IMPL(int) } } diff --git a/src/ops/dequantize.cc b/src/ops/dequantize.cc index 4463cf948..96d50f899 100644 --- a/src/ops/dequantize.cc +++ b/src/ops/dequantize.cc @@ -1,4 +1,4 @@ -#include "ctranslate2/ops/dequantize.h" + #include "ctranslate2/ops/dequantize.h" #include "dispatch.h" diff --git a/src/ops/dequantize_gpu.cu b/src/ops/dequantize_gpu.cu index 241b3acdb..ef3adc11d 100644 --- a/src/ops/dequantize_gpu.cu +++ b/src/ops/dequantize_gpu.cu @@ -7,6 +7,10 @@ namespace ctranslate2 { template struct dequantize_func { + __device__ __forceinline__ + OutT operator()(float scale, InT x, float zero) const { + return __fdividef(__fsub_rn(static_cast(x), zero) , scale); + } __device__ __forceinline__ OutT operator()(float scale, InT x) const { return __fdividef(static_cast(x), scale); @@ -26,7 +30,6 @@ namespace ctranslate2 { cuda::repeat_vec_depth(depth)); } - template __global__ void dequantize_gemm_output_kernel(const int32_t* c, const float* a_scales, diff --git a/src/ops/gemm.cc b/src/ops/gemm.cc index e6ff87f9d..465d3af4a 100644 --- a/src/ops/gemm.cc +++ b/src/ops/gemm.cc @@ -25,7 +25,8 @@ namespace ctranslate2 { bool trans_b, bool a_is_packed, bool b_is_packed, - const ActivationType* activation_type) + const ActivationType* activation_type, + const int group_size) : _alpha(alpha) , _beta(beta) , _trans_a(trans_a) @@ -33,6 +34,7 @@ namespace ctranslate2 { , _a_is_packed(a_is_packed) , _b_is_packed(b_is_packed) , _activation_type(activation_type) + , _group_size(group_size) { } @@ -69,6 +71,24 @@ namespace ctranslate2 { apply_bias_and_activation(c, bias, _activation_type); } + void Gemm::operator()(const StorageView& a, + const StorageView& b, + const StorageView& scaleAndZero, + StorageView& c, + const StorageView* bias) const { + PROFILE("Gemm"); +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + dim_t batch_size = a.dim(0); + dim_t time = a.dim(1); + DEVICE_DISPATCH(a.device(), (compute(a, b, scaleAndZero, c))); + c.reshape({batch_size, time, -1}); + + apply_bias_and_activation(c, bias, _activation_type); +#else + throw std::runtime_error("int4mm is supported only GPU Arch >= 800"); +#endif + } + template void Gemm::compute(const StorageView& a, const StorageView& b, @@ -177,5 +197,15 @@ namespace ctranslate2 { return compensation; } + StorageView Gemm::convert_to_int4pack(const StorageView& input, + int32_t innerKTiles) { + StorageView output(input.device(), input.dtype()); +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800) + DEVICE_DISPATCH(input.device(), (convert_weight_to_int4pack(input, output, innerKTiles))); +#else + throw std::runtime_error("convert weight to int4pack is supported only GPU Arch >= 800"); +#endif + return output; + } } } diff --git a/src/ops/int4gemm_cpu.cc b/src/ops/int4gemm_cpu.cc new file mode 100644 index 000000000..bf24b21f0 --- /dev/null +++ b/src/ops/int4gemm_cpu.cc @@ -0,0 +1,31 @@ +#include + +namespace ctranslate2 { + namespace ops { + template + void Gemm::compute(const StorageView& a, + const StorageView& b, + const StorageView& scaleAndZero, + StorageView& c) const { + // todo + throw std::runtime_error("int4mm is not supported for CPU"); + } + + template <> + void Gemm::convert_weight_to_int4pack(const StorageView& a, + StorageView& b, + int32_t innerKTiles) { + // todo + throw std::runtime_error("convert_weight_to_int4pack is not supported for CPU"); + } + +#define DECLARE_IMPL(T) \ + template void \ + Gemm::compute(const StorageView& a, \ + const StorageView& b, \ + const StorageView& scaleAndZero, \ + StorageView& c) const; + + DECLARE_IMPL(bfloat16_t) + } +} \ No newline at end of file diff --git a/src/ops/int4gemm_gpu.cu b/src/ops/int4gemm_gpu.cu new file mode 100644 index 000000000..5f564031c --- /dev/null +++ b/src/ops/int4gemm_gpu.cu @@ -0,0 +1,974 @@ +#include +#include "cuda/helpers.h" + +namespace ctranslate2 { + namespace ops { +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) + template + constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral::value && std::is_integral::value, ""); + return (a / b); + } + + template + constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral::value && std::is_integral::value, ""); + // Overflow safe variant of (a + b - 1) / b + const uint64_t blocks = a / b + (a % b != 0); + return blocks; + } + + template + constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral::value && std::is_integral::value, ""); + return divDown(a, b) * b; + } + + template + constexpr __host__ __device__ bool isEvenDivisor(U a, V b) { + static_assert(std::is_integral::value && std::is_integral::value, ""); + return (a % V(b) == 0) && ((a / V(b)) >= 1); + } + + template + constexpr __host__ __device__ T pow(T n, int power) { + return (power > 0 ? n * pow(n, power - 1) : 1); + } + + template + constexpr __host__ __device__ T pow2(int power) { + return pow(2, power); + } + + static_assert(pow2(8) == 256, "pow2"); + + template + constexpr __host__ __device__ int log2(T n, int p = 0) { + return (n <= 1) ? p : log2(n / 2, p + 1); + } + + static_assert(log2(2) == 1, "log2"); + static_assert(log2(3) == 1, "log2"); + static_assert(log2(4) == 2, "log2"); + + template + constexpr __host__ __device__ bool isPowerOf2(T v) { + static_assert(std::is_integral::value, ""); + return (v && !(v & (v - 1))); + } + + static_assert(isPowerOf2(2048), "isPowerOf2"); + static_assert(!isPowerOf2(3333), "isPowerOf2"); + + template + constexpr __host__ __device__ T nextHighestPowerOf2(T v) { + static_assert(std::is_integral::value, ""); + return (isPowerOf2(v) ? (T)2 * v : ((T)1 << (log2(v) + 1))); + } + + static_assert(nextHighestPowerOf2(1) == 2, "nextHighestPowerOf2"); + static_assert(nextHighestPowerOf2(2) == 4, "nextHighestPowerOf2"); + static_assert(nextHighestPowerOf2(3) == 4, "nextHighestPowerOf2"); + static_assert(nextHighestPowerOf2(4) == 8, "nextHighestPowerOf2"); + + static_assert(nextHighestPowerOf2(15) == 16, "nextHighestPowerOf2"); + static_assert(nextHighestPowerOf2(16) == 32, "nextHighestPowerOf2"); + static_assert(nextHighestPowerOf2(17) == 32, "nextHighestPowerOf2"); + + static_assert( + nextHighestPowerOf2(1536000000u) == 2147483648u, + "nextHighestPowerOf2"); + static_assert( + nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL, + "nextHighestPowerOf2"); + + template + constexpr __host__ __device__ T nextLowestPowerOf2(T v) { + static_assert(std::is_integral::value, ""); + return (isPowerOf2(v) ? v / (T)2 : ((T)1 << (log2(v)))); + } + + static_assert(nextLowestPowerOf2(1) == 0, "nextLowestPowerOf2"); + static_assert(nextLowestPowerOf2(2) == 1, "nextLowestPowerOf2"); + static_assert(nextLowestPowerOf2(3) == 2, "nextLowestPowerOf2"); + static_assert(nextLowestPowerOf2(4) == 2, "nextLowestPowerOf2"); + + static_assert(nextLowestPowerOf2(15) == 8, "nextLowestPowerOf2"); + static_assert(nextLowestPowerOf2(16) == 8, "nextLowestPowerOf2"); + static_assert(nextLowestPowerOf2(17) == 16, "nextLowestPowerOf2"); + + inline __host__ __device__ bool isPointerAligned(const void* p, int align) { + return reinterpret_cast(p) % align == 0; + } + +// Returns the increment needed to aligned the pointer to the next highest +// aligned address + template + inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) { + static_assert(isPowerOf2(Align), ""); + const uint32_t diff = uint32_t(uintptr_t(p) & uintptr_t(Align - 1)); + return diff == 0 ? 0 : uint32_t(Align) - diff; + } + + constexpr int32_t kWarpSize = 32; + // f16 vector types + + struct __align__(16) bf16x2x4 { + __nv_bfloat162 vals[4]; + }; + + struct __align__(16) bf16x2x4_u32 { + uint32_t vals[4]; + }; + + // from + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { + bf16x2x4 result; + constexpr int kElements = 8; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = source; + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so + // we must loop. No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#pragma unroll + for (int ii = 1; ii < kElements / 2; ++ii) { + i4s >>= 4; // or is it 8? + // (i4s & 0x000f000f) | 0x43004300 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. +#pragma unroll + for (int ii = 0; ii < kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias + // subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[ii]) + : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } + + return result; + } + + + + enum class KReductionType { + // No k-reduction is needed between blocks as the number of k-tiles processed + // per block are exact and we can directly write the output + None, + }; + + // Loads the A matrix in 16-bit standard m x k row major layout, and writes + // the C matrix in 16-bit standard m x n row major layout: + // + // size [m][k] + template + struct ALayout_RM { + static constexpr int32_t kMTileSize = 16; + static constexpr int32_t kNTileSize = 8; + static constexpr int32_t kKTileSize = 16; + + template + static __device__ void load( + const void* A, + int32_t m, + int32_t k, + int32_t mTiles, + int32_t mTile, + int32_t kTiles, + int32_t kTileStart, + int32_t laneId, + bf16x2x4_u32 out[KTilesToLoad]) { + const auto mLane = mTile * kMTileSize + (laneId / 4); + const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 2; + + // access + // [mTile * kMTileSize + (laneId / 4)] + // [kTileStart * kKTileSize + (laneId % 4) * 2] + auto aPtr = reinterpret_cast(A) + mLane * k + kLane; + + auto aPtrPlus8Rows = aPtr + 8 * k; + + bool m0InBounds = mLane < m; + bool m1InBounds = (mLane + 8) < m; + +#pragma unroll + for (int i = 0; i < KTilesToLoad; ++i) { + out[i].vals[0] = m0InBounds + ? *reinterpret_cast(aPtr + i * kKTileSize) + : uint32_t(0); + out[i].vals[1] = m1InBounds + ? *reinterpret_cast(aPtrPlus8Rows + i * kKTileSize) + : uint32_t(0); + + out[i].vals[2] = m0InBounds + ? *reinterpret_cast(aPtr + i * kKTileSize + 8) + : uint32_t(0); + out[i].vals[3] = m1InBounds ? *reinterpret_cast( + aPtrPlus8Rows + i * kKTileSize + 8) + : uint32_t(0); + } + } + + static __device__ void store( + void* C, + int32_t m, + int32_t n, + int32_t mOutTiles, + int32_t mTile, + int32_t nOutTiles, + int32_t nTile, + int32_t laneId, + const float4& out) { + static_assert(ReduceType == KReductionType::None, ""); + + if constexpr (ReduceType == KReductionType::None) { + // sum.x / sum.y are written at + // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1] + // sum.z / sum.w are written at + // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1] + // i.e., same columns, different row. + const int outRow = mTile * kMTileSize + (laneId / 4); + const int outCol = nTile * kNTileSize + (laneId % 4) * 2; + + // Pointer where sum.x / sum.y is written + auto cPtr = reinterpret_cast<__nv_bfloat16*>(C) + outRow * n + outCol; + + auto v01 = __float22bfloat162_rn(float2{out.x, out.y}); + auto v23 = __float22bfloat162_rn(float2{out.z, out.w}); + + if (outRow < m) { + *reinterpret_cast<__nv_bfloat162*>(cPtr) = v01; + } + + // sum.z, sum.w at +8 rows from cPtr + if (outRow + 8 < m) { + *reinterpret_cast<__nv_bfloat162*>(cPtr + 8 * n) = v23; + } + } + } + }; + + template + struct BLayout_TC_int4 { + static constexpr int32_t kInnerKTiles = InnerKTiles; + static constexpr int32_t kMTileSize = 16; + static constexpr int32_t kNTileSize = 8; + static constexpr int32_t kKTileSize = 16; + + template + static __device__ void load( + // type uint32, size [n / 8][k / (InnerKTiles * 16)][32][InnerKTiles / 2] + // n / 8: n-tiles (n8) + // k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16) + // 32: value per warp lane + // (InnerKTiles / 2): B layout has 4 values per lane (16 bits) per k-tile. + // 2 k-tiles packed is a uint32 (hence InnerKTiles == 2 is our smallest + // value) 4 k-tiles packed is a uint32x2 (64 bits) 8 k-tiles packed is a + // uint32x4 (128 bits) + const void* __restrict__ B, + // size [k / qGroupSize][n][2] + // Contains the scale and zero point of each of the quantized int4 values + // within B + // v_reconstructed = (bf16(B_int4_val) * scale) - zero + const void* __restrict__ quantizationInfo, + int32_t n, + int32_t k, + int32_t nTiles, + int32_t nTile, + int32_t kTiles, + int32_t kTileStart, + int32_t laneId, + bf16x2x4_u32 out[KTilesToLoad / InnerKTiles][InnerKTiles / 2]) { + // offset [nTile][kTileStart / InnerKTiles][laneId][0] + auto bPtr = reinterpret_cast(B) + + (((nTile * (kTiles / InnerKTiles) + (kTileStart / InnerKTiles)) * + kWarpSize) + + laneId) * + (InnerKTiles / 2); + + int32_t b_int4[KTilesToLoad / InnerKTiles][InnerKTiles / 2]; + +#pragma unroll + for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) { + auto bPtrCur = bPtr + i * kWarpSize * (InnerKTiles / 2); + + if constexpr (InnerKTiles == 2) { + b_int4[i][0] = bPtrCur[0]; + } + + if constexpr (InnerKTiles == 4) { + // asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n" + // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]) + // : "l"(bPtrCur)); + + int2 load8 = reinterpret_cast(bPtrCur)[0]; + b_int4[i][0] = load8.x; + b_int4[i][1] = load8.y; + } + + if constexpr (InnerKTiles == 8) { + // asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n" + // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]), + // "=r"(b_int4[i][2]), "=r"(b_int4[i][3]) : "l"(bPtrCur)); + + int4 load16 = reinterpret_cast(bPtrCur)[0]; + b_int4[i][0] = load16.x; + b_int4[i][1] = load16.y; + b_int4[i][2] = load16.z; + b_int4[i][3] = load16.w; + } + } + + // Load needed info for dequantization + + static_assert(isPowerOf2(QGroupSize), ""); + static_assert(isEvenDivisor(QGroupSize, kKTileSize), ""); + // smallest quantization group size is 32 (2 k-tiles are packed in an int32) + static_assert(QGroupSize >= kKTileSize * 2, ""); + constexpr int kKTilesPerQGroup = (QGroupSize / kKTileSize); + // a q-group could be larger than what we are handling in a single warp + constexpr int kNumQGroups = (KTilesToLoad / kKTilesPerQGroup) < 1 + ? 1 + : (KTilesToLoad / kKTilesPerQGroup); + + __nv_bfloat162 qScaleAndZero[kNumQGroups]; + { + int32_t laneN = nTile * kNTileSize + (laneId / 4); + int32_t groupStart = (kTileStart * kKTileSize) / QGroupSize; + + int32_t n = nTiles * kNTileSize; + + // offset [qScale_kGroup][qScale_n][0] + auto qInfoPtr = reinterpret_cast(quantizationInfo) + + (groupStart * n + laneN) * 2; + +#pragma unroll + for (int i = 0; i < kNumQGroups; ++i) { + qScaleAndZero[i] = + *reinterpret_cast(qInfoPtr + i * n * 2); + } + } + + // + // De-quantize int4 values to bf16. Values are dequantized as truly int4 + // [-8, 7] range; dequant = (bf16(int4_value) * bf16_scale) + bf16_zero + // + { + // FIXME: does this negatively affect register counts, or will nvcc + // move this expansion (and data loads above) closer to the point of use? + __nv_bfloat162 qScale[kNumQGroups]; + __nv_bfloat162 qZero[kNumQGroups]; + +#pragma unroll + for (int i = 0; i < kNumQGroups; ++i) { + qScale[i] = __bfloat162bfloat162(qScaleAndZero[i].x); + qZero[i] = __bfloat162bfloat162(qScaleAndZero[i].y); + } + +#pragma unroll + for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) { +#pragma unroll + for (int j = 0; j < InnerKTiles / 2; ++j) { + bf16x2x4 v = convert_i4x8_to_bf16x2x4(b_int4[i][j]); + + int curKTile = i * InnerKTiles + j * 2; + int curQGroup = (curKTile * kKTileSize) / QGroupSize; + + // The dequantized values in `v` for a given lane have the same n + // dimension (the B tensor core layout has all values in the same + // thread along the same n) but different k dimension, but all are + // guaranteed to occur within the same quantization group, so we need + // only load a single scale + zero to cover what this lane has +#pragma unroll + for (int k = 0; k < 4; ++k) { + v.vals[k] = __hfma2(v.vals[k], qScale[curQGroup], qZero[curQGroup]); + } + + // type pun, the __nv_bfloat162 value in bf16x2x4 is a struct and + // can't be used as a 32-bit asm register argument for `mma` + static_assert(sizeof(bf16x2x4) == sizeof(out[0][0]), ""); + std::memcpy(&out[i][j], &v, sizeof(bf16x2x4_u32)); + } + } + } + } + }; + + template < + typename ALayout, + typename BLayout, + typename CLayout, + int Warps, + int KTilesPerIteration> + __global__ + __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( + // Data for the A matrix, loaded as per ALayout + const void* const __restrict__ A, + + // Data for the B matrix, loaded as per BLayout + const void* const __restrict__ B, + + // Optional quantization data for dequantizing B, loaded as per BLayout + const void* const __restrict__ B_quantizationInfo, + + // Output data for the C matrix, stored as per CLayout + void* __restrict__ C, + + // The size of the matrix multiplication + int32_t m, + int32_t n, + int32_t k, + + // The size of the matrix multiplication, in multiples of our TC tile size + int32_t mTiles, + int32_t nTiles, + int32_t kTiles) { + constexpr int32_t kMTileSize = 16; + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + static_assert( + ALayout::kMTileSize == kMTileSize && ALayout::kNTileSize == kNTileSize && + ALayout::kKTileSize == kKTileSize, + ""); + + static_assert( + BLayout::kMTileSize == kMTileSize && BLayout::kNTileSize == kNTileSize && + BLayout::kKTileSize == kKTileSize, + ""); + + static_assert( + CLayout::kMTileSize == kMTileSize && CLayout::kNTileSize == kNTileSize && + CLayout::kKTileSize == kKTileSize, + ""); + + constexpr int kInnerKTiles = BLayout::kInnerKTiles; + + // 2/4/8 inner k-tiles correspond to 4, 8 and 16 byte innermost loads + static_assert( + kInnerKTiles == 2 || kInnerKTiles == 4 || kInnerKTiles == 8, ""); + + // We always process at least kInnerKTiles k-tiles back to back in a warp + static_assert( + KTilesPerIteration >= kInnerKTiles && + isEvenDivisor(KTilesPerIteration, kInnerKTiles), + ""); + + auto warpId = threadIdx.y; + auto laneId = threadIdx.x; + + int32_t mTile = blockIdx.z; + int32_t nTile = blockIdx.y; + + float4 c{0.0f, 0.0f, 0.0f, 0.0f}; + + // First, handle whole multiples of KTilesPerIteration + auto kTilesLimit = roundDown(kTiles, KTilesPerIteration); + + // Each warp handles a set of KTilesPerIteration under the above limit + for (int32_t kTileBase = (blockIdx.x * Warps + warpId) * KTilesPerIteration; + kTileBase < kTilesLimit; + kTileBase += Warps * KTilesPerIteration) { + // + // Load data from A + // + bf16x2x4_u32 a[KTilesPerIteration]; + ALayout::template load( + A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a); + + // + // Load data from B and de-quantize as needed + // Each k-tile is bf16x2x2 + // + bf16x2x4_u32 b[KTilesPerIteration / kInnerKTiles][kInnerKTiles / 2]; + BLayout::template load( + B, + B_quantizationInfo, + n, + k, + nTiles, + nTile, + kTiles, + kTileBase, + laneId, + b); + + // + // Now, perform the matrix multiplication + // + + // We accumulate across k-tiles here +#pragma unroll + for (int i = 0; i < KTilesPerIteration / kInnerKTiles; ++i) { + static_assert(isEvenDivisor(kInnerKTiles, 2) && kInnerKTiles >= 2, ""); +#pragma unroll + for (int j = 0; j < kInnerKTiles / 2; ++j) { + // We don't simply accumulate into `c` as this creates a too-strong + // execution dependency. Instead, we only periodically accumulate into + // `c` + float4 cTmp[2]; + +#pragma unroll + for (int k = 0; k < 2; ++k) { + cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" + : "=f"(cTmp[k].x), + "=f"(cTmp[k].y), + "=f"(cTmp[k].z), + "=f"(cTmp[k].w) + : "r"(a[i * kInnerKTiles + j * 2 + k].vals[0]), + "r"(a[i * kInnerKTiles + j * 2 + k].vals[1]), + "r"(a[i * kInnerKTiles + j * 2 + k].vals[2]), + "r"(a[i * kInnerKTiles + j * 2 + k].vals[3]), + "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]), + "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]), + "f"(cTmp[k].x), + "f"(cTmp[k].y), + "f"(cTmp[k].z), + "f"(cTmp[k].w)); + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { + c.x += cTmp[k].x; + c.y += cTmp[k].y; + c.z += cTmp[k].z; + c.w += cTmp[k].w; + } + } + } + } // for all tiles under kTilesLimit + + // Now, there could be a remainder of 1 to KTilesPerIteration - 1 k-tiles + // remaining. We guarantee that the number of warps is >= KTilesPerIteration / + // kInnerKTiles, so that each warp can simply load kInnerKTiles and do its + // thing without needing more warps + static_assert(Warps >= KTilesPerIteration / kInnerKTiles, ""); + + auto kTileBaseRemaining = kTilesLimit + warpId * kInnerKTiles; + + // If we have any remainder k-tiles, some warps will handle them, processing + // kInnerKTiles k-tiles at a time + if (kTileBaseRemaining < kTiles) { + bf16x2x4_u32 a[kInnerKTiles]; + ALayout::template load( + A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, a); + + bf16x2x4_u32 b[1][kInnerKTiles / 2]; + BLayout::template load( + B, + B_quantizationInfo, + n, + k, + nTiles, + nTile, + kTiles, + kTileBaseRemaining, + laneId, + b); + +#pragma unroll + for (int j = 0; j < kInnerKTiles / 2; ++j) { + // We don't simply accumulate into `c` as this creates a too-strong + // execution dependency. Instead, we only periodically accumulate into + // `c` + float4 cTmp[2]; + +#pragma unroll + for (int k = 0; k < 2; ++k) { + cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" + : "=f"(cTmp[k].x), "=f"(cTmp[k].y), "=f"(cTmp[k].z), "=f"(cTmp[k].w) + : "r"(a[j * 2 + k].vals[0]), + "r"(a[j * 2 + k].vals[1]), + "r"(a[j * 2 + k].vals[2]), + "r"(a[j * 2 + k].vals[3]), + "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]), + "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]), + "f"(cTmp[k].x), + "f"(cTmp[k].y), + "f"(cTmp[k].z), + "f"(cTmp[k].w)); + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { + c.x += cTmp[k].x; + c.y += cTmp[k].y; + c.z += cTmp[k].z; + c.w += cTmp[k].w; + } + } + } + + // + // Reduce independent k-tiles (same m/n) across warps + // + __shared__ float4 smem_sum[Warps][kWarpSize]; + + // FIXME: this likely doesn't need to be a true reduction tree, can just be a + // serial sum, maybe (unless nvcc/ptxas goes back to its old ways) + // smem_sum[warpId][laneId] = TreeReduce4::reduce(c); + smem_sum[warpId][laneId] = c; + + __syncthreads(); + + if (warpId == 0) { + float4 sum_f32{0.0f, 0.0f, 0.0f, 0.0f}; + + // Reduce across the block in the first warp + for (int i = 0; i < Warps; ++i) { + float4 v = smem_sum[i][laneId]; + sum_f32.x += v.x; + sum_f32.y += v.y; + sum_f32.z += v.z; + sum_f32.w += v.w; + } + + // Write the reduced result (in the first warp) into the output + CLayout::store( + C, + m, + n, + mTiles, + mTile, + // n for C output becomes k for A input, so for m16n8k16, + // we need to halve the tiles + nTiles / 2, + nTile, + laneId, + sum_f32); + } + } + + // FIXME: parallelize better, smem staging etc? + template + __global__ void + matrix_to_m16n8k16_Bint4_layout( + // size [n][k] + const int32_t* in, + int32_t n, + int32_t depth, + int32_t depth_output1, + int32_t depth_output2, + int32_t depth_output3, + // size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2] + int32_t* out) { + // int4 values are packed into int32 values, which require at least 8. Given + // m16n8k16 B layout requires 4 scalar values/lane, the minimum number of + // innermost k-tiles that we can use is 2. + static_assert(InnerKTiles >= 2 && isPowerOf2(InnerKTiles), ""); + + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + // gridDim.x corresponds to the number of k-tiles divided by InnerKTiles + auto kOuterTile = blockIdx.x; + auto nTile = blockIdx.y; + auto t = threadIdx.x; + + // Two k-tiles are packed into an int32 at a time +#pragma unroll + for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) { + // n dimension that this lane loads from + auto n0 = nTile * kNTileSize + (t / 4); + + bool n0Valid = n0 < n; + + int32_t ks[8]; + + auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize; + ks[0] = kBase0 + (t % 4) * 2; + ks[1] = ks[0] + 1; + ks[2] = ks[0] + 8; + ks[3] = ks[0] + 8 + 1; + + auto kBase1 = kBase0 + kKTileSize; + ks[4] = kBase1 + (t % 4) * 2; + ks[5] = ks[4] + 1; + ks[6] = ks[4] + 8; + ks[7] = ks[4] + 8 + 1; + + auto pIn = in + (n0 * depth); + + uint32_t v[8]; +#pragma unroll + for (int i = 0; i < 8; ++i) { + v[i] = (n0Valid && (ks[i] < depth)) ? pIn[ks[i]] : uint32_t(0); + } + + int32_t pack = (v[7] << 28) | (v[5] << 24) | (v[3] << 20) | (v[1] << 16) | + (v[6] << 12) | (v[4] << 8) | (v[2] << 4) | v[0]; + + // inner k-tiles pack two at a time + out[nTile * depth_output3 + kOuterTile * depth_output2 + t * depth_output1 + innerKTile / 2] = pack; + } + } + + template < + typename ALayout, + typename BLayout, + typename CLayout, + int Warps, + int KTilesPerWarp> + void launch_tinygemm_kernel( + const StorageView& A, + const StorageView& B, + const StorageView* qScaleAndZeros, /* optional */ + StorageView& C_final, + int32_t mTiles, + int32_t nTiles, + int32_t kTiles, + int32_t m, + int32_t n, + int32_t k, + cudaStream_t stream) { + // After intra-block reduction across the k dimension, we are left with this + // many tiles + // int32_t postKernelKTiles = kTiles / (Warps * KTilesPerWarp); + int32_t postKernelKTiles = 1; // we loop + + auto grid = dim3(postKernelKTiles, nTiles, mTiles); + auto block = dim3(kWarpSize, Warps); + + auto func = + tinygemm_m16n8k16_chunk_kernel; + + func<<>>( + A.data(), + B.data(), + qScaleAndZeros ? qScaleAndZeros->data() : nullptr, + C_final.data(), + m, + n, + k, + mTiles, + nTiles, + kTiles); + } + + template + void Gemm::compute(const StorageView& a, + const StorageView& b, + const StorageView& scaleAndZero, + StorageView& c) const { + constexpr int32_t kMTileSize = 16; + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + // row major layout + auto m = a.rank() == 3 ? a.dim(0) * a.dim(1) : a.dim(0); + auto mTiles = divUp(m, kMTileSize); + + // tensor core layout + auto nTiles = b.dim(0); + auto n = nTiles * kNTileSize; + + // row major layout + auto k = a.rank() == 3 ? a.dim(2) : a.dim(1); + auto kTiles = divUp(k, kKTileSize); + + // The number of inner k tiles is the innermost dimension of times 2 + // 2 k-tiles (4 values per lane per tile, 8 values total) quantized to int4 + // packed into 1 int32 for int4 B + const int32_t B_innerKTiles = b.dim(3) * 2; + + //TORCH_CHECK(qScaleAndZeros.dim() == 3); + auto numQGroups = scaleAndZero.dim(0); + // Output is a standard row-major matrix + c.resize({m, n}); + auto stream = cuda::get_cuda_stream(); +#define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \ + do { \ + using ACLayout = ALayout_RM; \ + \ + switch (B_innerKTiles) { \ + case 2: \ + if constexpr (K_TILES_PER_WARP >= 2) { \ + using BLayout = BLayout_TC_int4<2, Q_GROUP_SIZE>; \ + launch_tinygemm_kernel< \ + ACLayout, \ + BLayout, \ + ACLayout, \ + WARPS, \ + K_TILES_PER_WARP>( \ + a, \ + b, \ + &scaleAndZero, \ + c, \ + mTiles, \ + nTiles, \ + kTiles, \ + m, \ + n, \ + k, \ + stream); \ + } \ + break; \ + case 4: \ + if constexpr (K_TILES_PER_WARP >= 4) { \ + using BLayout = BLayout_TC_int4<4, Q_GROUP_SIZE>; \ + launch_tinygemm_kernel< \ + ACLayout, \ + BLayout, \ + ACLayout, \ + WARPS, \ + K_TILES_PER_WARP>( \ + a, \ + b, \ + &scaleAndZero, \ + c, \ + mTiles, \ + nTiles, \ + kTiles, \ + m, \ + n, \ + k, \ + stream); \ + } \ + break; \ + case 8: \ + if constexpr (K_TILES_PER_WARP >= 8) { \ + using BLayout = BLayout_TC_int4<8, Q_GROUP_SIZE>; \ + launch_tinygemm_kernel< \ + ACLayout, \ + BLayout, \ + ACLayout, \ + WARPS, \ + K_TILES_PER_WARP>( \ + a, \ + b, \ + &scaleAndZero, \ + c, \ + mTiles, \ + nTiles, \ + kTiles, \ + m, \ + n, \ + k, \ + stream); \ + } \ + break; \ + default: \ + break; \ + } \ + } while (false) + +#define HANDLE_Q_GROUP(WARPS, K_TILES_PER_WARP, REDUCE_TYPE) \ + do { \ + switch (_group_size) { \ + case 32: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 32, REDUCE_TYPE); \ + break; \ + case 64: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 64, REDUCE_TYPE); \ + break; \ + case 128: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 128, REDUCE_TYPE); \ + break; \ + case 256: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 256, REDUCE_TYPE); \ + break; \ + } \ + } while (false) + + HANDLE_Q_GROUP(8, 8, KReductionType::None); + +#undef HANDLE_Q_GROUP +#undef RUN_GEMM + } + + template <> + void Gemm::convert_weight_to_int4pack( + const StorageView& a, + StorageView& b, + int32_t innerKTiles) { + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + auto nTiles = divUp(a.dim(0), kNTileSize); + + // k-tiles are packed back to back in the innermost dimension in order to + // allow for 4/8/16 byte loads + // kSuperTiles is the number of k-tiles assuming k is innerKTiles * kKTileSize + auto kSuperTiles = divUp(a.dim(1), innerKTiles * kKTileSize); + + // each block handles `innerKTiles` k-tiles. + // 2 k-tiles are a single int32 + b.resize({nTiles, kSuperTiles, 32, innerKTiles / 2}); + + auto stream = ctranslate2::cuda::get_cuda_stream(); + dim3 grid(kSuperTiles, nTiles); + + if (innerKTiles == 2) { + matrix_to_m16n8k16_Bint4_layout<2><<>>( + a.data(), + a.dim(0), + a.dim(-1), + b.dim(-1), + b.stride(1), + b.stride(0), + b.data()); + } else if (innerKTiles == 4) { + matrix_to_m16n8k16_Bint4_layout<4><<>>( + a.data(), + a.dim(0), + a.dim(-1), + b.dim( 1), + b.stride(1), + b.stride(0), + b.data()); + } else if (innerKTiles == 8) { + matrix_to_m16n8k16_Bint4_layout<8><<>>( + a.data(), + a.dim(0), + a.dim(-1), + b.dim(-1), + b.stride(1), + b.stride(0), + b.data()); + } + } + +#define DECLARE_IMPL(T) \ + template void \ + Gemm::compute(const StorageView& a, \ + const StorageView& b, \ + const StorageView& scaleAndZero, \ + StorageView& c) const; + + DECLARE_IMPL(bfloat16_t) +#endif + } +} \ No newline at end of file diff --git a/src/storage_view.cc b/src/storage_view.cc index 0cbdf25c2..2728c4480 100644 --- a/src/storage_view.cc +++ b/src/storage_view.cc @@ -3,7 +3,6 @@ #include "ctranslate2/primitives.h" #include "dispatch.h" - #define PRINT_MAX_VALUES 6 namespace ctranslate2 { diff --git a/src/types.cc b/src/types.cc index 2431bce66..b1440216a 100644 --- a/src/types.cc +++ b/src/types.cc @@ -87,6 +87,8 @@ namespace ctranslate2 { return "float16"; case ComputeType::BFLOAT16: return "bfloat16"; + case ComputeType::INT32_BFLOAT16: + return "int32_bfloat16"; }; throw std::invalid_argument("Invalid compute type value"); } @@ -162,8 +164,13 @@ namespace ctranslate2 { const bool support_float16 = mayiuse_float16(device, device_index); const bool support_int16 = mayiuse_int16(device, device_index); const bool support_int8 = mayiuse_int8(device, device_index); - - switch (requested_compute_type) { + const bool lower_bit_model = model_compute_type == ComputeType::INT32_BFLOAT16; + ComputeType accepted_request_compute_type = requested_compute_type; + if (lower_bit_model) { + // model quantized to lower bit could run in its mode only + accepted_request_compute_type = model_compute_type; + } + switch (accepted_request_compute_type) { case ComputeType::FLOAT32: { return ComputeType::FLOAT32; @@ -185,6 +192,12 @@ namespace ctranslate2 { return ComputeType::FLOAT32; } + case ComputeType::INT32_BFLOAT16: { + if (!support_bfloat16) + unsupported_compute_type("bfloat16"); + return ComputeType::INT32_BFLOAT16; + } + case ComputeType::INT16: { if (support_int16) return ComputeType::INT16; @@ -310,6 +323,8 @@ namespace ctranslate2 { return std::make_pair(DataType::FLOAT16, DataType::FLOAT16); case ComputeType::BFLOAT16: return std::make_pair(DataType::BFLOAT16, DataType::BFLOAT16); + case ComputeType::INT32_BFLOAT16: + return std::make_pair(DataType::INT32, DataType::BFLOAT16); default: throw std::invalid_argument("resolve_compute_type should be called first"); } @@ -329,6 +344,8 @@ namespace ctranslate2 { } case DataType::INT16: return ComputeType::INT16; + case DataType::INT32: + return ComputeType::INT32_BFLOAT16; case DataType::FLOAT16: return ComputeType::FLOAT16; case DataType::BFLOAT16: