From c35f39cf8361b4e227654483fc808135fe315d0b Mon Sep 17 00:00:00 2001 From: Abhinav M Kulkarni Date: Mon, 25 Sep 2023 13:28:02 +0530 Subject: [PATCH] Add AWQ quantization inference support (#1019) # Add AWQ quantization inference support Fixes https://github.com/huggingface/text-generation-inference/issues/781 This PR (partially) adds support for AWQ quantization for inference. More information on AWQ [here](https://arxiv.org/abs/2306.00978). In general, AWQ is faster and more accurate than GPTQ, which is currently supported by TGI. This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors (in `requirements.txt`, just one line change). Quick way to test this PR would be bring up TGI as follows: ``` text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq text-generation-launcher \ --huggingface-hub-cache ~/.cache/huggingface/hub/ \ --model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \ --trust-remote-code --port 8080 \ --max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \ --quantize awq ``` Please note: * This PR was tested with FlashAttention v2 and vLLM. * This PR adds support for AWQ inference, not quantizing the models. That needs to be done outside of TGI, instructions [here](https://github.com/mit-han-lab/llm-awq/tree/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa). * This PR only adds support for `FlashLlama` models for now. * Multi-GPU setup has not been tested. * No integration tests have been added so far, will add later if maintainers are interested in this change. * This PR can be tested on any of the models released [here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models). Please refer to the linked issue for benchmarks for [abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq) vs [TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ). Please note, AWQ has released faster (and in case of Llama, fused) kernels for 4-bit GEMM, currently at the top of the `main` branch at https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit that has been tested to work. We can switch to latest commit later on. ## Who can review? @OlivierDehaene OR @Narsil --------- Co-authored-by: Abhinav Kulkarni --- .../source/basic_tutorials/preparing_model.md | 2 +- launcher/src/main.rs | 4 ++ server/requirements.txt | 2 + server/text_generation_server/cli.py | 1 + .../text_generation_server/models/__init__.py | 4 ++ .../models/flash_llama.py | 2 +- .../utils/awq/quantize/qmodule.py | 53 +++++++++++++++++ server/text_generation_server/utils/layers.py | 13 +++- .../text_generation_server/utils/weights.py | 59 ++++++++++++++----- 9 files changed, 122 insertions(+), 18 deletions(-) create mode 100644 server/text_generation_server/utils/awq/quantize/qmodule.py diff --git a/docs/source/basic_tutorials/preparing_model.md b/docs/source/basic_tutorials/preparing_model.md index 6b622d99f90..0f5739ea854 100644 --- a/docs/source/basic_tutorials/preparing_model.md +++ b/docs/source/basic_tutorials/preparing_model.md @@ -4,7 +4,7 @@ Text Generation Inference improves the model in several aspects. ## Quantization -TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes` or `gptq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq). To get more information about quantization, please refer to (./conceptual/quantization.md) +TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq` or `awq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq) when using AWQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to (./conceptual/quantization.md) ## RoPE Scaling diff --git a/launcher/src/main.rs b/launcher/src/main.rs index cbb6f25d59c..09e32f8944b 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -25,6 +25,7 @@ enum Quantization { BitsandbytesNF4, BitsandbytesFP4, Gptq, + Awq, } impl std::fmt::Display for Quantization { @@ -43,6 +44,9 @@ impl std::fmt::Display for Quantization { Quantization::Gptq => { write!(f, "gptq") } + Quantization::Awq => { + write!(f, "awq") + } } } } diff --git a/server/requirements.txt b/server/requirements.txt index 1b038ccac22..ac3ac9fad74 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -73,3 +73,5 @@ win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13" yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13" +# Custom 4-bit GEMM AWQ kernels +git+https://github.com/mit-han-lab/llm-awq.git@f084f40bd996f3cf3a0633c1ad7d9d476c318aaa#subdirectory=awq/kernels diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index e3fda07f533..e0b8c0fec5b 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -17,6 +17,7 @@ class Quantization(str, Enum): bitsandbytes_nf4 = "bitsandbytes-nf4" bitsandbytes_fp4 = "bitsandbytes-fp4" gptq = "gptq" + awq = "awq" class Dtype(str, Enum): diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 29b049cfa26..0d96d43b28f 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -268,6 +268,10 @@ def get_model( raise ValueError( "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) + if quantize == "awq": + raise ValueError( + "awq quantization is not supported for AutoModel" + ) elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): raise ValueError( "4bit quantization is not supported for AutoModel" diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 063aa01e1af..d2ed0b15a75 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -62,7 +62,7 @@ def __init__( filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize == "gptq": + if config.quantize in ["gptq", "awq"]: weights._set_gptq_params(model_id) model = FlashLlamaForCausalLM(config, weights) diff --git a/server/text_generation_server/utils/awq/quantize/qmodule.py b/server/text_generation_server/utils/awq/quantize/qmodule.py new file mode 100644 index 00000000000..fb1adf5c48b --- /dev/null +++ b/server/text_generation_server/utils/awq/quantize/qmodule.py @@ -0,0 +1,53 @@ +# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py + +import math +import torch +import torch.nn as nn +import awq_inference_engine # with CUDA kernels + + +class ScaledActivation(nn.Module): + def __init__(self, module, scales): + super().__init__() + self.act = module + self.scales = nn.Parameter(scales.data) + + def forward(self, x): + return self.act(x) / self.scales.view(1, 1, -1).to(x.device) + + +class WQLinear(nn.Module): + def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): + super().__init__() + + if w_bit not in [4]: + raise NotImplementedError("Only 4-bit are supported for now.") + + self.in_features = qweight.shape[0] + self.out_features = qweight.shape[1] * 32 // w_bit + + self.w_bit = w_bit + self.group_size = group_size if group_size != -1 else self.in_features + # quick sanity check (make sure aligment) + assert self.in_features % self.group_size == 0 + assert self.out_features % (32 // self.w_bit) == 0 + + self.register_buffer('qweight', qweight) + self.register_buffer('qzeros', qzeros) + self.register_buffer('scales', scales) + if bias: + self.register_buffer('bias', bias) + else: + self.bias = None + + @torch.no_grad() + def forward(self, x): + out_shape = x.shape[:-1] + (self.out_features, ) + out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format( + self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size + ) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index c1c36194650..cfec58597f9 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -17,6 +17,7 @@ from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear +from text_generation_server.utils.awq.quantize.qmodule import WQLinear try: major, _minor = torch.cuda.get_device_capability() @@ -248,6 +249,14 @@ def get_linear(weight, bias, quantize): bits, groupsize, ) + elif quantize == "awq": + try: + qweight, qzeros, scales, _, bits, groupsize, _ = weight + except Exception: + raise NotImplementedError( + f"The passed weight is not `awq` compatible, loader needs to be updated." + ) + linear = WQLinear(w_bit=bits, group_size=groupsize, qweight=qweight, qzeros=qzeros, scales=scales, bias=bias is not None) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear @@ -283,8 +292,8 @@ def load(config, prefix: str, weights): weight = weights.get_tensor(f"{prefix}.weight") should_gather = False - # GPTQ doesn't quantize heads (nor embeddings) - if config.quantize == "gptq": + # GPTQ and AWQ don't quantize heads (nor embeddings) + if config.quantize in ["gptq", "awq"]: quantize = None else: quantize = config.quantize diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 2ef7ad39f43..fdeabbe6473 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -135,18 +135,26 @@ def get_weights_col_packed_qkv(self, prefix: str, quantize: str): Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being already alternating Q,K,V within the main tensor """ - if quantize == "gptq": + if quantize in ["gptq", "awq"]: try: qweight = self._get_qweight(f"{prefix}.qweight") except RuntimeError: - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) + if quantize == "gptq": + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + else: + raise RuntimeError( + "Cannot load `awq` weight, make sure the model is already quantized" + ) qzeros = self._get_qweight(f"{prefix}.qzeros") scales = self._get_qweight(f"{prefix}.scales") scales = scales.to(dtype=self.dtype) - g_idx = self.get_tensor(f"{prefix}.g_idx") + try: + g_idx = self.get_tensor(f"{prefix}.g_idx") + except RuntimeError: + g_idx = None bits, groupsize = self._get_gptq_params() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) @@ -171,15 +179,20 @@ def get_weights_col_packed_qkv(self, prefix: str, quantize: str): return weight def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): - if quantize == "gptq": + if quantize in ["gptq", "awq"]: try: qweight = torch.cat( [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 ) except RuntimeError: - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) + if quantize == "gptq": + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + else: + raise RuntimeError( + "Cannot load `awq` weight, make sure the model is already quantized" + ) qzeros = torch.cat( [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 @@ -187,10 +200,14 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): scales = torch.cat( [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 ) - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] + + try: + w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + except RuntimeError: + g_idx = None bits, groupsize = self._get_gptq_params() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) @@ -216,7 +233,7 @@ def get_tensor_shard(self, var, dim): return tensor def get_multi_weights_row(self, prefix: str, quantize: str): - if quantize == "gptq": + if quantize in "gptq": use_exllama = True bits, groupsize = self._get_gptq_params() @@ -282,6 +299,20 @@ def get_multi_weights_row(self, prefix: str, quantize: str): g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + elif quantize == "awq": + bits, groupsize = self._get_gptq_params() + + try: + qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `awq` weight, make sure the model is already quantized" + ) + + qzeros = self.get_tensor(f"{prefix}.qzeros") + scales = self.get_tensor(f"{prefix}.scales") + + weight = (qweight, qzeros, scales, None, bits, groupsize, None) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight