Skip to content

Commit

Permalink
safetensors-compatible state_dict iteration: add quant_config for flo…
Browse files Browse the repository at this point in the history
…at scale/zero
  • Loading branch information
mobicham committed Jul 16, 2024
1 parent 96a5cc1 commit 74bbe01
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 11 deletions.
53 changes: 51 additions & 2 deletions hqq/core/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
"group_size": int,
"nbits": int,
"shape": torch.Size,
"channel_wise": bool,
"optimize": bool,
"round_zero": bool,
}


Expand Down Expand Up @@ -561,19 +564,36 @@ def cpu(self):
# TODO: later
return self

# state_dict is encoded by default for safetensors support. You can get the raw dict by setting self.encoded_state_dict=False. \
# Note: you can't change the state once it's done
def state_dict(self, *args, **kwargs): # nn.Module override compatible
if (
self.quant_config["scale_quant_params"]
or self.quant_config["zero_quant_params"]
) and self.encoded_state_dict:
raise Exception(
"Unsupported serialization for quantized scale/zero and self.encoded_state_dict=True"
)
# TODO: add support for quantized zero/scale case (quant_config and zero/scale)

_encode_type = (
encode_safetensor_type if (self.encoded_state_dict) else lambda z: z
)

# Core data
state = {"W_q": self.W_q} | {k: _encode_type(v) for k, v in self.meta.items()}
if self.bias is not None:
state["bias"] = self.bias
state["offload_meta"] = _encode_type(self.offload_meta)

# Encoding flag
if self.encoded_state_dict:
state["encoded_state_dict"] = _encode_type(self.encoded_state_dict)
# TODO: add support for quant zero/scale
# TODO: add quant_config

# Quant config
state["stores_quant_config"] = _encode_type(True)
for k in self.quant_config["weight_quant_params"]:
state[k] = _encode_type(self.quant_config["weight_quant_params"][k])

if "destination" in kwargs and "prefix" in kwargs:
for key, value in state.items():
Expand Down Expand Up @@ -620,8 +640,37 @@ def load_state_dict(self, state_dict, strict=True, assign=False):
decode_safetensor_type if (encoded_state_dict) else lambda z, w: z
)

# Quant-config
if state_dict.pop(
"stores_quant_config", False
): # check for backward compatibility
self.quant_config = {
"weight_quant_params": {
k: _decode_type(state_dict[k], _META_TYPE[k])
for k in [
"nbits",
"channel_wise",
"group_size",
"optimize",
"round_zero",
"axis",
"view_as_float",
]
}
}
# TODO: scale/zero quant use-case
self.quant_config["scale_quant_params"] = state_dict.pop(
"scale_quant_params", None
)
self.quant_config["zero_quant_params"] = state_dict.pop(
"zero_quant_params", None
)

# W_q/ bias
self.W_q = state_dict.pop("W_q")
self.bias = state_dict.pop("bias", None)

# Meta
self.offload_meta = _decode_type(state_dict.pop("offload_meta", False), bool)
if "meta" in state_dict:
self.meta = state_dict["meta"] # Backward compatibility
Expand Down
5 changes: 1 addition & 4 deletions hqq/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ def zero_pad_row(

# Map a Pytorch dtype into a safetensor dtype
def encode_safetensor_type(data):
if data is None:
return None
if isinstance(data, (torch.Tensor, torch.nn.Parameter)):
return data
if isinstance(data, torch.Size):
Expand All @@ -48,10 +46,9 @@ def encode_safetensor_type(data):
if isinstance(data, str):
return torch.tensor([ord(i) for i in data], dtype=torch.uint8)


# Decode a safetensor dtype into a Pytorch dtype
def decode_safetensor_type(data, data_type):
if data_type is None:
return None
if data_type in [torch.Tensor, torch.nn.Parameter]:
return data
if data_type is torch.Size:
Expand Down
12 changes: 7 additions & 5 deletions hqq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@

try:
from ..backends.bitblas import HQQLinearBitBlas

_HQQ_BACKEND_CLASSES.append(HQQLinearBitBlas)
except Exception:
pass
pass

try:
from ..backends.marlin import MarlinLinear

_HQQ_BACKEND_CLASSES.append(MarlinLinear)
except Exception:
pass
pass


# Defined what is qualified as "linear layer"
Expand Down Expand Up @@ -401,10 +403,10 @@ def serialize_weights(cls, model, verbose: bool = False) -> dict:
if name in ignore_keys:
continue
try:
module.encoded_state_dict = (
False # disable state_dict encoding for safetensors
)
# disable state_dict encoding for safetensors
module.encoded_state_dict = False
state_dict = module.state_dict()

if len(state_dict) > 0:
weights[name] = dict(state_dict)
except Exception:
Expand Down

0 comments on commit 74bbe01

Please sign in to comment.