From 44286bb9b3ed6d9c48a520ef33fd1a7ae2797a3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Wed, 26 Jun 2024 09:18:31 +0200 Subject: [PATCH] Use symmetric quantization in the `quantize` subcommand Packing of asymmetric quantization is broken, all (q)zeros values of `0` get reset to `1`, resulting in a loss of accuracy. So instead use symmetric quantization. To be able to distinguish models with symmetric and asymmetric quantization, a new config tensor `gptq_sym` is added. If this tensor is not present, we assume `sym=False`. --- server/text_generation_server/cli.py | 1 + .../text_generation_server/layers/gptq/__init__.py | 12 ++++++++---- .../text_generation_server/layers/gptq/quantize.py | 3 +++ server/text_generation_server/utils/weights.py | 7 +++++++ 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 68ae95dd7e0..71ad18f7920 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -353,6 +353,7 @@ def quantize( upload_to_model_id=upload_to_model_id, percdamp=percdamp, act_order=act_order, + sym=True, ) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index efcb3118f22..aaa7a68a0b1 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -393,11 +393,15 @@ def get_weights_row(self, weights: Weights, prefix: str): ) def _get_gptq_params(self, weights: Weights): - try: + if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): self.bits = weights.get_tensor("gptq_bits").item() self.groupsize = weights.get_tensor("gptq_groupsize").item() self.desc_act = False - self.sym = False + # `server quantize` used asymmetric quantization unconditionally + # before the `gptq_sym` setting tensor was added. + self.sym = ( + weights.get_tensor("gptq_sym").item() + if weights._has_tensor("gptq_sym") + else False + ) self.quant_method = "gptq" - except (SafetensorError, RuntimeError) as e: - pass diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index c65d5e78d98..0271d913d7b 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -871,6 +871,7 @@ def quantize( upload_to_model_id: Optional[str], percdamp: float, act_order: bool, + sym: bool, ): print("loading model") config = AutoConfig.from_pretrained( @@ -946,6 +947,7 @@ def _unload(): percdamp=percdamp, act_order=act_order, hooks=hooks, + sym=sym, ) print(time.time() - tick) @@ -957,6 +959,7 @@ def _unload(): state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict["gptq_bits"] = torch.LongTensor([bits]) state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) + state_dict["gptq_sym"] = torch.BoolTensor([sym]) max_shard_size = "10GB" shards, index = shard_checkpoint( diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 1a62fb3bdd1..50a9167a091 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -146,6 +146,13 @@ def _get_slice(self, tensor_name: str): slice_ = f.get_slice(tensor_name) return slice_ + def _has_tensor(self, tensor_name: str): + try: + self.get_filename(tensor_name) + except Exception: + return False + return True + def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape()