diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index cb7c8d048..7099dae13 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -36,6 +36,7 @@ groupwise_affine_quantize_tensor_from_qparams, pack_tinygemm_scales_and_zeros, per_token_dynamic_quant, + prepare_int4_weight_and_scales_and_zeros, ) aten = torch.ops.aten @@ -533,6 +534,7 @@ def linear_forward_int4( weight_int4pack: torch.Tensor, scales_and_zeros: torch.Tensor, out_features: int, + in_features: int, groupsize: int, precision: torch.dtype = torch.bfloat16, scales_precision: torch.dtype = torch.bfloat16, @@ -558,6 +560,136 @@ def linear_forward_int4( return c +def linear_forward_int4_dynamic_quantization_4bit( + x: torch.Tensor, + weight_int4pack: torch.Tensor, + scales_and_zeros: torch.Tensor, + out_features: int, + in_features: int, + groupsize: int, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, +): + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._dyn_quant_matmul_4bit( + x.to(precision), weight_int4pack, groupsize, in_features, out_features) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c + +def kai_roundup(a: int, b: int) -> int: + return ((a + b - 1) // b) * b + + +def get_kai_packed_weight_size(n_bits, N, K, groupsize): + if n_bits == 4: + if groupsize == K: # channelwise + # dotprod params only [1x8x32_neon_dotprod] + kai_nr = 8 + kai_kr = 16 + kai_sr = 2 + kai_num_bytes_sum_rhs = 4 # sizeof(int32_t) + kai_num_bytes_multiplier_rhs = 4 # sizeof(float) + kai_num_bytes_bias = 4 # sizeof(float) + + def kai_k_roundedup(k, kr, sr): + # Since we pack a float and int32 value at the end of the row, + # we must make sure that k is a multiple of 4 for alignment + kr_sr_roundedup4 = kai_roundup(kr * sr, 4) + return kai_roundup(k, kr_sr_roundedup4) + + def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + k, nr, kr, sr + ): + k_internal = kai_k_roundedup(k, kr, sr) + + assert (k_internal % 2) == 0, "k_internal must be even" + + return nr * ( + (k_internal // 2) + + kai_num_bytes_multiplier_rhs + + kai_num_bytes_sum_rhs + + kai_num_bytes_bias + ) + + def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + n, k, nr, kr, sr + ): + num_rows = kai_roundup(n, nr) // nr + + return ( + num_rows + * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + k, nr, kr, sr + ) + ) + + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + N, K, kai_nr, kai_kr, kai_sr + ) + elif groupsize % 32 == 0 and K % groupsize == 0: # groupwise + kai_nr = 8 + kai_kr = 16 + kai_sr = 2 + kai_num_bytes_sum_rhs = 4 + kai_num_bytes_bias = 4 + kai_nr_multiple_of = 4 + kai_bl_multiple_of = 32 + + def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + n, k, nr, kr, sr, bl + ): + assert (bl % kr) == 0 + assert (nr % kai_nr_multiple_of) == 0 + assert (bl % kai_bl_multiple_of) == 0 + + num_rows = kai_roundup(n, nr) // nr + + return ( + num_rows + * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + k, nr, kr, sr, bl + ) + ) + + def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + k, nr, kr, sr, bl + ): + assert (bl % kr) == 0 + assert (nr % kai_nr_multiple_of) == 0 + assert (bl % kai_bl_multiple_of) == 0 + + # kr and sr are unused in the calculation + num_bytes_multiplier_rhs = kai_get_bf16_datatype_size_in_bytes() + num_blocks_per_row = kai_num_blocks_per_row(k, bl) + num_bytes_per_block = kai_num_bytes_per_block( + bl, num_bytes_multiplier_rhs + ) + + return nr * ( + (num_bytes_per_block * num_blocks_per_row) + + kai_num_bytes_sum_rhs + + kai_num_bytes_bias + ) + + # This funtion retuns size of these datatypes stored as enum. We modify it to just return bf16 datatype + # https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/kai_common.h?ref_type=heads#L55 + def kai_get_bf16_datatype_size_in_bytes(): + return 2 # 2 bytes + + def kai_num_blocks_per_row(k, bl): + assert (bl % kai_bl_multiple_of) == 0 + return kai_roundup(k, bl) // bl + + def kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs): + assert (bl % kai_bl_multiple_of) == 0 + return (bl // 2) + num_bytes_multiplier_rhs + + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + N, K, kai_nr, kai_kr, kai_sr, groupsize + ) + class WeightOnlyInt4Linear(torch.nn.Module): __constants__ = ["in_features", "out_features"] in_features: int @@ -576,21 +708,29 @@ def __init__( inner_k_tiles: int = 8, precision: torch.dtype = torch.bfloat16, scales_precision: torch.dtype = torch.bfloat16, + scheme = None, ) -> None: super().__init__() - self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) + self.padding = scheme is None and not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) if self.padding: self.origin_in_features = in_features in_features = find_multiple(in_features, 1024) self.in_features = in_features self.out_features = out_features - assert not bias, "require bias=False" self.device = device self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles self.precision = precision self.scales_precision = scales_precision + self.scheme = scheme + self.bias = bias + + self.custom_forward = linear_forward_int4 + if torch.backends.kleidiai.is_available() and (self.scheme == "symmetric_channelwise" or self.scheme == "symmetric_groupwise"): + self.custom_forward = linear_forward_int4_dynamic_quantization_4bit + else: + self.scheme = None if dtype is not None: raise ValueError("Please specify 'precision' instead of 'dtype'") @@ -600,17 +740,36 @@ def __init__( in_features % (inner_k_tiles * 16) == 0 ), "require in_features % (innerKTiles * 16) == 0" if is_device(device.type, "cpu"): - self.register_buffer( - "weight", - torch.zeros( - ( - out_features, - in_features // 2, + if scheme is None: + assert not bias, "require bias=False" + self.register_buffer( + "weight", + torch.zeros( + ( + out_features, + in_features // 2, + ), + dtype=torch.uint8, + device=device, ), - dtype=torch.uint8, - device=device, - ), - ) + ) + else: + if torch.backends.kleidiai.is_available() and ( + (groupsize == in_features and scales_precision == torch.float) + or ( + groupsize < in_features + and groupsize % 32 == 0 + and in_features % groupsize == 0 + and scales_precision == torch.bfloat16 + ) + ): + packed_weight_size = get_kai_packed_weight_size(4,out_features, in_features, groupsize) + self.register_buffer( + "weight", + torch.empty((packed_weight_size), dtype=torch.uint8) + ) + else: + raise ValueError("KleidiAI backend can not be initiated") else: self.register_buffer( "weight", @@ -625,30 +784,49 @@ def __init__( device=device, ), ) - self.dtype = dtype - self.register_buffer( - "scales_and_zeros", - torch.zeros( - (in_features // groupsize, out_features, 2), - dtype=self.scales_precision, - device=device, - ), - ) + if scheme is None: + self.register_buffer( + "scales_and_zeros", + torch.zeros( + (in_features // groupsize, out_features, 2), + dtype=self.scales_precision, + device=device, + ), + ) + elif scheme == "symmetric_channelwise": + self.register_buffer( + "scales_and_zeros", + torch.zeros( + (out_features), + dtype=self.scales_precision, + device=device, + ), + ) + elif scheme == "symmetric_groupwise": + self.register_buffer( + "scales_and_zeros", + torch.zeros( + (out_features, in_features//groupsize), + dtype=self.scales_precision, + device=device, + ), + ) + def forward(self, input: torch.Tensor) -> torch.Tensor: if self.padding: input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) - return linear_forward_int4( + return self.custom_forward( input, self.weight, self.scales_and_zeros, self.out_features, + self.in_features, self.groupsize, self.precision, self.scales_precision, ) - def _replace_linear_int4( module: torch.nn.Module, groupsize: int, @@ -659,27 +837,31 @@ def _replace_linear_int4( scales_precision: torch.dtype = torch.bfloat16, linear_class: Type[torch.nn.Module] = WeightOnlyInt4Linear, copy_weights: bool = False, + scheme=None, ): for name, child in module.named_children(): - # TODO: support linear bias - if ( - isinstance(child, nn.Linear) - and child.bias is None - and (skip_layer_func is None or not skip_layer_func(child.weight)) - ): - if ( - _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) - or padding_allowed - ): + if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)): + if scheme is not None or _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: + bias = False + if child.bias is not None: + bias = child.bias + if scheme == "symmetric_channelwise": + groupsize = child.in_features + scales_precision = torch.float32 + elif scheme == "symmetric_groupwise": + # Scales are actually f16 but they are packed along with weights + # To maintain api compatibility we populate scaled with no element in this scheme + scales_precision = torch.bfloat16 new_linear = linear_class( child.in_features, child.out_features, - bias=False, + bias=bias, device=child.weight.device, groupsize=groupsize, inner_k_tiles=inner_k_tiles, precision=precision, scales_precision=scales_precision, + scheme=scheme, ) # TODO: merge with 8da4w? # In distributed training, the model may be instantiated @@ -699,6 +881,7 @@ def _replace_linear_int4( scales_precision, linear_class, copy_weights, + scheme=scheme, ) @@ -723,10 +906,11 @@ def __init__( inner_k_tiles: Optional[int] = 8, device: torch.device = torch.device("cuda"), precision: torch.dtype = torch.bfloat16, + scheme = None ) -> None: super().__init__() assert inner_k_tiles in [2, 4, 8] - assert groupsize in [32, 64, 128, 256] + assert groupsize in [32, 64, 128, 256] # 0 group size is allowed for channelwise scheme where groupsize = row size self.inner_k_tiles = inner_k_tiles self.groupsize: int = groupsize @@ -734,6 +918,7 @@ def __init__( self.device: torch.device = device # precision and dtype are being used interchangeably here self.precision: torch.dtype = precision + self.scheme = scheme @torch.no_grad() def _create_quantized_state_dict( @@ -742,18 +927,19 @@ def _create_quantized_state_dict( cur_state_dict = model.state_dict() for fqn, mod in model.named_modules(): if isinstance(mod, torch.nn.Linear): - assert not mod.bias + if self.scheme is None: + assert not mod.bias out_features = mod.out_features in_features = mod.in_features # assert out_features % 8 == 0, "require out_features % 8 == 0" - logging.info(f"linear: {fqn}, in={in_features}, out={out_features}") - assert ( - in_features % self.groupsize == 0 - ), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0" + if self.scheme is None: + assert ( + in_features % self.groupsize == 0 + ), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0" weight = mod.weight.data - if not _check_linear_int4_k( + if self.scheme is None and not _check_linear_int4_k( in_features, self.groupsize, self.inner_k_tiles ): if self.padding_allowed: @@ -772,30 +958,42 @@ def _create_quantized_state_dict( + "and that groupsize and inner_k_tiles*16 evenly divide into it" ) continue - (w_int4x8, scales_and_zeros) = groupwise_affine_quantize_tensor( - weight, - 4, # n_bit - self.groupsize, - self.precision, # dtype for scales_and_zeros - ) - # TODO: just get the device from mod.weight.device? - if ( - is_device(w_int4x8.device.type, "cpu") - and TORCH_VERSION_AT_LEAST_2_6 - ): - weight_int4pack = ( - torch.ops.aten._convert_weight_to_int4pack_for_cpu( - w_int4x8.to(self.device), self.inner_k_tiles - ) + if (self.scheme == "symmetric_channelwise" or self.scheme == "symmetric_groupwise") and is_device(weight.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + if self.scheme == "symmetric_channelwise": + self.groupsize = mod.in_features + ( + w_int4x8, + scales_and_zeros + ) = prepare_int4_weight_and_scales_and_zeros( + weight.to(self.precision), self.groupsize, self.inner_k_tiles, self.scheme, precision=self.precision ) + weight_int4pack = torch.ops.aten._dyn_quant_pack_4bit_weight(w_int4x8.to(self.device), scales_and_zeros, mod.bias, self.groupsize, mod.in_features, mod.out_features) else: - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - w_int4x8.to(self.device), self.inner_k_tiles + (w_int4x8, scales_and_zeros) = groupwise_affine_quantize_tensor( + weight, + 4, # n_bit + self.groupsize, + self.precision, # dtype for scales_and_zeros ) + # TODO: just get the device from mod.weight.device? + if ( + is_device(w_int4x8.device.type, "cpu") + and TORCH_VERSION_AT_LEAST_2_6 + ): + weight_int4pack = ( + torch.ops.aten._convert_weight_to_int4pack_for_cpu( + w_int4x8.to(self.device), self.inner_k_tiles + ) + ) + else: + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + w_int4x8.to(self.device), self.inner_k_tiles + ) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to( - self.device - ) + self.device + ) + logging.info(f"quantizing linear: {fqn}, in={in_features}, out={out_features}, to scheme={self.scheme}, using blocksize={self.groupsize}") return cur_state_dict def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: @@ -807,6 +1005,7 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: skip_layer_func=None, precision=self.precision, scales_precision=self.precision, + scheme=self.scheme, ) return model diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 74c136ad0..401ec1d93 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -454,6 +454,7 @@ def groupwise_affine_dequantize_tensor_from_qparams( ) + def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize, dtype) w_int4x8 = groupwise_affine_quantize_tensor_from_qparams( @@ -463,6 +464,100 @@ def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bflo return w_int4x8, scales_and_zeros +def get_group_qparams(w, n_bit=4, groupsize=128, scheme="symmetric_channelwise", precision=torch.bfloat16): + if groupsize > w.shape[-1] or scheme == "symmetric_channelwise": + groupsize = w.shape[-1] + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + # improved symmetric 4 bit quantization that uses bin correspondingto -8 from [-8,7] ( -2^(b-1) , 2^(b-1)-1 ) range + if scheme == "symmetric_groupwise": + to_quant_abs = to_quant.abs() + max_abs_indices = to_quant_abs.argmax(dim=1, keepdim=True) + max_val = torch.gather(to_quant, 1, max_abs_indices) + scales = max_val / -8 + zeros = torch.zeros_like(scales) + + elif scheme == "symmetric_channelwise": + to_quant_abs = to_quant.abs() + max_abs_indices = to_quant_abs.argmax(dim=1, keepdim=True) + max_val = torch.gather(to_quant, 1, max_abs_indices) + scales = max_val / -8 + zeros = torch.zeros_like(scales) + return scales.to(precision).reshape(w.shape[0], -1), zeros.to( + precision + ).reshape(w.shape[0], -1) + + +def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128, scheme="symmetric_channelwise", precision=torch.bfloat16): + assert groupsize > 1 + if groupsize > w.shape[-1] and scales.shape[-1] == 1: + groupsize = w.shape[-1] + + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + if scheme == "symmetric_groupwise": + max_int = 2**n_bit - 1 + w_int8 = ( + to_quant.div(scales) + .add(8.5) + .to(torch.int8) + .clamp(max=max_int) + ) + elif scheme == "symmetric_channelwise": + max_int = 2**n_bit - 1 + w_int8 = ( + to_quant.div(scales) + .add(8.5) + .to(torch.int8) + .clamp(max=max_int) + ) + # We pack every odd index value in upper 4 bits and every even index value in lower 4 bits + w_uint8 = (w_int8[::, 1::2] << 4 | w_int8[::, ::2]).to(torch.uint8) + return w_uint8 + + +def pack_scales_and_zeros(scales, zeros, scheme="symmetric_channelwise"): + assert scales.shape == zeros.shape + scales_zeros = scales.squeeze().contiguous() + return scales_zeros + + +def group_quantize_tensor(w, n_bit=4, groupsize=128, scheme="symmetric_channelwise", precision=torch.bfloat16): + scales, zeros = get_group_qparams( + w, n_bit, groupsize, scheme=scheme, precision=precision) + w_uint8 = group_quantize_tensor_from_qparams( + w, scales, zeros, n_bit, groupsize, scheme=scheme) + scales_and_zeros = pack_scales_and_zeros(scales, zeros, scheme=scheme) + return w_uint8, scales_and_zeros + + +def prepare_int4_weight_and_scales_and_zeros(weights, groupsize, inner_k_tiles, scheme="symmetric_channelwise", precision=torch.bfloat16): + assert weights.dim() == 2 + assert groupsize > 1 + if groupsize > weights.shape[-1] or scheme == "symmetric_channelwise": + groupsize = weights.shape[-1] + assert weights.shape[-1] % groupsize == 0 + weight_int4pack, scales_and_zeros = group_quantize_tensor( + weights, n_bit=4, groupsize=groupsize, scheme=scheme, precision=precision + ) + if scheme == "symmetric_channelwise": + scales_and_zeros = scales_and_zeros.to(dtype=torch.float32) + elif scheme == "symmetric_groupwise": + scales_and_zeros = scales_and_zeros.to(dtype=torch.bfloat16) + return weight_int4pack, scales_and_zeros + + def groupwise_affine_dequantize_tensor( w_int4x8, scales_and_zeros,