diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py index 8f88b1f80a1..3089f70429b 100644 --- a/server/tests/utils/test_weights.py +++ b/server/tests/utils/test_weights.py @@ -1,13 +1,47 @@ import pytest import torch -from text_generation_server.utils.weights import Weights -from text_generation_server.layers.gptq import GPTQWeight -from text_generation_server.layers.exl2 import Exl2Weight -from text_generation_server.layers.marlin import MarlinWeight +from text_generation_server.utils.weights import ( + _DefaultWeightsLoader, + Weights, + WeightsLoader, +) +from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader +from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader +from text_generation_server.layers.marlin import MarlinWeight, MarlinWeightsLoader from types import SimpleNamespace from typing import List, Optional, Dict, Union from pathlib import Path + +@pytest.fixture +def gptq_weights_loader(): + return GPTQWeightsLoader( + bits=4, + groupsize=-1, + desc_act=False, + quant_method="gptq", + quantize="gptq", + sym=True, + ) + + +@pytest.fixture +def gptq_weights_loader_awq(): + return GPTQWeightsLoader( + bits=4, + groupsize=-1, + desc_act=False, + quant_method="awq", + quantize="awq", + sym=True, + ) + + +@pytest.fixture +def marlin_weights_loader(): + return MarlinWeightsLoader(bits=4, is_marlin_24=False) + + dummy_file_system = { "test_weights": { "layer.0.weight": torch.tensor( @@ -308,6 +342,7 @@ def __init__( dummy_fs, aliases: Optional[Dict[str, List[str]]] = None, prefix: Optional[str] = None, + weights_loader: Optional[WeightsLoader] = None, ): routing = {} self.dummy_fs = dummy_fs @@ -327,6 +362,9 @@ def __init__( self.dtype = dtype self.process_group = process_group self.prefix = prefix + self.weights_loader = ( + _DefaultWeightsLoader() if weights_loader is None else weights_loader + ) self._handles = {} def _get_handle(self, filename: Union[Path, str]): @@ -412,12 +450,10 @@ def test_get_weights_col_packed(): ) prefix = "weight" - quantize = None block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -448,12 +484,10 @@ def test_get_weights_col_packed_block_size(): ) prefix = "weight" - quantize = None block_sizes = 2 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -484,12 +518,10 @@ def test_get_weights_col_packed_block_size_arr(): ) prefix = "weight" - quantize = None block_sizes = [1, 1] w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -519,11 +551,9 @@ def test_get_multi_weights_col(): ) prefixes = ["weight", "weight"] - quantize = None w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -557,11 +587,9 @@ def test_get_multi_weights_row(): ) prefix = "weight" - quantize = None w = weights.get_multi_weights_row( prefix=prefix, - quantize=quantize, ) assert torch.allclose( @@ -576,7 +604,7 @@ def test_get_multi_weights_row(): # test_get_weights_col -def test_get_weights_col_awq(): +def test_get_weights_col_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_weights_col_gptq", @@ -585,14 +613,13 @@ def test_get_weights_col_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefix = "weight" - quantize = "awq" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -617,7 +644,7 @@ def test_get_weights_col_awq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_weights_col_gtpq(): +def test_get_weights_col_gtpq(gptq_weights_loader): weights = MockWeights( [ "test_get_weights_col_gptq", @@ -626,14 +653,13 @@ def test_get_weights_col_gtpq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefix = "weight" - quantize = "gptq" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -664,14 +690,13 @@ def test_get_weights_col_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) scaled_scale_max = 0.3906 * 256 @@ -692,7 +717,7 @@ def test_get_weights_col_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_weights_col_marlin(): +def test_get_weights_col_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_weights_col_marlin", @@ -701,14 +726,13 @@ def test_get_weights_col_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) expected_weight = MarlinWeight( @@ -723,7 +747,7 @@ def test_get_weights_col_marlin(): # test_get_weights_col_packed -def test_get_weights_col_packed_awq(): +def test_get_weights_col_packed_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_weights_col_packed_gptq", @@ -732,15 +756,14 @@ def test_get_weights_col_packed_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefix = "weight" - quantize = "awq" block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -773,15 +796,14 @@ def test_get_weights_col_packed_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -803,7 +825,7 @@ def test_get_weights_col_packed_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_weights_col_packed_gptq(): +def test_get_weights_col_packed_gptq(gptq_weights_loader): weights = MockWeights( [ "test_get_weights_col_packed_gptq", @@ -812,14 +834,13 @@ def test_get_weights_col_packed_gptq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefixes = ["weight"] - quantize = "gptq" w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -842,7 +863,7 @@ def test_get_weights_col_packed_gptq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_weights_col_packed_marlin(): +def test_get_weights_col_packed_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_weights_col_packed_marlin", @@ -851,14 +872,13 @@ def test_get_weights_col_packed_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_multi_weights_col( prefixes=[prefix], - quantize=quantize, dim=0, ) @@ -876,7 +896,7 @@ def test_get_weights_col_packed_marlin(): # test_get_multi_weights_col -def test_get_multi_weights_col_awq(): +def test_get_multi_weights_col_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_multi_weights_col_gptq", @@ -885,14 +905,13 @@ def test_get_multi_weights_col_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefixes = ["weight"] - quantize = "awq" w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -924,22 +943,21 @@ def test_get_multi_weights_col_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" try: w = weights.get_multi_weights_col( prefixes=[prefix], - quantize=quantize, dim=0, ) except ValueError as e: assert e.args[0] == "get_multi_weights_col is not supported for exl2" -def test_get_multi_weights_col_gptq(): +def test_get_multi_weights_col_gptq(gptq_weights_loader): weights = MockWeights( [ "test_get_multi_weights_col_gptq", @@ -948,14 +966,13 @@ def test_get_multi_weights_col_gptq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefixes = ["weight"] - quantize = "gptq" w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -978,7 +995,7 @@ def test_get_multi_weights_col_gptq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_col_marlin(): +def test_get_multi_weights_col_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_multi_weights_col_marlin", @@ -987,14 +1004,13 @@ def test_get_multi_weights_col_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_multi_weights_col( prefixes=[prefix], - quantize=quantize, dim=0, ) @@ -1010,7 +1026,7 @@ def test_get_multi_weights_col_marlin(): # test_get_multi_weights_row -def test_get_multi_weights_row_awq(): +def test_get_multi_weights_row_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_multi_weights_row_gptq", @@ -1019,14 +1035,13 @@ def test_get_multi_weights_row_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefix = "weight" - quantize = "awq" w = weights.get_multi_weights_row( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -1057,14 +1072,13 @@ def test_get_multi_weights_row_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" w = weights.get_multi_weights_row( prefix=prefix, - quantize=quantize, ) print(w) @@ -1086,7 +1100,7 @@ def test_get_multi_weights_row_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_multi_weights_row_gptq(): +def test_get_multi_weights_row_gptq(gptq_weights_loader): weights = MockWeights( [ "test_get_multi_weights_row_gptq", @@ -1095,14 +1109,13 @@ def test_get_multi_weights_row_gptq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefix = "weight" - quantize = "gptq" w = weights.get_multi_weights_row( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -1124,7 +1137,7 @@ def test_get_multi_weights_row_gptq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_row_marlin(): +def test_get_multi_weights_row_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_multi_weights_row_marlin", @@ -1133,14 +1146,13 @@ def test_get_multi_weights_row_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_multi_weights_row( prefix=prefix, - quantize=quantize, ) expected_weight = MarlinWeight( diff --git a/server/text_generation_server/layers/exl2.py b/server/text_generation_server/layers/exl2.py index f6cb729ed6a..3c7a034c7a3 100644 --- a/server/text_generation_server/layers/exl2.py +++ b/server/text_generation_server/layers/exl2.py @@ -1,6 +1,9 @@ import torch +from typing import List, Union from dataclasses import dataclass +from text_generation_server.utils.weights import WeightsLoader, Weights + @dataclass class Exl2Weight: @@ -21,3 +24,58 @@ def __post_init__(self): @property def device(self) -> torch.device: return self.q_weight.device + + +class Exl2WeightsLoader(WeightsLoader): + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + raise RuntimeError("Column-packed weights are not supported for exl") + + def get_weights_col(self, weights: Weights, prefix: str): + try: + q_weight = weights.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + "Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = weights.get_tensor(f"{prefix}.q_scale") + q_invperm = weights.get_tensor(f"{prefix}.q_invperm") + q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") + q_groups = weights.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + raise ValueError("get_multi_weights_col is not supported for exl2") + + def get_multi_weights_row(self, weights: Weights, prefix: str): + try: + q_weight = weights.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + "Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = weights.get_tensor(f"{prefix}.q_scale") + q_invperm = weights.get_tensor(f"{prefix}.q_invperm") + q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") + q_groups = weights.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 56080145028..0e05915ae4c 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -1,20 +1,14 @@ from dataclasses import dataclass +from loguru import logger import os -from typing import Optional +from typing import List, Optional, Union +from safetensors import SafetensorError +from text_generation_server.utils.weights import Weights, WeightsLoader import torch from text_generation_server.utils.import_utils import ( SYSTEM, ) - - -@dataclass -class GPTQParams: - bits: int - checkpoint_format: Optional[str] - groupsize: int - desc_act: bool - quant_method: str - sym: bool +from text_generation_server.utils.log import log_once @dataclass @@ -69,3 +63,340 @@ def device(self) -> torch.device: pass from text_generation_server.layers.gptq.quant_linear import QuantLinear + + +class GPTQWeightsLoader(WeightsLoader): + def __init__( + self, + *, + bits: int, + desc_act: bool, + groupsize: int, + quant_method: str, + quantize: str, + sym: bool, + ): + self.bits = bits + self.desc_act = desc_act + self.groupsize = groupsize + self.quant_method = quant_method + self.quantize = quantize + self.sym = sym + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + try: + qweight = weights.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." + ) + scales = weights.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) + scales = scales.to(dtype=weights.dtype) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + g_idx = weights.get_tensor(f"{prefix}.g_idx") + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=False, + ) + + qzeros = weights.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.g_idx") + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=False, + ) + + def get_weights_col(self, weights: Weights, prefix: str): + return self.get_multi_weights_col(weights, [prefix], dim=0) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + try: + qweight = torch.cat( + [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" + ) + + scales = torch.cat( + [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=False, + ) + + qzeros = torch.cat( + [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + + from text_generation_server.layers.gptq import HAS_EXLLAMA + + use_exllama = ( + self.bits == 4 + and HAS_EXLLAMA + and self.quantize == "gptq" + and not self.desc_act + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=use_exllama, + ) + + def get_multi_weights_row(self, weights: Weights, prefix: str): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + if self.desc_act or self.groupsize == -1: + scales = weights.get_tensor(f"{prefix}.scales") + else: + scales = weights.get_sharded(f"{prefix}.scales", dim=0) + + sharded_in_features = weights.process_group.size() > 1 + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=sharded_in_features, + ) + + use_exllama = True + if self.bits != 4: + use_exllama = False + + if self.desc_act: + log_once(logger.warning, "Disabling exllama because desc_act=True") + use_exllama = False + + try: + qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + else: + g_idx = None + + if weights.process_group.size() > 1: + if g_idx is not None: + if ( + not torch.equal( + g_idx.cpu(), + torch.tensor( + [i // self.groupsize for i in range(g_idx.shape[0])], + dtype=torch.int32, + ), + ) + and not (g_idx == 0).all() + ): + # Exllama implementation does not support row tensor parallelism with act-order, as + # it would require to reorder input activations that are split unto several GPUs + use_exllama = False + + from text_generation_server.layers.gptq import ( + HAS_EXLLAMA, + CAN_EXLLAMA, + GPTQWeight, + ) + + if use_exllama: + if not HAS_EXLLAMA: + if CAN_EXLLAMA: + log_once( + logger.warning, + "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", + ) + use_exllama = False + else: + log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") + + if use_exllama and self.groupsize != -1: + qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) + scales = weights.get_sharded(f"{prefix}.scales", dim=0) + else: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + scales = weights.get_tensor(f"{prefix}.scales") + + if use_exllama and g_idx is not None: + g_idx = g_idx - g_idx[0] + + if self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=use_exllama, + ) + + def _get_gptq_params(self, weights: Weights): + try: + self.bits = weights.get_tensor("gptq_bits").item() + self.groupsize = weights.get_tensor("gptq_groupsize").item() + self.desc_act = False + self.sym = False + self.quant_method = "gptq" + except (SafetensorError, RuntimeError) as e: + pass diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index a1af67a3f5f..71ad4309f3f 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union +from text_generation_server.utils.weights import Weights, WeightsLoader import torch import torch.nn as nn -from text_generation_server.layers.gptq import GPTQParams from text_generation_server.utils.import_utils import SYSTEM try: @@ -24,16 +24,133 @@ MARLIN_TILE_SIZE = 16 -def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool: +class MarlinWeightsLoader(WeightsLoader): + def __init__(self, *, bits: int, is_marlin_24: bool): + self.bits = bits + self.is_marlin_24 = is_marlin_24 + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + if self.is_marlin_24: + B = weights.get_packed_sharded( + f"{prefix}.B_24", dim=1, block_sizes=block_sizes + ) + B_meta = weights.get_packed_sharded( + f"{prefix}.B_meta", dim=1, block_sizes=block_sizes + ) + s = weights.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + B = weights.get_packed_sharded( + f"{prefix}.B", dim=1, block_sizes=block_sizes + ) + s = weights.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_weights_col(self, weights: Weights, prefix: str): + return self.get_multi_weights_col(weights, [prefix], dim=0) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = torch.cat( + [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `marlin` weight, make sure the model is already quantized" + ) + + B_meta = torch.cat( + [weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 + ) + + s = torch.cat( + [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = torch.cat( + [weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `marlin` weight, make sure the model is already quantized" + ) + s = torch.cat( + [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_multi_weights_row(self, weights: Weights, prefix: str): + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = weights.get_sharded(f"{prefix}.B_24", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." + ) + + B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0) + num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = weights.get_tensor(f"{prefix}.s") + else: + s = weights.get_sharded(f"{prefix}.s", dim=0) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = weights.get_sharded(f"{prefix}.B", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized." + ) + + num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = weights.get_tensor(f"{prefix}.s") + else: + s = weights.get_sharded(f"{prefix}.s", dim=0) + weight = MarlinWeight(B=B, s=s) + + return weight + + +def can_use_gptq_marlin( + *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool +) -> bool: return ( SYSTEM == "cuda" and marlin_kernels is not None and has_sm_8_0 and quantize == "gptq" - and gptq_params.quant_method == "gptq" - and gptq_params.bits in GPTQ_MARLIN_BITS - and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES - and gptq_params.sym + and quant_method == "gptq" + and bits in GPTQ_MARLIN_BITS + and groupsize in GPTQ_MARLIN_GROUP_SIZES + and sym ) diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 038de25815a..d0809af086a 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -52,7 +52,7 @@ def load(config, prefix: str, weights): weight = weights.get_tensor(f"{prefix}.weight") except: # ...otherwise they are quantized. - weight = weights.get_weights_col(prefix, config.quantize) + weight = weights.get_weights_col(prefix) should_gather = weights.process_group.size() > 1 elif weights.process_group.size() > 1: try: @@ -129,9 +129,7 @@ class TensorParallelColumnLinear(SuperLayer): @classmethod def load_gate_up(cls, config, prefix: str, weights, bias: bool): """Specific method when the QKV was joined after the fact""" - weight = weights.get_weights_col_packed_gate_up( - prefix, quantize=config.quantize - ) + weight = weights.get_weights_col_packed_gate_up(prefix) if bias: raise NotImplementedError("packed_gate_up only implemented without bias") else: @@ -152,7 +150,6 @@ def load_qkv( """Specific method when the QKV was joined after the fact""" weight = weights.get_weights_col_packed_qkv( prefix, - quantize=config.quantize, num_heads=num_heads, num_key_value_heads=num_key_value_heads, ) @@ -165,7 +162,7 @@ def load_qkv( @classmethod def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_weights_col(prefix, config.quantize) + weight = weights.get_weights_col(prefix) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: @@ -178,14 +175,12 @@ def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): if config.quantize == "exl2": linears = [] for prefix in prefixes: - weight = weights.get_weights_col(prefix, config.quantize) + weight = weights.get_weights_col(prefix) b = weights.get_tensor(f"{prefix}.bias") if bias else None linears.append(get_linear(weight, b, config.quantize)) linear = LayerConcat(linears) else: - weight = weights.get_multi_weights_col( - prefixes, quantize=config.quantize, dim=dim - ) + weight = weights.get_multi_weights_col(prefixes, dim=dim) if bias: b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] bias = torch.cat(b, dim=dim) @@ -202,7 +197,7 @@ def __init__(self, linear, process_group): @classmethod def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_multi_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index f993fe72094..25719b999dc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -163,7 +163,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index beff08b3080..a3ce55213ff 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 14b62b00b0b..34a7efa2550 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index d5dc25cff1a..2353bb2d866 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -61,7 +61,6 @@ def _load_qkv_gptq(config, prefix: str, weights): # Weights weight = weights.get_weights_col_packed_qkv( f"{prefix}.c_attn", - config.quantize, config.num_attention_heads, config.num_attention_heads, ) @@ -137,7 +136,7 @@ def load_row(config, prefix: str, weights, bias: bool): """load_row, but with transposed weight matrices.""" if config.quantize == "gptq": - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_multi_weights_row(prefix) else: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T @@ -155,9 +154,7 @@ def load_row(config, prefix: str, weights, bias: bool): def load_col(config, prefix: str, weights, bias: bool): """load_col, but with transposed weight matrices.""" if config.quantize == "gptq": - weight = weights.get_multi_weights_col( - [prefix], quantize=config.quantize, dim=1 - ) + weight = weights.get_multi_weights_col([prefix], dim=1) else: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 429793ea575..49c0e9030fc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -135,7 +135,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 0eca181b67b..b6b55ba526a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -48,7 +48,7 @@ def load_row(config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_multi_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process @@ -64,7 +64,7 @@ def load_row(config, prefix: str, weights, bias: bool): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): - weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0) + weight = weights.get_multi_weights_col([prefix], dim=0) if isinstance(weight, torch.Tensor): # Only on non quantized versions weight = ( diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 7401bc27a7f..6c508264293 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -85,7 +85,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 4813e2df988..32d97788916 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -23,7 +23,7 @@ def load_row(config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_multi_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 21a22046c9e..de40d5288e8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -17,6 +17,7 @@ TensorParallelEmbedding, get_linear, ) +from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -81,11 +82,13 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = qzeros.to(device=weights.device) - gptq_params = weights._get_gptq_params() - if gptq_params.quant_method == "gptq": + loader = weights.weights_loader + assert isinstance(loader, GPTQWeightsLoader) + loader._get_gptq_params(weights) + if loader.quant_method == "gptq": g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = g_idx.to(device=weights.device) - elif gptq_params.quant_method == "awq": + elif loader.quant_method == "awq": g_idx = None from text_generation_server.layers.awq.conversion_utils import ( fast_awq_to_gptq, @@ -100,8 +103,8 @@ def _load_multi_mqa_gptq( qzeros=qzeros, scales=scales, g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, + bits=loader.bits, + groupsize=loader.groupsize, use_exllama=HAS_EXLLAMA, ) @@ -197,9 +200,7 @@ def load_col(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T else: - weight = weights.get_multi_weights_col( - [prefix], quantize=config.quantize, dim=0 - ) + weight = weights.get_multi_weights_col([prefix], dim=0) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) @@ -212,7 +213,7 @@ def load_row(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T else: - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_multi_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bf1fda4a602..4375d49471d 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -50,6 +50,7 @@ from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION +from text_generation_server.utils.quantization import get_quantizer_loader from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments from text_generation_server.utils.import_utils import ( @@ -881,12 +882,16 @@ def __init__( torch.distributed.barrier(group=self.process_group) + weights_loader = get_quantizer_loader(quantize, model_id, revision) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device, dtype, process_group=self.process_group, aliases=aliases + filenames, + device, + dtype, + process_group=self.process_group, + aliases=aliases, + weights_loader=weights_loader, ) - if config.quantize in ["awq", "exl2", "gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) prefix = "" model = model_class(prefix, config, weights) diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py new file mode 100644 index 00000000000..90a3d34e24e --- /dev/null +++ b/server/text_generation_server/utils/quantization.py @@ -0,0 +1,119 @@ +from typing import Optional +import os +import json +from dataclasses import dataclass + +from huggingface_hub import hf_hub_download + +from text_generation_server.utils.weights import WeightsLoader + + +@dataclass +class _QuantizerConfig: + bits: int + checkpoint_format: Optional[str] + desc_act: bool + groupsize: int + quant_method: str + sym: bool + + +# We should probably do this with Pytantic JSON deserialization, +# but for now we'll stay close to the old _set_gptq_params. +def _get_quantizer_config(model_id, revision): + bits = 4 + groupsize = -1 + quant_method = "gptq" + checkpoint_format = None + sym = True + desc_act = False + + filename = "config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename, revision=revision) + with open(filename, "r") as f: + data = json.load(f) + bits = data["quantization_config"]["bits"] + groupsize = data["quantization_config"]["group_size"] + # Order is important here, desc_act is missing on some real models + quant_method = data["quantization_config"]["quant_method"] + checkpoint_format = data["quantization_config"].get("checkpoint_format") + sym = data["quantization_config"]["sym"] + desc_act = data["quantization_config"]["desc_act"] + except Exception: + filename = "quantize_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download( + model_id, filename=filename, revision=revision + ) + with open(filename, "r") as f: + data = json.load(f) + bits = data["bits"] + groupsize = data["group_size"] + sym = data["sym"] + desc_act = data["desc_act"] + if "version" in data and data["version"] == "GEMM": + quant_method = "awq" + except Exception: + filename = "quant_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download( + model_id, filename=filename, revision=revision + ) + with open(filename, "r") as f: + data = json.load(f) + bits = data["w_bit"] + groupsize = data["q_group_size"] + desc_act = data["desc_act"] + if "version" in data and data["version"] == "GEMM": + quant_method = "awq" + except Exception: + pass + + return _QuantizerConfig( + bits=bits, + groupsize=groupsize, + quant_method=quant_method, + checkpoint_format=checkpoint_format, + sym=sym, + desc_act=desc_act, + ) + + +def get_quantizer_loader( + quantize: Optional[str], model_id: str, revision: Optional[str] +) -> Optional[WeightsLoader]: + quantizer_config = _get_quantizer_config(model_id, revision) + if quantize in {"awq", "gptq"}: + from text_generation_server.layers.gptq import GPTQWeightsLoader + + return GPTQWeightsLoader( + bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, + groupsize=quantizer_config.groupsize, + quant_method=quantizer_config.quant_method, + quantize=quantize, + sym=quantizer_config.sym, + ) + elif quantize == "exl2": + from text_generation_server.layers.exl2 import Exl2WeightsLoader + + return Exl2WeightsLoader() + elif quantize == "marlin": + from text_generation_server.layers.marlin import MarlinWeightsLoader + + return MarlinWeightsLoader( + bits=quantizer_config.bits, + is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", + ) + else: + return None diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 3731fd249f7..9bd16e42d9c 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,13 +1,66 @@ -import os +from abc import ABC, abstractmethod from pathlib import Path from typing import Dict, List, Optional, Union -from safetensors import safe_open, SafetensorError +from safetensors import safe_open import torch -from loguru import logger -from huggingface_hub import hf_hub_download -import json -from text_generation_server.layers.gptq import GPTQParams -from text_generation_server.utils.log import log_once + + +class WeightsLoader(ABC): + """ + Instances of this type implement higher-level weight loading. + + At a low-level, every weight is stored in the Safetensors format. + The interpretation of weights may be different however, for instance + could be packed, quantized weights. Loaders are responsible for + interpreting the raw tensors, sharding tensors in a manner compatible + with the format, etc. + """ + + @abstractmethod + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): ... + + @abstractmethod + def get_weights_col(self, weights: "Weights", prefix: str): ... + + @abstractmethod + def get_multi_weights_col( + self, weights: "Weights", prefixes: List[str], dim: int + ): ... + + @abstractmethod + def get_multi_weights_row(self, weights: "Weights", prefix: str): ... + + +class _DefaultWeightsLoader(WeightsLoader): + """ + This loader uses tensors as-is with the exception of applying sharding + and/or concatenation. + """ + + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): + return weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + + def get_weights_col(self, weights: "Weights", prefix: str): + return weights.get_multi_weights_col([prefix], 0) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + return torch.cat(w, dim=dim) + + def get_multi_weights_row(self, weights: "Weights", prefix: str): + return weights.get_sharded(f"{prefix}.weight", dim=1) class Weights: @@ -19,6 +72,7 @@ def __init__( process_group, aliases: Optional[Dict[str, List[str]]] = None, prefix: Optional[str] = None, + weights_loader: Optional[WeightsLoader] = None, ): routing = {} for filename in filenames: @@ -37,6 +91,9 @@ def __init__( self.dtype = dtype self.process_group = process_group self.prefix = prefix + self.weights_loader = ( + _DefaultWeightsLoader() if weights_loader is None else weights_loader + ) self._handles = {} def _get_handle(self, filename): @@ -181,18 +238,14 @@ def get_weights_col_packed_qkv( num_key_value_heads: int, ): return self.get_weights_col_packed( - prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads] + prefix, [num_heads, num_key_value_heads, num_key_value_heads] ) - def get_weights_col_packed_gate_up(self, prefix: str, quantize: str): - return self.get_weights_col_packed(prefix, quantize, 2) + def get_weights_col_packed_gate_up(self, prefix: str): + return self.get_weights_col_packed(prefix, 2) - def get_weights_col_packed( - self, prefix: str, quantize: str, block_sizes: Union[int, List[int]] - ): + def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]): """ - Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being - already alternating Q,K,V within the main tensor. The columns are split in equally sized blocks when blocks is an `int`, or in blocks proportional given to the sizes. For instance `[2, 1, 1]` will @@ -200,276 +253,13 @@ def get_weights_col_packed( convenient for e.g. splitting QKV without knowing the storage details of quantized weights. """ - if quantize in ["gptq", "awq"]: - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - - try: - qweight = self.get_packed_sharded( - f"{prefix}.qweight", dim=1, block_sizes=block_sizes - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized." - ) - scales = self.get_packed_sharded( - f"{prefix}.scales", dim=1, block_sizes=block_sizes - ) - scales = scales.to(dtype=self.dtype) - - gptq_params = self._get_gptq_params() - if can_use_gptq_marlin(gptq_params, quantize): - g_idx = self.get_tensor(f"{prefix}.g_idx") - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, - ) - - qzeros = self.get_packed_sharded( - f"{prefix}.qzeros", dim=1, block_sizes=block_sizes - ) - if quantize == "gptq" and gptq_params.quant_method == "gptq": - g_idx = self.get_tensor(f"{prefix}.g_idx") - elif quantize == "gptq" and gptq_params.quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conversion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - g_idx = ( - torch.arange( - qweight.shape[0] * (32 // gptq_params.bits), - device=qweight.device, - ) - // gptq_params.groupsize - ).to(dtype=torch.int32) - else: - g_idx = None - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=False, - ) - elif quantize == "marlin": - from text_generation_server.layers.marlin import ( - GPTQMarlin24Weight, - MarlinWeight, - repack_gptq_for_marlin, - ) - - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - B = self.get_packed_sharded( - f"{prefix}.B_24", dim=1, block_sizes=block_sizes - ) - B_meta = self.get_packed_sharded( - f"{prefix}.B_meta", dim=1, block_sizes=block_sizes - ) - s = self.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - - gptq_params = self._get_gptq_params() - weight = GPTQMarlin24Weight( - B=B, B_meta=B_meta, s=s, bits=gptq_params.bits - ) - else: - B = self.get_packed_sharded( - f"{prefix}.B", dim=1, block_sizes=block_sizes - ) - s = self.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - weight = MarlinWeight(B=B, s=s) - else: - weight = self.get_packed_sharded( - f"{prefix}.weight", dim=0, block_sizes=block_sizes - ) - return weight - - def get_weights_col(self, prefix: str, quantize: str): - if quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2Weight - - try: - q_weight = self.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = self.get_tensor(f"{prefix}.q_scale") - q_invperm = self.get_tensor(f"{prefix}.q_invperm") - q_scale_max = self.get_tensor(f"{prefix}.q_scale_max") - q_groups = self.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) - - return self.get_multi_weights_col([prefix], quantize, 0) - - def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): - if quantize == "exl2": - raise ValueError("get_multi_weights_col is not supported for exl2") - elif quantize in ["gptq", "awq"]: - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - - try: - qweight = torch.cat( - [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - - scales = torch.cat( - [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 - ) - - gptq_params = self._get_gptq_params() - if can_use_gptq_marlin(gptq_params, quantize): - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, - ) - - qzeros = torch.cat( - [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) - - from text_generation_server.layers.gptq import HAS_EXLLAMA - - use_exllama = ( - gptq_params.bits == 4 - and HAS_EXLLAMA - and quantize == "gptq" - and not gptq_params.desc_act - ) - - if quantize == "gptq" and gptq_params.quant_method == "gptq": - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - elif quantize == "gptq" and gptq_params.quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conversion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - if use_exllama: - g_idx = None - else: - g_idx = ( - torch.arange( - qweight.shape[0] * (32 // gptq_params.bits), - device=qweight.device, - ) - // gptq_params.groupsize - ).to(dtype=torch.int32) - else: - g_idx = None - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=use_exllama, - ) - elif quantize == "marlin": - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - GPTQMarlin24Weight, - MarlinWeight, - ) - - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - try: - B = torch.cat( - [self.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - - B_meta = torch.cat( - [self.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 - ) - - s = torch.cat( - [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - gptq_params = self._get_gptq_params() - weight = GPTQMarlin24Weight( - B=B, B_meta=B_meta, s=s, bits=gptq_params.bits - ) - else: - try: - B = torch.cat( - [self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - s = torch.cat( - [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - weight = MarlinWeight(B=B, s=s) + return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes) - else: - w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] - weight = torch.cat(w, dim=dim) + def get_weights_col(self, prefix: str): + return self.weights_loader.get_weights_col(self, prefix) - return weight + def get_multi_weights_col(self, prefixes: List[str], dim: int): + return self.weights_loader.get_multi_weights_col(self, prefixes, dim) def get_tensor_shard(self, var, dim): world_size = self.process_group.size() @@ -487,318 +277,8 @@ def get_tensor_shard(self, var, dim): tensor = tensor.to(device=self.device) return tensor - def get_multi_weights_row(self, prefix: str, quantize: str): - if quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2Weight - - try: - q_weight = self.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = self.get_tensor(f"{prefix}.q_scale") - q_invperm = self.get_tensor(f"{prefix}.q_invperm") - q_scale_max = self.get_tensor(f"{prefix}.q_scale_max") - q_groups = self.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) - - elif quantize == "gptq": - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - - gptq_params = self._get_gptq_params() - if can_use_gptq_marlin(gptq_params, quantize): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - if gptq_params.desc_act or gptq_params.groupsize == -1: - scales = self.get_tensor(f"{prefix}.scales") - else: - scales = self.get_sharded(f"{prefix}.scales", dim=0) - - sharded_in_features = self.process_group.size() > 1 - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=sharded_in_features, - ) - - use_exllama = True - if gptq_params.bits != 4: - use_exllama = False - - if gptq_params.desc_act: - log_once(logger.warning, "Disabling exllama because desc_act=True") - use_exllama = False - - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) - - if gptq_params.quant_method == "gptq": - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - elif gptq_params.quant_method == "awq": - g_idx = None - - if self.process_group.size() > 1: - if g_idx is not None: - if ( - not torch.equal( - g_idx.cpu(), - torch.tensor( - [ - i // gptq_params.groupsize - for i in range(g_idx.shape[0]) - ], - dtype=torch.int32, - ), - ) - and not (g_idx == 0).all() - ): - # Exllama implementation does not support row tensor parallelism with act-order, as - # it would require to reorder input activations that are split unto several GPUs - use_exllama = False - - from text_generation_server.layers.gptq import ( - HAS_EXLLAMA, - CAN_EXLLAMA, - GPTQWeight, - ) - - if use_exllama: - if not HAS_EXLLAMA: - if CAN_EXLLAMA: - log_once( - logger.warning, - "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", - ) - use_exllama = False - else: - log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - - if use_exllama and gptq_params.groupsize != -1: - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) - else: - qzeros = self.get_tensor(f"{prefix}.qzeros") - scales = self.get_tensor(f"{prefix}.scales") - - if use_exllama and g_idx is not None: - g_idx = g_idx - g_idx[0] - - if gptq_params.quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conversion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - if use_exllama: - g_idx = None - else: - g_idx = ( - torch.arange( - qweight.shape[0] * (32 // gptq_params.bits), - device=qweight.device, - ) - // gptq_params.groupsize - ).to(dtype=torch.int32) - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=use_exllama, - ) - elif quantize == "awq": - from text_generation_server.layers.gptq import GPTQWeight - - gptq_params = self._get_gptq_params() - - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `awq` weight, make sure the model is already quantized" - ) - - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) - g_idx = None - use_exllama = False - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=use_exllama, - ) - elif quantize == "marlin": - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - GPTQMarlin24Weight, - MarlinWeight, - ) - - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - try: - B = self.get_sharded(f"{prefix}.B_24", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." - ) - - B_meta = self.get_sharded(f"{prefix}.B_meta", dim=0) - num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = self.get_tensor(f"{prefix}.s") - else: - s = self.get_sharded(f"{prefix}.s", dim=0) - - gptq_params = self._get_gptq_params() - weight = GPTQMarlin24Weight( - B=B, B_meta=B_meta, s=s, bits=gptq_params.bits - ) - else: - try: - B = self.get_sharded(f"{prefix}.B", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized." - ) - - num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = self.get_tensor(f"{prefix}.s") - else: - s = self.get_sharded(f"{prefix}.s", dim=0) - weight = MarlinWeight(B=B, s=s) - else: - weight = self.get_sharded(f"{prefix}.weight", dim=1) - return weight - - def _get_gptq_params(self) -> GPTQParams: - try: - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() - checkpoint_format = getattr(self, "gptq_checkpoint_format", None) - desc_act = False - sym = False - quant_method = "gptq" - except (SafetensorError, RuntimeError) as e: - try: - bits = self.gptq_bits - groupsize = self.gptq_groupsize - checkpoint_format = getattr(self, "gptq_checkpoint_format", None) - desc_act = getattr(self, "gptq_desc_act", False) - quant_method = getattr(self, "quant_method", "gptq") - sym = getattr(self, "sym", True) - except Exception: - raise e - - return GPTQParams( - bits=bits, - checkpoint_format=checkpoint_format, - desc_act=desc_act, - groupsize=groupsize, - quant_method=quant_method, - sym=sym, - ) - - def _set_gptq_params(self, model_id, revision): - filename = "config.json" - try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) - self.gptq_bits = data["quantization_config"]["bits"] - self.gptq_groupsize = data["quantization_config"]["group_size"] - # Order is important here, desc_act is missing on some real models - self.quant_method = data["quantization_config"]["quant_method"] - self.gptq_checkpoint_format = data["quantization_config"].get( - "checkpoint_format" - ) - self.gptq_sym = data["quantization_config"]["sym"] - self.gptq_desc_act = data["quantization_config"]["desc_act"] - except Exception: - filename = "quantize_config.json" - try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) - self.gptq_bits = data["bits"] - self.gptq_groupsize = data["group_size"] - self.gptq_sym = data["sym"] - self.gptq_desc_act = data["desc_act"] - if "version" in data and data["version"] == "GEMM": - self.quant_method = "awq" - except Exception: - filename = "quant_config.json" - try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) - self.gptq_bits = data["w_bit"] - self.gptq_groupsize = data["q_group_size"] - self.gptq_desc_act = data["desc_act"] - if "version" in data and data["version"] == "GEMM": - self.quant_method = "awq" - except Exception: - pass + def get_multi_weights_row(self, prefix: str): + return self.weights_loader.get_multi_weights_row(self, prefix) def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: