diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index b58d4c3755a..f57c18a2109 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -59,7 +59,7 @@ def _get_perms() -> Tuple[List[int], List[int]]: _scale_perm, _scale_perm_single = _get_perms() -def permute_scales(scales: torch.Tensor, in_features: int, group_size: int): +def permute_scales(scales: torch.Tensor): out_features = scales.shape[1] if scales.shape[0] == 1: scales = scales.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] @@ -95,8 +95,8 @@ def repack_gptq_for_marlin( g_idx: torch.Tensor, bits: int, desc_act: bool, - group_size: int, - is_sym: bool, + groupsize: int, + sym: bool, sharded_infeatures: bool, ) -> GPTQMarlinWeight: """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" @@ -109,12 +109,12 @@ def repack_gptq_for_marlin( f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}" ) - if group_size not in GPTQ_MARLIN_GROUP_SIZES: + if groupsize not in GPTQ_MARLIN_GROUP_SIZES: supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES) raise RuntimeError( - f"Repacking GPTQ weights with group size {group_size} as Marlin is not supported, must be one of: {supported_sizes}" + f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" ) - if not is_sym: + if not sym: raise RuntimeError( "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." ) @@ -123,12 +123,12 @@ def repack_gptq_for_marlin( in_features = qweight.shape[0] * weights_per_int out_features = qweight.shape[1] - if in_features % group_size != 0: + if in_features % groupsize != 0: raise ValueError( - f"Number of input features ({in_features}) not divisible by group size ({group_size})" + f"Number of input features ({in_features}) not divisible by group size ({groupsize})" ) - if desc_act and group_size != -1: + if desc_act and groupsize != -1: perm = torch.argsort(g_idx).to(torch.int) g_idx = g_idx[perm] else: @@ -139,7 +139,7 @@ def repack_gptq_for_marlin( qweight, perm, in_features, out_features, bits ) - scales = permute_scales(scales, in_features, group_size) + scales = permute_scales(scales) is_full_k = not (desc_act and sharded_infeatures) @@ -249,11 +249,11 @@ def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): out_features % 256 == 0 ), f"Number of output features ({out_features}) not divisable by 256" - group_size = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] - assert group_size in { + groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] + assert groupsize in { -1, 128, - }, f"Group size must be -1 or 128, was {group_size}" + }, f"Group size must be -1 or 128, was {groupsize}" self.register_buffer("B", weight.B) self.register_buffer("s", weight.s) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 5696aa6143f..45cfc073ca3 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -251,8 +251,8 @@ def get_weights_col_packed( g_idx=g_idx, bits=gptq_params.bits, desc_act=gptq_params.desc_act, - group_size=gptq_params.groupsize, - is_sym=gptq_params.sym, + groupsize=gptq_params.groupsize, + sym=gptq_params.sym, sharded_infeatures=False, ) @@ -416,8 +416,8 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): g_idx=g_idx, bits=gptq_params.bits, desc_act=gptq_params.desc_act, - group_size=gptq_params.groupsize, - is_sym=gptq_params.sym, + groupsize=gptq_params.groupsize, + sym=gptq_params.sym, sharded_infeatures=False, ) else: @@ -638,8 +638,8 @@ def get_multi_weights_row(self, prefix: str, quantize: str): g_idx=g_idx, bits=gptq_params.bits, desc_act=gptq_params.desc_act, - group_size=gptq_params.groupsize, - is_sym=gptq_params.sym, + groupsize=gptq_params.groupsize, + sym=gptq_params.sym, sharded_infeatures=sharded_in_features, ) else: @@ -652,7 +652,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str): num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] if num_groups == 1: - # The number of groups is 1 when group_size == -1. share + # The number of groups is 1 when groupsize == -1. share # scales between all shards in this case. s = self.get_tensor(f"{prefix}.s") else: