From b8125533218865e207599523201640aafd06ace5 Mon Sep 17 00:00:00 2001 From: Zhenzhong1 Date: Tue, 30 Jan 2024 22:52:23 -0800 Subject: [PATCH] fixed the dtype issue --- neural_speed/convert/convert_quantized_llama.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/neural_speed/convert/convert_quantized_llama.py b/neural_speed/convert/convert_quantized_llama.py index cbe55e473..6d001ba39 100644 --- a/neural_speed/convert/convert_quantized_llama.py +++ b/neural_speed/convert/convert_quantized_llama.py @@ -66,10 +66,13 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, shape = int_weight.shape write_header(fout, shape[::-1], dst_name, GGML_QJBLAS_TYPE) + weight_dtype = "int8" if q_config['bits'] == 4: int_weight = (int_weight - 8) * 16 gptq_scales = gptq_scales / 16 gptq_zeros = (gptq_zeros - 8) * 16 + weight_dtype == "int4" + dst = np.zeros((int_weight.shape[0], int_weight.shape[1] * 4), dtype=np.int8) int_weight = np.ascontiguousarray(int_weight.numpy()) gptq_scales = np.ascontiguousarray((gptq_scales.float()).numpy()) @@ -84,7 +87,7 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, # pack int weight in bestla format byte_size = cpp_model.Model.np_bestla_qpack(int_weight, gptq_scales, gptq_zeros, g_idx, dst, - weight_dtype="int4" if q_config['bits'] == 4 else "int8", + weight_dtype=weight_dtype, group_size=q_config['group_size'], alg="sym" if q_config['sym'] else "asym", compute_dtype="int8")