From 5f8f0d2c7bf9f90b6ac496d795d1160c1a8bd5e6 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 29 Aug 2024 12:03:44 +0000 Subject: [PATCH] fix empty state_dict() and bump to 0.2.1 --- hqq/__init__.py | 2 +- hqq/core/quantize.py | 61 +++++++++++++++++++++++++++++++++++++++----- setup.py | 2 +- 3 files changed, 56 insertions(+), 9 deletions(-) diff --git a/hqq/__init__.py b/hqq/__init__.py index 5edb029..ea2cce8 100755 --- a/hqq/__init__.py +++ b/hqq/__init__.py @@ -1,3 +1,3 @@ -__version__ = "0.2.0" +__version__ = "0.2.1" __author__ = 'Dr. Hicham Badri' __credits__ = 'Mobius Labs GmbH' diff --git a/hqq/core/quantize.py b/hqq/core/quantize.py index e0d6361..68fb907 100755 --- a/hqq/core/quantize.py +++ b/hqq/core/quantize.py @@ -9,6 +9,7 @@ from .utils import is_divisible, encode_safetensor_type, decode_safetensor_type from .optimize import optimize_weights_proximal from .bitpack import BitPack +from termcolor import colored _META_TYPE = { "scale": torch.Tensor, @@ -386,6 +387,8 @@ def __init__( self.ready = False self.in_gpu = False self.bias = None + self.axis = None + self.channel_wise = None self.device = device self.compute_dtype = compute_dtype self.quant_config = copy.deepcopy(quant_config) @@ -408,6 +411,9 @@ def __init__( if initialize: self.initialize() + def is_initialized(self): + return False if (None in [self.W_q, self.meta]) else True + def initialize(self): if self.linear_layer is not None: self.quantize(self.linear_layer.weight.data, **self.quant_config) @@ -524,9 +530,11 @@ def cuda(self, device): ) if self.bias is not None: - if(isinstance(self.bias, torch.nn.Parameter)): - self.bias.data = self.bias.data.to(device=device, dtype=self.compute_dtype) - if(isinstance(self.bias, torch.Tensor)): + if isinstance(self.bias, torch.nn.Parameter): + self.bias.data = self.bias.data.to( + device=device, dtype=self.compute_dtype + ) + if isinstance(self.bias, torch.Tensor): self.bias = self.bias.to(device=device, dtype=self.compute_dtype) self.W_q = nn.Parameter(self.W_q, requires_grad=False) @@ -569,7 +577,36 @@ def cpu(self): # state_dict is encoded by default for safetensors support. You can get the raw dict by setting self.encoded_state_dict=False. \ # Note: you can't change the state once it's done + def state_dict_keys(self): + return set( + [ + "W_q", + "nbits", + "group_size", + "shape", + "scale", + "zero", + "axis", + "packing", + "unpack_view_dtype", + "view_as_float", + "quant_scale", + "quant_zero", + "compute_dtype", + "bias", + "offload_meta", + "encoded_state_dict", + "stores_quant_config", + "channel_wise", + "optimize", + "round_zero", + ] + ) + def state_dict(self, *args, **kwargs): # nn.Module override compatible + if not self.is_initialized(): + return {k: None for k in self.state_dict_keys()} + if ( self.quant_config["scale_quant_params"] or self.quant_config["zero_quant_params"] @@ -1027,11 +1064,21 @@ def hqq_base_quant_config( "view_as_float": view_as_float, } - if(quant_zero or quant_scale): - print(colored('Warning: Quantized meta-data is deprecated and will be removed. It is not supported for quantized model serialization.', 'yellow')) + if quant_zero or quant_scale: + print( + colored( + "Warning: Quantized meta-data is deprecated and will be removed. It is not supported for quantized model serialization.", + "yellow", + ) + ) - if(offload_meta): - print(colored('Warning: Meta-data offloading is deprecated and will be removed. It is not supported for quantized model serialization.', 'yellow')) + if offload_meta: + print( + colored( + "Warning: Meta-data offloading is deprecated and will be removed. It is not supported for quantized model serialization.", + "yellow", + ) + ) if offload_meta: if quant_scale != quant_zero: diff --git a/setup.py b/setup.py index 6a455de..6dc17e4 100755 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ def run(self): setup( name="hqq", - version="0.2.0", + version="0.2.1", description="Half-Quadratic Quantization (HQQ)", url="https://github.com/mobiusml/hqq/", author="Dr. Hicham Badri",