Skip to content

Commit

Permalink
Use symmetric quantization in the quantize subcommand
Browse files Browse the repository at this point in the history
Packing of asymmetric quantization is broken, all (q)zeros values
of `0` get reset to `1`, resulting in a loss of accuracy. So instead
use symmetric quantization. To be able to distinguish models with
symmetric and asymmetric quantization, a new config tensor `gptq_sym` is
added. If this tensor is not present, we assume `sym=False`.
  • Loading branch information
danieldk committed Jul 12, 2024
1 parent d789de3 commit 44286bb
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 4 deletions.
1 change: 1 addition & 0 deletions server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def quantize(
upload_to_model_id=upload_to_model_id,
percdamp=percdamp,
act_order=act_order,
sym=True,
)


Expand Down
12 changes: 8 additions & 4 deletions server/text_generation_server/layers/gptq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,11 +393,15 @@ def get_weights_row(self, weights: Weights, prefix: str):
)

def _get_gptq_params(self, weights: Weights):
try:
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
self.sym = False
# `server quantize` used asymmetric quantization unconditionally
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"
except (SafetensorError, RuntimeError) as e:
pass
3 changes: 3 additions & 0 deletions server/text_generation_server/layers/gptq/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,7 @@ def quantize(
upload_to_model_id: Optional[str],
percdamp: float,
act_order: bool,
sym: bool,
):
print("loading model")
config = AutoConfig.from_pretrained(
Expand Down Expand Up @@ -946,6 +947,7 @@ def _unload():
percdamp=percdamp,
act_order=act_order,
hooks=hooks,
sym=sym,
)
print(time.time() - tick)

Expand All @@ -957,6 +959,7 @@ def _unload():
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
state_dict["gptq_bits"] = torch.LongTensor([bits])
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
state_dict["gptq_sym"] = torch.BoolTensor([sym])

max_shard_size = "10GB"
shards, index = shard_checkpoint(
Expand Down
7 changes: 7 additions & 0 deletions server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ def _get_slice(self, tensor_name: str):
slice_ = f.get_slice(tensor_name)
return slice_

def _has_tensor(self, tensor_name: str):
try:
self.get_filename(tensor_name)
except Exception:
return False
return True

def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()

Expand Down

0 comments on commit 44286bb

Please sign in to comment.