Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] HQQ Quant 4 bit #1738

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ set(SOURCES
src/ops/awq/gemv.cc
src/ops/awq/gemv_cpu.cc
src/ops/sum.cc
src/ops/int4gemm_cpu.cc
src/padder.cc
src/profiler.cc
src/random.cc
Expand Down Expand Up @@ -324,7 +325,7 @@ if(WITH_MKL)
endif()

# Find MKL libraries.
find_library(MKL_CORE_LIBRARY NAMES mkl_core HINTS ${MKL_ROOT}/lib ${MKL_ROOT}/lib/intel64)
find_library(MKL_CORE_LIBRARY NAMES mkl_core PATHS ${MKL_ROOT}/lib ${MKL_ROOT}/lib/intel64)
if(MKL_CORE_LIBRARY)
get_filename_component(MKL_LIBRARY_DIR ${MKL_CORE_LIBRARY} DIRECTORY)
message(STATUS "Found MKL library directory: ${MKL_LIBRARY_DIR}")
Expand Down Expand Up @@ -605,6 +606,7 @@ if (WITH_CUDA)
src/ops/awq/gemm_gpu.cu
src/ops/awq/gemv_gpu.cu
src/ops/awq/dequantize_gpu.cu
src/ops/int4gemm_gpu.cu
)

set_source_files_properties(
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ namespace ctranslate2 {
const StorageView& _embeddings;
const DataType _output_type;
const StorageView* _qscale;
const StorageView* _qzero;
};

// This enum order should remain fixed.
Expand Down
7 changes: 5 additions & 2 deletions include/ctranslate2/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ namespace ctranslate2 {
enum class QUANTIZATION_TYPE {
CT2,
AWQ_GEMM,
AWQ_GEMV
AWQ_GEMV,
HQQ_4BIT,
};

static const size_t current_binary_version = 6;
Expand Down Expand Up @@ -113,11 +114,13 @@ namespace ctranslate2 {
}

// If the model contains variables, they will be moved to the new device.
void set_device(const Device device, const int index = 0);
void set_device(const Device device, const int index = 0, const bool format_lower_bit = false);

// Copy the model to another device.
std::shared_ptr<const Model> copy_to(Device device, int device_index = 0) const;

template<typename T>
T get_config_if_exists(const std::string& name) const;
const StorageView* get_variable_if_exists(const std::string& name) const;
const StorageView& get_variable(const std::string& name) const;
std::unordered_map<std::string, StorageView> get_variables() const;
Expand Down
27 changes: 25 additions & 2 deletions include/ctranslate2/ops/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,24 @@ namespace ctranslate2 {
bool trans_b = false,
bool a_is_packed = false,
bool b_is_packed = false,
const ActivationType* activation_type = nullptr);
const ActivationType* activation_type = nullptr,
const int _group_size = 0);

void operator()(const StorageView& a,
const StorageView& b,
StorageView& c,
const StorageView* a_shift_compensation = nullptr,
const StorageView* bias = nullptr) const;

void operator()(const StorageView& a,
const StorageView& b,
const StorageView& scaleAndZero,
StorageView& c,
const StorageView* bias = nullptr) const;

StorageView convert_to_int4pack(const StorageView& input,
int32_t innerKTiles);

// Return the packed representation of b, if implemented by the GEMM backend.
static StorageView pack_b_input(const StorageView& b,
const bool transpose,
Expand All @@ -49,13 +59,26 @@ namespace ctranslate2 {
bool _trans_b;
bool _a_is_packed;
bool _b_is_packed;
const int _group_size;

template <Device D, typename In, typename Out>
void compute(const StorageView& a,
const StorageView& b,
StorageView& c,
const StorageView* a_shift_compensation) const;
};

#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
template <Device D, typename In, typename Out>
void compute(const StorageView& a,
const StorageView& b,
const StorageView& scaleAndZero,
StorageView& c) const;

template <Device D>
void convert_weight_to_int4pack(const StorageView& a,
StorageView& b,
int32_t innerKTiles);
#endif
};
}
}
1 change: 1 addition & 0 deletions include/ctranslate2/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace ctranslate2 {
INT16,
FLOAT16,
BFLOAT16,
INT32_BFLOAT16,
};

ComputeType str_to_compute_type(const std::string& compute_type);
Expand Down
24 changes: 22 additions & 2 deletions python/ctranslate2/converters/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Optional

from ctranslate2.specs.model_spec import ACCEPTED_MODEL_TYPES, ModelSpec
from ctranslate2.specs.model_spec import ACCEPTED_MODEL_TYPES, DEVICE, ModelSpec


class Converter(abc.ABC):
Expand All @@ -30,6 +30,18 @@ def declare_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
choices=ACCEPTED_MODEL_TYPES,
help="Weight quantization type.",
)
parser.add_argument(
"--group_size",
type=int,
default=None,
help="Group size used in quantization lower bit",
)
parser.add_argument(
"--device",
default=None,
choices=DEVICE,
help="Device where the quantization runs on. Available only with hqq quantization",
)
parser.add_argument(
"--force",
action="store_true",
Expand All @@ -51,6 +63,8 @@ def convert_from_args(self, args: argparse.Namespace) -> str:
args.output_dir,
vmap=args.vocab_mapping,
quantization=args.quantization,
group_size=args.group_size,
device=args.device,
force=args.force,
)

Expand All @@ -59,6 +73,8 @@ def convert(
output_dir: str,
vmap: Optional[str] = None,
quantization: Optional[str] = None,
group_size: Optional[int] = None,
device: Optional[str] = None,
force: bool = False,
) -> str:
"""Converts the model to the CTranslate2 format.
Expand All @@ -69,6 +85,10 @@ def convert(
in the converted model directory.
quantization: Weight quantization scheme (possible values are: int8, int8_float32,
int8_float16, int8_bfloat16, int16, float16, bfloat16, float32).
group_size: group size used by the quantization in lower bit (possible values
are: 32, 64, 128...)
device: Device where the compute of the scales and zero in quantization lower bit
runs (possible values are: cuda, cpu)
force: Override the output directory if it already exists.

Returns:
Expand All @@ -95,7 +115,7 @@ def convert(
model_spec.register_vocabulary_mapping(vmap)

model_spec.validate()
model_spec.optimize(quantization=quantization)
model_spec.optimize(quantization=quantization, group_size=group_size, device=device)

# Create model directory.
if os.path.exists(output_dir):
Expand Down
2 changes: 2 additions & 0 deletions python/ctranslate2/specs/common_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Quantization(enum.IntEnum):
CT2 = 0
AWQ_GEMM = 1
AWQ_GEMV = 2
HQQ_INT4 = 3


class LayerNormSpec(model_spec.LayerSpec):
Expand Down Expand Up @@ -62,4 +63,5 @@ class EmbeddingsSpec(model_spec.LayerSpec):
def __init__(self):
self.weight = None
self.weight_scale = model_spec.OPTIONAL
self.weight_zero = model_spec.OPTIONAL
self.multiply_by_sqrt_depth = model_spec.OPTIONAL
124 changes: 120 additions & 4 deletions python/ctranslate2/specs/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
import struct

from typing import Dict, List, Optional
try:
from hqq.core.quantize import HQQLinear, BaseQuantizeConfig, Quantizer
hqq_is_available = True
except ImportError:
hqq_is_available = False

import numpy as np

Expand All @@ -21,10 +26,15 @@
except ImportError:
torch_is_available = False

from ctranslate2.specs import (
common_spec,
)

OPTIONAL = "__optional"
CURRENT_BINARY_VERSION = 6

ACCEPTED_MODEL_TYPES = (
"hqq_int4",
"int8",
"int8_float32",
"int8_float16",
Expand All @@ -37,6 +47,10 @@

SKIP_CREATING_ALIAS = ("rotary_scaling_long_factor", "rotary_scaling_short_factor")

DEVICE = (
"cuda",
"cpu",
)

def _join_scope(scope, name):
if not scope:
Expand Down Expand Up @@ -188,7 +202,50 @@ def _alias_variables(self):
setattr(spec, attr_name, other_name)
break

def _quantize(self, quantization):
def _hqq_quants_to_torch_quants(self, w_q, scales, zeros, shape, nbits=4):
max_int = 2**nbits - 1
min_int = 0
dump = 2 ** (nbits - 1)

# HQQ -> torch logic
new_zeros = (scales * dump) - zeros * scales

min_val = new_zeros - scales * dump

# group_quantize_tensor_from_qparams
w_r = (w_q - zeros) * scales

w_q = (
w_r.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape(shape)
.contiguous()
)

# group_dequantize_tensor_from_qparams
# W_r = W_q*scales + min_val

scales = scales.contiguous().reshape(shape[0], -1)
new_zeros = new_zeros.contiguous().reshape(shape[0], -1)
scale_and_zero = (
torch.cat(
[
scales.reshape(scales.size(0), scales.size(1), 1),
new_zeros.reshape(new_zeros.size(0), new_zeros.size(1), 1),
],
2,
)
.transpose(0, 1)
.contiguous()
)


return w_q, scale_and_zero

def _quantize(self, quantization, group_size, device):
"""Possibly quantizes the variable of the layer."""
if quantization is not None and quantization not in ACCEPTED_MODEL_TYPES:
raise ValueError(
Expand All @@ -202,6 +259,7 @@ def _quantize(spec, name, value):

key = _split_scope(name)[-1]
scale = None
zero = None
is_quantizable = hasattr(spec, "%s_scale" % key)
is_convertible = value.dtype in ("float32", "float16", "bfloat16")

Expand Down Expand Up @@ -244,11 +302,37 @@ def _quantize(spec, name, value):
value = NumpyVariable(value)
elif quantization in ("float16", "bfloat16", "float32"):
value = value.to(quantization)
elif quantization in (
"hqq_int4",
) and value.shape != 3 and hqq_is_available:
if 'embeddings' in name:
value = value.to("bfloat16")
else:
quant_config = BaseQuantizeConfig(nbits=4, group_size=group_size,
quant_zero=False, quant_scale=False, axis=1)
hqq_linear = HQQLinear(None, quant_config=quant_config,
compute_dtype=torch.bfloat16, device=device)
hqq_linear.quantize(value.to("bfloat16").tensor, **hqq_linear.quant_config)

value = hqq_linear.W_q.cpu()
scale = hqq_linear.meta['scale'].cpu()
zero = hqq_linear.meta['zero'].cpu()
old_shape = hqq_linear.meta['shape']
value = Quantizer.unpack[hqq_linear.meta["packing"]](value)
value, scale = self._hqq_quants_to_torch_quants(value, scale, zero, old_shape)

scale = scale.cpu()
value = value.cpu()
scale = PyTorchVariable(scale)
value = PyTorchVariable(value)
del hqq_linear.W_q
del hqq_linear.meta['scale']
del hqq_linear.meta['zero']

elif is_convertible:
if quantization in ("float16", "int8_float16"):
value = value.to("float16")
elif quantization in ("bfloat16", "int8_bfloat16"):
elif quantization in ("bfloat16", "int8_bfloat16", "hqq_int4"):
value = value.to("bfloat16")
elif quantization in ("float32", "int16", "int8_float32"):
value = value.to("float32")
Expand All @@ -259,7 +343,9 @@ def _quantize(spec, name, value):

self._visit(_quantize)

def optimize(self, quantization: Optional[str] = None) -> None:
def optimize(self, quantization: Optional[str] = None,
group_size: Optional[int] = None,
device: Optional[str] = None) -> None:
"""Recursively applies some optimizations to this layer:

* Alias variables with the same shape and value.
Expand All @@ -268,9 +354,11 @@ def optimize(self, quantization: Optional[str] = None) -> None:
Arguments:
quantization: Weight quantization scheme (possible values are: int8, int8_float32,
int8_float16, int8_bfloat16, int16, float16, bfloat16, float32).
group_size: Group size of quantization lower bits
device: device where the quantization runs on (support hqq only)
"""
self._alias_variables()
self._quantize(quantization)
self._quantize(quantization, group_size, device)

def _visit(self, fn):
"""Recursively visits this layer and its children."""
Expand Down Expand Up @@ -414,6 +502,34 @@ def _write_string(string):
_write_string(variable_name)


def _save_quantization_type(self, quantization, group_size):
if quantization == 'hqq_int4':
self._config.add_attribute("quantization_type", common_spec.Quantization.HQQ_INT4)
elif quantization is not None:
self._config.add_attribute("quantization_type", common_spec.Quantization.CT2)

if group_size is not None:
self._config.add_attribute("quantization_group_size", group_size)

def optimize(self, quantization: Optional[str] = None,
group_size: Optional[int] = None,
device: Optional[str] = None) -> None:
"""Recursively applies some optimizations to its layer:

* Alias variables with the same shape and value.
* Quantize weights.

Arguments:
quantization: Weight quantization scheme (possible values are: int8, int8_float32,
int8_float16, int8_bfloat16, int16, float16, bfloat16, float32).
group_size: Group size of quantization lower bits
device: device where the quantization runs on (support hqq only)
"""
self._save_quantization_type(quantization, group_size)
self._alias_variables()
self._quantize(quantization, group_size, device)


def _flatten_vocabularies(vocabularies):
for name, vocabulary in vocabularies.items():
if len(vocabulary) == 1:
Expand Down
Loading
Loading