Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
enable llama2 awq (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenwei-intel authored Jan 4, 2024
1 parent fad80b1 commit 9be307f
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 22 deletions.
10 changes: 9 additions & 1 deletion neural_speed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_model_type(model_config):
model_type = "chatglm2"
return model_type

def init(self, model_name, not_quant=False, use_cache=False,
def init(self, model_name, not_quant=False, use_cache=False, use_gptq=False, use_awq=False,
weight_dtype="int4", alg="sym", group_size=32,
scale_dtype="fp32", compute_dtype="int8", use_ggml=False):
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
Expand All @@ -94,6 +94,10 @@ def init(self, model_name, not_quant=False, use_cache=False,
quant_desc += "_pc"
else:
quant_desc += "_g{}".format(group_size)
if use_gptq:
quant_desc = "gptq"
if use_awq:
quant_desc = "awq"
quant_bin = "{}/ne_{}_q_{}.bin".format(output_path, model_type, quant_desc)

if not_quant:
Expand All @@ -103,6 +107,10 @@ def init(self, model_name, not_quant=False, use_cache=False,
if use_cache and os.path.exists(self.bin_file):
return

if use_gptq or use_awq:
convert_model(model_name, quant_bin, "f32")
return

if not use_cache or not os.path.exists(fp32_bin):
convert_model(model_name, fp32_bin, "f32")
assert os.path.exists(fp32_bin), "Fail to convert pytorch model"
Expand Down
6 changes: 3 additions & 3 deletions neural_speed/convert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def convert_model(model, outfile, outtype):
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
model_type = model_maps.get(config.model_type, config.model_type)

gpt_model = 'gptq' in str(model).lower()
if gpt_model:
path = Path(Path(__file__).parent.absolute(), "convert_gptq_{}.py".format(model_type))
quantized_model = 'gptq' in str(model).lower() or 'awq' in str(model).lower()
if quantized_model:
path = Path(Path(__file__).parent.absolute(), "convert_quantized_{}.py".format(model_type))
else:
path = Path(Path(__file__).parent.absolute(), "convert_{}.py".format(model_type))
cmd = []
Expand Down
63 changes: 47 additions & 16 deletions neural_speed/convert/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
import struct
import json
import warnings
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
Literal, Optional, Sequence, Tuple, TypeVar, Union)
from sentencepiece import SentencePieceProcessor # type: ignore
Expand Down Expand Up @@ -169,8 +170,18 @@ def qzeros_to_zeros(qzeros, bits=4):
col += 1
return zeros


def unpack_weight(qweight, scales, qzeros, q_config):
if "quant_method" not in q_config:
raise ValueError(f"Unsupported q_config without quant_method: {q_config}")
quant_method = q_config["quant_method"]
if quant_method == "gptq":
return unpack_gptq_weight(qweight, scales, qzeros, q_config)
if quant_method == "awq":
return unpack_awq_weight(qweight, scales, qzeros, q_config)
raise ValueError(f"Unsupported quant_method: {quant_method}")


def unpack_gptq_weight(qweight, scales, qzeros, q_config):
group_size = q_config['group_size']
bits = q_config['bits']
wf = torch.tensor([[ 0, 4, 8, 12, 16, 20, 24, 28]], dtype=torch.int32)
Expand All @@ -179,16 +190,29 @@ def unpack_weight(qweight, scales, qzeros, q_config):
torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

zeros = zeros + 1
# zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
zeros = zeros.reshape(scales.shape)

# scales = scales
# scales = scales.reshape(-1, 1, scales.shape[-1])

weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1),
wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(weight,(2 ** bits) - 1, out=weight)
# int_weight = weight.reshape(-1, group_size, weight.shape[2])

return weight, scales, zeros


def unpack_awq_weight(qweight, scales, qzeros, q_config):
group_size = q_config['group_size']
bits = q_config['bits']
order_map = [0, 4, 1, 5, 2, 6, 3, 7]

pack_num = 8
weight = torch.zeros(qweight.shape[0], qweight.shape[1] * pack_num)
zeros = torch.zeros(qzeros.shape[0], qzeros.shape[1] * pack_num)
for col in range(qweight.shape[1]):
for i in range(pack_num):
w_col = torch.bitwise_right_shift(qweight[:, col], 4 * order_map[i])
weight[:, col * pack_num + i] = torch.bitwise_and(w_col, (2 ** bits) - 1)
z_col = torch.bitwise_right_shift(qzeros[:, col], 4 * order_map[i])
zeros[:, col * pack_num + i] = torch.bitwise_and(z_col, (2 ** bits) - 1)

return weight, scales, zeros

Expand All @@ -210,7 +234,7 @@ def find_quantized_model_file(model_path):
print(f"Detected model file {found[0]}")
return str(found[0])

def load_gptq_model(model_path):
def load_quantized_model(model_path):
input_path = find_quantized_model_file(model_path)
model = None
if input_path.endswith('pt'):
Expand All @@ -224,9 +248,10 @@ def load_gptq_model(model_path):
with open(model_path + '/config.json', "r", encoding="utf-8") as f:
config = json.load(f)

with open(model_path + '/quantize_config.json', "r", encoding="utf-8") as f:
quantize_config = json.load(f)
return model, config, quantize_config
quantize_config = config["quantization_config"]
if "zero_point" in quantize_config:
quantize_config["sym"] = not quantize_config["zero_point"]
return model, config, config["quantization_config"]


def convert_fp32_tensor(src_name, dst_name, model, fout):
Expand Down Expand Up @@ -304,15 +329,17 @@ def convert_q4_f32_tensor(src_name, dst_name, model, fout, q_config, n_head, n_h
qweight = model[f"{src_name}.qweight"]

weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
# import pdb; pdb.set_trace()
# weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2])
# num_itr = g_idx.shape[0]//x.shape[-1]
if q_config['desc_act']:
if 'desc_act' in q_config and q_config['desc_act']:
g_idx = model[f"{src_name}.g_idx"]
weight = (gptq_scales[g_idx.long()] * (weight - gptq_zeros[g_idx.long()]))
else:
infeatures = weight.shape[0]
g_idx = torch.tensor([i // q_config["group_size"] for i in range(infeatures)], dtype=torch.int32)
weight = (gptq_scales[g_idx.long()] * (weight - gptq_zeros[g_idx.long()]))
scale_zeros = gptq_zeros * gptq_scales
weight = (gptq_scales[g_idx.long()] * weight - scale_zeros[g_idx.long()])

weight = weight.t()
weight = weight.float()
Expand All @@ -327,7 +354,8 @@ def convert_q4_f32_tensor(src_name, dst_name, model, fout, q_config, n_head, n_h


def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None):
import intel_extension_for_transformers.llm.runtime.graph.llama_cpp as cpp_model
# unpack weight and repack into jblas format
import neural_speed.llama_cpp as cpp_model
qzeros = model[f"{src_name}.qzeros"]
zeros = qzeros_to_zeros(qzeros)
scales = model[f"{src_name}.scales"]
Expand All @@ -336,12 +364,14 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head,
int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
int_weight = int_weight.view(-1,int_weight.shape[-1])

# permute_func for llama-like model
if permute_func:
int_weight = permute_func(int_weight.t(), n_head, n_head_kv).t().contiguous()
gptq_scales = permute_func(gptq_scales.t(), n_head, n_head_kv).t().contiguous()
gptq_zeros = permute_func(gptq_zeros.t(), n_head, n_head_kv).t().contiguous()

if q_config['desc_act']:
# shuffle weight in GPTQ when act order is on
if 'desc_act'in q_config and q_config['desc_act']:
g_idx = model[f"{src_name}.g_idx"]
int_weight2 = int_weight.clone()
group_size=q_config['group_size']
Expand Down Expand Up @@ -371,11 +401,12 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head,
gptq_zeros = np.empty(0, dtype=np.int8)
else:
gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy())
if q_config['desc_act']:
if 'desc_act'in q_config and q_config['desc_act']:
g_idx = np.ascontiguousarray(g_idx.numpy())
else:
g_idx = np.empty(0, dtype=np.int32)

# pack int weight in bestla format
byte_size = cpp_model.Model.np_bestla_qpack(int_weight, gptq_scales, gptq_zeros, g_idx, dst,
weight_dtype="int4" if q_config['bits'] == 4 else "int8",
group_size=q_config['group_size'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
out_path = args.outfile.as_posix()
model_path = args.model.as_posix()

model, config, quantize_config = load_gptq_model(model_path)
model, config, quantize_config = load_quantized_model(model_path)
f = open(out_path, "wb")

# 1. write hparams
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
out_path = args.outfile.as_posix()
model_path = args.model.as_posix()

model, config, quantize_config = load_gptq_model(model_path)
model, config, quantize_config = load_quantized_model(model_path)
f = open(out_path, "wb")

# 1. write hparams
Expand Down

0 comments on commit 9be307f

Please sign in to comment.