Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
rename func names & add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenzhong1 committed Jan 31, 2024
1 parent 4558296 commit 48fd856
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
13 changes: 9 additions & 4 deletions neural_speed/convert/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,11 @@ def unpack_weight(qweight, scales, qzeros, q_config):
def unpack_gptq_weight_4bits(qweight, scales, qzeros, q_config):
group_size = q_config['group_size']
bits = q_config['bits']
wf = torch.tensor([[ 0, 4, 8, 12, 16, 20, 24, 28]], dtype=torch.int32)
s32_bits = 32

assert bits == 4
# Int32 can store 8 * 4bits data. This is the offset for each data.
wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)
zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits),
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)
Expand All @@ -229,11 +233,13 @@ def unpack_gptq_weight_4bits(qweight, scales, qzeros, q_config):

def unpack_gptq_weight_3bits(qweight, scales, qzeros, q_config):
print("unpack_gptq_weight_3bits... ", end='')

group_size = q_config['group_size']
bits = q_config['bits']
s32_bits = 32

wf = torch.tensor([[0, 3, 6, 9, 12, 15, 18, 21, 24, 27]], dtype=torch.int32)
assert bits == 3
# Int32 can only store 10 * 3bits data. This is the offset for each data.
wf = torch.tensor([[ i for i in range(0, s32_bits - bits, bits)]], dtype=torch.int32)
zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits),
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)
Expand Down Expand Up @@ -384,7 +390,6 @@ def convert_q4_f32_tensor(src_name, dst_name, model, fout, q_config, n_head, n_h
qweight = model[f"{src_name}.qweight"]

weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
# import pdb; pdb.set_trace()
# weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2])
# num_itr = g_idx.shape[0]//x.shape[-1]
if 'desc_act' in q_config and q_config['desc_act']:
Expand Down
21 changes: 12 additions & 9 deletions neural_speed/convert/convert_quantized_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def permute_func(weights, n_head: int, n_head_kv: int):
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2,
*weights.shape[1:]).swapaxes(1, 2).reshape(weights.shape))

def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config):
# unpack weight and repack into jblas format
def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config):
# unpack weight and repack into 3bits / 4bits BestLA format
import neural_speed.llama_cpp as cpp_model
if ".weight" in src_name:
src_name = src_name.replace(".weight", "")
Expand Down Expand Up @@ -62,6 +62,9 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config):
shape = int_weight.shape
write_header(fout, shape[::-1], dst_name, GGML_QJBLAS_TYPE)

# INC stores sig-int4 value as u4(range 0~15, they add a offset),
# BesTLA requires s4_clip((-8,7)*16), so we sub the offset and then mul 16.
# Int3 is the same as int4, but offset=4, mul scale==32.
if q_config['bits'] == 4:
int_weight = (int_weight - 8) * 16
gptq_scales = gptq_scales / 16
Expand All @@ -82,7 +85,7 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config):
else:
g_idx = np.empty(0, dtype=np.int32)

# pack int weight in bestla format
# repack int weight in BesTLA format
byte_size = cpp_model.Model.np_bestla_qpack(int_weight, gptq_scales, gptq_zeros, g_idx, dst,
weight_dtype="int4" if q_config['bits'] == 4 else "int8",
group_size=q_config['group_size'],
Expand Down Expand Up @@ -169,18 +172,18 @@ def main(args_in: Optional[List[str]] = None) -> None:
convert_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, fout)

for i in tqdm(range(n_layer), desc="Processing layers"):
convert_q4_bestla_tensor(f"transformer.h.{i}.attn.q_proj.weight",
convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.q_proj.weight",
f"transformer.h.{i}.attn.q_proj.weight", list_vars, fout, quantize_config)
convert_q4_bestla_tensor(f"transformer.h.{i}.attn.k_proj.weight",
convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.k_proj.weight",
f"transformer.h.{i}.attn.k_proj.weight", list_vars, fout, quantize_config)
convert_q4_bestla_tensor(f"transformer.h.{i}.attn.v_proj.weight",
convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.v_proj.weight",
f"transformer.h.{i}.attn.v_proj.weight", list_vars, fout, quantize_config)

convert_q4_bestla_tensor(f"transformer.h.{i}.attn.out_proj.weight",
convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.out_proj.weight",
f"transformer.h.{i}.attn.out_proj.weight", list_vars, fout, quantize_config)
convert_q4_bestla_tensor(f"transformer.h.{i}.mlp.fc_in.weight",
convert_to_qx_bestla_tensor(f"transformer.h.{i}.mlp.fc_in.weight",
f"transformer.h.{i}.mlp.fc_in.weight", list_vars, fout, quantize_config)
convert_q4_bestla_tensor(f"transformer.h.{i}.mlp.fc_out.weight",
convert_to_qx_bestla_tensor(f"transformer.h.{i}.mlp.fc_out.weight",
f"transformer.h.{i}.mlp.fc_out.weight", list_vars, fout, quantize_config)

convert_fp32_tensor(f"transformer.h.{i}.mlp.fc_in.bias",
Expand Down

0 comments on commit 48fd856

Please sign in to comment.