diff --git a/neural_speed/convert/convert_chatglm.py b/neural_speed/convert/convert_chatglm.py index 2e988139d..92bb8734a 100644 --- a/neural_speed/convert/convert_chatglm.py +++ b/neural_speed/convert/convert_chatglm.py @@ -50,6 +50,7 @@ def bytes_to_unicode(): class SentencePieceVocab: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None: self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) added_tokens: Dict[str, int] @@ -149,11 +150,11 @@ def chatglm2_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams print("ChatGLM-2.gguf converting: ") list_vars = model.state_dict() for name in list_vars.keys(): - print(name, list_vars[name].shape, list_vars[name].dtype) + print("%-80s" % name, list_vars[name].shape, list_vars[name].dtype) print(hparams) - gguf_file = fname_out + '.gguf' + gguf_file = fname_out gguf_writer = gguf.GGUFWriter(gguf_file, "chatglm2") arch = "chatglm2." @@ -285,35 +286,68 @@ def write_vocab_gguf(dir_model): print("gguf: get tensor metadata") for name in list_vars.keys(): data = list_vars[name].squeeze().numpy() - - print("Processing variable: " + name + " with shape: ", data.shape) if 'inv_freq' in name: + print("Converting: %-75s" % name, " shape: %-15s" % str(data.shape)) continue + print("Converting: %-75s" % name, " shape: %-15s" % str(data.shape), end=" ") n_dims = len(data.shape) # ftype == 0 -> float32, ftype == 1 -> float16 ftype_cur = 0 if ftype != 0: if name[-7:] == ".weight" and n_dims == 2: - print(" Converting to float16") + print(" to float16".rjust(15)) data = data.astype(np.float16) ftype_cur = 1 else: - print(" Converting to float32") + print(" to float32".rjust(15)) data = data.astype(np.float32) ftype_cur = 0 else: if data.dtype != np.float32: - print(" Converting to float32") + print(" to float32".rjust(15)) data = data.astype(np.float32) ftype_cur = 0 - # print(f"[{i+1:{padi}d}/{len(model)}] - # Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4}") - gguf_writer.add_tensor(name, data) + if "mlp.dense_h_to_4h" in name: + name_0 = name.replace("dense_h_to_4h", "dense_h_to_4h_0") + name_1 = name.replace("dense_h_to_4h", "dense_h_to_4h_1") + shape_0 = data.shape[0] + half_shape_0 = int(shape_0 / 2) + data_0 = data[0:half_shape_0, :] + data_1 = data[half_shape_0:shape_0, :] + + print("Converting: %-75s" % name_0, " shape: %-15s" % str(data_0.shape)) + print("Converting: %-75s" % name_1, " shape: %-15s" % str(data_1.shape)) + + n_dims = len(data_0.shape) + assert (len(data_0.shape) == len(data_1.shape)) + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0 + if ftype != 0: + if name_0[-7:] == ".weight" and n_dims == 2: + print(" to float16".rjust(15)) + data_0 = data_0.astype(np.float16) + data_1 = data_1.astype(np.float32) + ftype_cur = 1 + else: + print(" to float32".rjust(15)) + data_0 = data_0.astype(np.float32) + data_1 = data_1.astype(np.float32) + ftype_cur = 0 + else: + if data_0.dtype != np.float32: + print(" to float32".rjust(15)) + data_0 = data_0.astype(np.float32) + data_1 = data_1.astype(np.float32) + ftype_cur = 0 + + gguf_writer.add_tensor(name_0, data_0) + gguf_writer.add_tensor(name_1, data_1) + print("gguf: write header") gguf_writer.write_header_to_file() print("gguf: write metadata") @@ -363,9 +397,9 @@ def chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor - fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled - fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings - fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) + fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled + fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings + fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1)) fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2)) @@ -419,10 +453,56 @@ def chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): for i in range(n_dims): fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) fout.write(str) - # data data.tofile(fout) + if "mlp.dense_h_to_4h" in name: + name_0 = name.replace("dense_h_to_4h", "dense_h_to_4h_0") + name_1 = name.replace("dense_h_to_4h", "dense_h_to_4h_1") + shape_0 = data.shape[0] + half_shape_0 = int(shape_0 / 2) + data_0 = data[0:half_shape_0, :] + data_1 = data[half_shape_0:shape_0, :] + + print("Converting: %-75s" % name_0, " shape: ", data_0.shape) + print("Converting: %-75s" % name_1, " shape: ", data_1.shape) + + n_dims = len(data_0.shape) + assert (len(data_0.shape) == len(data_1.shape)) + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0 + if ftype != 0: + if name_0[-7:] == ".weight" and n_dims == 2: + print(" to float16".rjust(15)) + data_0 = data_0.astype(np.float16) + data_1 = data_1.astype(np.float32) + ftype_cur = 1 + else: + print(" to float32".rjust(15)) + data_0 = data_0.astype(np.float32) + data_1 = data_1.astype(np.float32) + ftype_cur = 0 + else: + if data_0.dtype != np.float32: + print(" to float32".rjust(15)) + data_0 = data_0.astype(np.float32) + data_1 = data_1.astype(np.float32) + ftype_cur = 0 + + str_0 = name_0.encode("utf-8") + fout.write(struct.pack("iii", n_dims, len(str_0), ftype_cur)) + for i in range(n_dims): + fout.write(struct.pack("i", data_0.shape[n_dims - 1 - i])) + fout.write(str_0) + data_0.tofile(fout) + + str_1 = name_1.encode("utf-8") + fout.write(struct.pack("iii", n_dims, len(str_1), ftype_cur)) + for i in range(n_dims): + fout.write(struct.pack("i", data_1.shape[n_dims - 1 - i])) + fout.write(str_1) + data_1.tofile(fout) + fout.close() print("Done. Output file: " + fname_out) diff --git a/neural_speed/models/chatglm/chatglm2.cpp b/neural_speed/models/chatglm/chatglm2.cpp index 0501dc68a..001ef95ab 100644 --- a/neural_speed/models/chatglm/chatglm2.cpp +++ b/neural_speed/models/chatglm/chatglm2.cpp @@ -150,11 +150,11 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i // self-attention cur = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps); - cur = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[0], cur), cur); + cur = ne_mul(ctx0, cur, model.layers[il].norm[0]); { // compute QKV cur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur); - cur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[1], cur), cur); + cur = ne_add(ctx0, cur, model.layers[il].attn[1]); struct ne_tensor* query_layer = ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1], @@ -298,19 +298,25 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i // mlp.forward struct ne_tensor* mlp_output = ne_rms_norm(ctx0, hidden_states, hparams.rms_norm_eps); - ne_set_name(mlp_output, "mlp_output"); - // mlp_output = ne_mul(ctx0, mlp_output, model.layers[il].norm[1]); - mlp_output = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[1], mlp_output), mlp_output); - - mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[0], mlp_output); - struct ne_tensor* x0 = ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1], 0); - x0 = ne_silu(ctx0, x0); - struct ne_tensor* x1 = ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1], - mlp_output->ne[0] / 2 * ne_element_size(mlp_output)); - ne_set_name(x0, "x0"); - ne_set_name(x1, "x1"); - mlp_output = ne_mul(ctx0, x0, x1); - mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[1], mlp_output); + mlp_output = ne_mul(ctx0, mlp_output, model.layers[il].norm[1]); + + if (model.layers[il].ffn_fusion && + bestla_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[2]->data, model.layers[il].ffn[1]->data, + model.layers[il].ffn[3]->data, N, int(cur->ne[0] / 2), + model.layers[il].ffn[2]->ne[1], model.layers[il].ffn[1]->ne[1])) { + mlp_output = + ne_ffn_silu(ctx0, model.layers[il].ffn[2], model.layers[il].ffn[1], model.layers[il].ffn[3], mlp_output); + } else { + // mlp.forward + mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[0], mlp_output); + struct ne_tensor* x0 = + ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1], 0); + x0 = ne_silu(ctx0, x0); + struct ne_tensor* x1 = ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1], + mlp_output->ne[0] / 2 * ne_element_size(mlp_output)); + mlp_output = ne_mul(ctx0, x0, x1); + mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[1], mlp_output); + } #ifdef NS_TP_MODEL if (enable_tp) { @@ -327,9 +333,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i // norm { inpL = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps); - ne_set_name(inpL, "inpL"); - // inpL = ne_mul(ctx0, inpL, model.others[1]); - inpL = ne_mul(ctx0, ne_repeat(ctx0, model.others[1], inpL), inpL); + inpL = ne_mul(ctx0, inpL, model.others[1]); } lctx.use_buf(ctx0, -1); diff --git a/neural_speed/models/chatglm/chatglm2.h b/neural_speed/models/chatglm/chatglm2.h index c5d551f58..35db8175b 100644 --- a/neural_speed/models/chatglm/chatglm2.h +++ b/neural_speed/models/chatglm/chatglm2.h @@ -26,7 +26,7 @@ enum chatglm2_model { static const model_scratch chatglm_mem_req(int n_layers) { switch (n_layers) { case 28: - return {2048ull * MB, 2048ull * MB, 4096ull * MB}; + return {4096ull * MB, 4096ull * MB, 8192ull * MB}; default: MODEL_ASSERT(false); } diff --git a/neural_speed/models/chatglm/chatglm2_utils.cpp b/neural_speed/models/chatglm/chatglm2_utils.cpp index 58ca1e87e..3fd38d2f3 100644 --- a/neural_speed/models/chatglm/chatglm2_utils.cpp +++ b/neural_speed/models/chatglm/chatglm2_utils.cpp @@ -67,14 +67,13 @@ void CHATGLM2::init(const char* path_model, model_context* ctx, int n_gpu_layer_ model.hparams = ml->file_loaders.at(0)->hparams; model_file_version file_version = ml->file_loaders.at(0)->file_version; auto& hparams = model.hparams; - n_ff = 4 * hparams.n_embd; - fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab); - fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); - fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); - fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); - fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); - fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); - fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); + fprintf(stderr, "%s: hparams.n_vocab = %u\n", __func__, hparams.n_vocab); + fprintf(stderr, "%s: hparams.n_embd = %u\n", __func__, hparams.n_embd); + fprintf(stderr, "%s: hparams.n_mult = %u\n", __func__, hparams.n_mult); + fprintf(stderr, "%s: hparams.n_head = %u\n", __func__, hparams.n_head); + fprintf(stderr, "%s: hparams.n_layer = %u\n", __func__, hparams.n_layer); + fprintf(stderr, "%s: hparams.n_rot = %u\n", __func__, hparams.n_rot); + fprintf(stderr, "%s: hparams.ffn_hidden_size = %u\n", __func__, hparams.ffn_hidden_size); fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size()); n_embd = hparams.n_embd; n_vocab = hparams.n_vocab; @@ -149,6 +148,15 @@ void CHATGLM2::load(model_context* ctx, model_progress_callback progress_callbac layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.dense_4h_to_h.weight", {uint32_t(hparams.ffn_hidden_size), n_embd}, backend); + if (ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight") && + ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight")) { + layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight", + {n_embd, uint32_t(hparams.ffn_hidden_size)}, backend); + layer.ffn[3] = ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight", + {n_embd, uint32_t(hparams.ffn_hidden_size)}, backend); + layer.ffn_fusion = true; + } + // kv-cache layer.k_cache = nullptr; // kv-cache will be init later in model_utils layer.v_cache = nullptr; // kv-cache will be init later in model_utils @@ -161,6 +169,7 @@ void CHATGLM2::load(model_context* ctx, model_progress_callback progress_callbac } } + fprintf(stderr, "%s: layers[0].ffn_fusion = %u\n", __func__, model.layers[0].ffn_fusion); // print memory requirements // this is the total memory required to run the inference const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory diff --git a/neural_speed/models/model_utils/model_types.h b/neural_speed/models/model_utils/model_types.h index d438dac33..33e7df888 100644 --- a/neural_speed/models/model_utils/model_types.h +++ b/neural_speed/models/model_utils/model_types.h @@ -160,6 +160,8 @@ struct model_layer { struct ne_tensor* k_cache; struct ne_tensor* v_cache; + + bool ffn_fusion = false; }; typedef int32_t model_pos; diff --git a/tests/model-test/cpp_graph_inference.sh b/tests/model-test/cpp_graph_inference.sh index e5e45b2da..973c09892 100644 --- a/tests/model-test/cpp_graph_inference.sh +++ b/tests/model-test/cpp_graph_inference.sh @@ -221,7 +221,7 @@ function main() { infer_cmd="./build/bin/run_dolly" elif [[ "${model}" == "chatglm2" ]]; then quant_script="./build/bin/quant_chatglm2" - convert_script="${convert_script}/convert_chatglm.py" + convert_script="${convert_script}/convert_chatglm.py --format=GGUF" infer_cmd="./build/bin/run_chatglm2" elif [[ "${model}" == "chatglm-6b" ]]; then quant_script="./build/bin/quant_chatglm"