From cff66465edbd691898be9bb7a20ceec674089c64 Mon Sep 17 00:00:00 2001 From: Zhenzhong1 Date: Mon, 26 Feb 2024 20:05:26 -0800 Subject: [PATCH 01/10] convert done --- neural_speed/convert/convert_chatglm.py | 54 +++++++++++++++++++------ 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/neural_speed/convert/convert_chatglm.py b/neural_speed/convert/convert_chatglm.py index 2e988139d..0aa62ce6f 100644 --- a/neural_speed/convert/convert_chatglm.py +++ b/neural_speed/convert/convert_chatglm.py @@ -50,6 +50,7 @@ def bytes_to_unicode(): class SentencePieceVocab: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None: self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) added_tokens: Dict[str, int] @@ -149,7 +150,7 @@ def chatglm2_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams print("ChatGLM-2.gguf converting: ") list_vars = model.state_dict() for name in list_vars.keys(): - print(name, list_vars[name].shape, list_vars[name].dtype) + print("%-80s" % name, list_vars[name].shape, list_vars[name].dtype) print(hparams) @@ -285,9 +286,41 @@ def write_vocab_gguf(dir_model): print("gguf: get tensor metadata") for name in list_vars.keys(): data = list_vars[name].squeeze().numpy() - - print("Processing variable: " + name + " with shape: ", data.shape) if 'inv_freq' in name: + print("Converting: %-80s" % name, " shape: %-20s" % str(data.shape)) + continue + + print("Converting: %-80s" % name, " shape: %-20s" % str(data.shape), end=" ") + if "mlp.dense_h_to_4h" in name: + shape_0 = data.shape[0] + half_shape_0 = int(shape_0 / 2) + data_0 = data[0:half_shape_0, :] + data_1 = data[half_shape_0:shape_0, :] + n_dims = len(data.shape) + + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0 + if ftype != 0: + if name[-7:] == ".weight" and n_dims == 2: + print(" to float16".rjust(15)) + data = data.astype(np.float16) + ftype_cur = 1 + else: + print(" to float32".rjust(15)) + data = data.astype(np.float32) + ftype_cur = 0 + else: + if data.dtype != np.float32: + print(" to float32".rjust(15)) + data = data.astype(np.float32) + ftype_cur = 0 + + name_0 = name + "_0" + name_1 = name + "_1" + gguf_writer.add_tensor(name_0, data_0) + gguf_writer.add_tensor(name_1, data_1) + print("Converting: %-80s" % name_0, " shape: %-20s" % str(data_0.shape)) + print("Converting: %-80s" % name_1, " shape: %-20s" % str(data_1.shape)) continue n_dims = len(data.shape) @@ -296,22 +329,19 @@ def write_vocab_gguf(dir_model): ftype_cur = 0 if ftype != 0: if name[-7:] == ".weight" and n_dims == 2: - print(" Converting to float16") + print(" to float16".rjust(15)) data = data.astype(np.float16) ftype_cur = 1 else: - print(" Converting to float32") + print(" to float32".rjust(15)) data = data.astype(np.float32) ftype_cur = 0 else: if data.dtype != np.float32: - print(" Converting to float32") + print(" to float32".rjust(15)) data = data.astype(np.float32) ftype_cur = 0 - # print(f"[{i+1:{padi}d}/{len(model)}] - # Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4}") - gguf_writer.add_tensor(name, data) print("gguf: write header") @@ -363,9 +393,9 @@ def chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor - fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled - fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings - fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) + fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled + fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings + fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1)) fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2)) From ae963e7d79a9d199df89d1098e6e3ad8296e49ee Mon Sep 17 00:00:00 2001 From: Zhenzhong1 Date: Mon, 26 Feb 2024 20:14:03 -0800 Subject: [PATCH 02/10] fixed an issue, missing original tensor --- neural_speed/convert/convert_chatglm.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/neural_speed/convert/convert_chatglm.py b/neural_speed/convert/convert_chatglm.py index 0aa62ce6f..24a7b94bf 100644 --- a/neural_speed/convert/convert_chatglm.py +++ b/neural_speed/convert/convert_chatglm.py @@ -287,10 +287,10 @@ def write_vocab_gguf(dir_model): for name in list_vars.keys(): data = list_vars[name].squeeze().numpy() if 'inv_freq' in name: - print("Converting: %-80s" % name, " shape: %-20s" % str(data.shape)) + print("Converting: %-80s" % name, " shape: %-15s" % str(data.shape)) continue - print("Converting: %-80s" % name, " shape: %-20s" % str(data.shape), end=" ") + print("Converting: %-80s" % name, " shape: %-15s" % str(data.shape), end=" ") if "mlp.dense_h_to_4h" in name: shape_0 = data.shape[0] half_shape_0 = int(shape_0 / 2) @@ -315,12 +315,13 @@ def write_vocab_gguf(dir_model): data = data.astype(np.float32) ftype_cur = 0 + gguf_writer.add_tensor(name, data) name_0 = name + "_0" name_1 = name + "_1" gguf_writer.add_tensor(name_0, data_0) gguf_writer.add_tensor(name_1, data_1) - print("Converting: %-80s" % name_0, " shape: %-20s" % str(data_0.shape)) - print("Converting: %-80s" % name_1, " shape: %-20s" % str(data_1.shape)) + print("Converting: %-80s" % name_0, " shape: %-15s" % str(data_0.shape)) + print("Converting: %-80s" % name_1, " shape: %-15s" % str(data_1.shape)) continue n_dims = len(data.shape) From 21e917f62cafc59f5f583cca95dd461722d1bb56 Mon Sep 17 00:00:00 2001 From: Zhenzhong1 Date: Mon, 26 Feb 2024 20:54:35 -0800 Subject: [PATCH 03/10] inference pass --- neural_speed/convert/convert_chatglm.py | 73 ++++++++++--------- .../models/chatglm/chatglm2_utils.cpp | 25 +++++-- 2 files changed, 55 insertions(+), 43 deletions(-) diff --git a/neural_speed/convert/convert_chatglm.py b/neural_speed/convert/convert_chatglm.py index 24a7b94bf..ab4416687 100644 --- a/neural_speed/convert/convert_chatglm.py +++ b/neural_speed/convert/convert_chatglm.py @@ -287,43 +287,10 @@ def write_vocab_gguf(dir_model): for name in list_vars.keys(): data = list_vars[name].squeeze().numpy() if 'inv_freq' in name: - print("Converting: %-80s" % name, " shape: %-15s" % str(data.shape)) - continue - - print("Converting: %-80s" % name, " shape: %-15s" % str(data.shape), end=" ") - if "mlp.dense_h_to_4h" in name: - shape_0 = data.shape[0] - half_shape_0 = int(shape_0 / 2) - data_0 = data[0:half_shape_0, :] - data_1 = data[half_shape_0:shape_0, :] - n_dims = len(data.shape) - - # ftype == 0 -> float32, ftype == 1 -> float16 - ftype_cur = 0 - if ftype != 0: - if name[-7:] == ".weight" and n_dims == 2: - print(" to float16".rjust(15)) - data = data.astype(np.float16) - ftype_cur = 1 - else: - print(" to float32".rjust(15)) - data = data.astype(np.float32) - ftype_cur = 0 - else: - if data.dtype != np.float32: - print(" to float32".rjust(15)) - data = data.astype(np.float32) - ftype_cur = 0 - - gguf_writer.add_tensor(name, data) - name_0 = name + "_0" - name_1 = name + "_1" - gguf_writer.add_tensor(name_0, data_0) - gguf_writer.add_tensor(name_1, data_1) - print("Converting: %-80s" % name_0, " shape: %-15s" % str(data_0.shape)) - print("Converting: %-80s" % name_1, " shape: %-15s" % str(data_1.shape)) + print("Converting: %-75s" % name, " shape: %-15s" % str(data.shape)) continue + print("Converting: %-75s" % name, " shape: %-15s" % str(data.shape), end=" ") n_dims = len(data.shape) # ftype == 0 -> float32, ftype == 1 -> float16 @@ -345,6 +312,42 @@ def write_vocab_gguf(dir_model): gguf_writer.add_tensor(name, data) + if "mlp.dense_h_to_4h" in name: + name_0 = name.replace("dense_h_to_4h", "dense_h_to_4h_0") + name_1 = name.replace("dense_h_to_4h", "dense_h_to_4h_1") + shape_0 = data.shape[0] + half_shape_0 = int(shape_0 / 2) + data_0 = data[0:half_shape_0, :] + data_1 = data[half_shape_0:shape_0, :] + + print("Converting: %-75s" % name_0, " shape: %-15s" % str(data_0.shape)) + print("Converting: %-75s" % name_1, " shape: %-15s" % str(data_1.shape)) + + n_dims = len(data_0.shape) + assert (len(data_0.shape) == len(data_1.shape)) + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0 + if ftype != 0: + if name_0[-7:] == ".weight" and n_dims == 2: + print(" to float16".rjust(15)) + data_0 = data_0.astype(np.float16) + data_1 = data_1.astype(np.float32) + ftype_cur = 1 + else: + print(" to float32".rjust(15)) + data_0 = data_0.astype(np.float32) + data_1 = data_1.astype(np.float32) + ftype_cur = 0 + else: + if data_0.dtype != np.float32: + print(" to float32".rjust(15)) + data_0 = data_0.astype(np.float32) + data_1 = data_1.astype(np.float32) + ftype_cur = 0 + + gguf_writer.add_tensor(name_0, data_0) + gguf_writer.add_tensor(name_1, data_1) + print("gguf: write header") gguf_writer.write_header_to_file() print("gguf: write metadata") diff --git a/neural_speed/models/chatglm/chatglm2_utils.cpp b/neural_speed/models/chatglm/chatglm2_utils.cpp index 58ca1e87e..f34409555 100644 --- a/neural_speed/models/chatglm/chatglm2_utils.cpp +++ b/neural_speed/models/chatglm/chatglm2_utils.cpp @@ -67,14 +67,13 @@ void CHATGLM2::init(const char* path_model, model_context* ctx, int n_gpu_layer_ model.hparams = ml->file_loaders.at(0)->hparams; model_file_version file_version = ml->file_loaders.at(0)->file_version; auto& hparams = model.hparams; - n_ff = 4 * hparams.n_embd; - fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab); - fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); - fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); - fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); - fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); - fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); - fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); + fprintf(stderr, "%s: hparams.n_vocab = %u\n", __func__, hparams.n_vocab); + fprintf(stderr, "%s: hparams.n_embd = %u\n", __func__, hparams.n_embd); + fprintf(stderr, "%s: hparams.n_mult = %u\n", __func__, hparams.n_mult); + fprintf(stderr, "%s: hparams.n_head = %u\n", __func__, hparams.n_head); + fprintf(stderr, "%s: hparams.n_layer = %u\n", __func__, hparams.n_layer); + fprintf(stderr, "%s: hparams.n_rot = %u\n", __func__, hparams.n_rot); + fprintf(stderr, "%s: hparams.ffn_hidden_size = %u\n", __func__, hparams.ffn_hidden_size); fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size()); n_embd = hparams.n_embd; n_vocab = hparams.n_vocab; @@ -149,6 +148,16 @@ void CHATGLM2::load(model_context* ctx, model_progress_callback progress_callbac layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.dense_4h_to_h.weight", {uint32_t(hparams.ffn_hidden_size), n_embd}, backend); + if (ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight")) { + layer.ffn[2] = + ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight",{n_embd, uint32_t(hparams.ffn_hidden_size)}, backend); + } + + if (ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight")) { + layer.ffn[3] = + ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight", {n_embd, uint32_t(hparams.ffn_hidden_size)}, backend); + } + // kv-cache layer.k_cache = nullptr; // kv-cache will be init later in model_utils layer.v_cache = nullptr; // kv-cache will be init later in model_utils From 00a8b864d53cb108b8c102b4a52684c4ef0858df Mon Sep 17 00:00:00 2001 From: Zhenzhong1 Date: Mon, 26 Feb 2024 23:34:37 -0800 Subject: [PATCH 04/10] accuracy pass --- neural_speed/models/chatglm/chatglm2.cpp | 25 +++++++++++-------- .../models/chatglm/chatglm2_utils.cpp | 7 ++---- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/neural_speed/models/chatglm/chatglm2.cpp b/neural_speed/models/chatglm/chatglm2.cpp index 0501dc68a..b3e1013c9 100644 --- a/neural_speed/models/chatglm/chatglm2.cpp +++ b/neural_speed/models/chatglm/chatglm2.cpp @@ -298,19 +298,22 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i // mlp.forward struct ne_tensor* mlp_output = ne_rms_norm(ctx0, hidden_states, hparams.rms_norm_eps); - ne_set_name(mlp_output, "mlp_output"); - // mlp_output = ne_mul(ctx0, mlp_output, model.layers[il].norm[1]); mlp_output = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[1], mlp_output), mlp_output); - mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[0], mlp_output); - struct ne_tensor* x0 = ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1], 0); - x0 = ne_silu(ctx0, x0); - struct ne_tensor* x1 = ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1], - mlp_output->ne[0] / 2 * ne_element_size(mlp_output)); - ne_set_name(x0, "x0"); - ne_set_name(x1, "x1"); - mlp_output = ne_mul(ctx0, x0, x1); - mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[1], mlp_output); + if (bestla_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[2]->data, model.layers[il].ffn[1]->data, + model.layers[il].ffn[3]->data, N, int(cur->ne[0] / 2), + model.layers[il].ffn[2]->ne[1], model.layers[il].ffn[1]->ne[1])){ + mlp_output = ne_ffn_silu(ctx0, model.layers[il].ffn[2], model.layers[il].ffn[1], model.layers[il].ffn[3], mlp_output); + } else { + // mlp.forward + mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[0], mlp_output); + struct ne_tensor* x0 = ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1], 0); + x0 = ne_silu(ctx0, x0); + struct ne_tensor* x1 = ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1], + mlp_output->ne[0] / 2 * ne_element_size(mlp_output)); + mlp_output = ne_mul(ctx0, x0, x1); + mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[1], mlp_output); + } #ifdef NS_TP_MODEL if (enable_tp) { diff --git a/neural_speed/models/chatglm/chatglm2_utils.cpp b/neural_speed/models/chatglm/chatglm2_utils.cpp index f34409555..645a4e7e7 100644 --- a/neural_speed/models/chatglm/chatglm2_utils.cpp +++ b/neural_speed/models/chatglm/chatglm2_utils.cpp @@ -148,12 +148,9 @@ void CHATGLM2::load(model_context* ctx, model_progress_callback progress_callbac layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.dense_4h_to_h.weight", {uint32_t(hparams.ffn_hidden_size), n_embd}, backend); - if (ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight")) { + if (ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight") && ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight")) { layer.ffn[2] = - ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight",{n_embd, uint32_t(hparams.ffn_hidden_size)}, backend); - } - - if (ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight")) { + ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight", {n_embd, uint32_t(hparams.ffn_hidden_size)}, backend); layer.ffn[3] = ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight", {n_embd, uint32_t(hparams.ffn_hidden_size)}, backend); } From db8409e10392c5f0cee5958a5e4665333b56ee3e Mon Sep 17 00:00:00 2001 From: Zhenzhong1 Date: Tue, 27 Feb 2024 00:16:13 -0800 Subject: [PATCH 05/10] compatibility enhence --- neural_speed/models/chatglm/chatglm2.cpp | 5 +++-- neural_speed/models/chatglm/chatglm2_utils.cpp | 1 + neural_speed/models/model_utils/model_types.h | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/neural_speed/models/chatglm/chatglm2.cpp b/neural_speed/models/chatglm/chatglm2.cpp index b3e1013c9..cd0e180b3 100644 --- a/neural_speed/models/chatglm/chatglm2.cpp +++ b/neural_speed/models/chatglm/chatglm2.cpp @@ -300,8 +300,9 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i struct ne_tensor* mlp_output = ne_rms_norm(ctx0, hidden_states, hparams.rms_norm_eps); mlp_output = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[1], mlp_output), mlp_output); - if (bestla_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[2]->data, model.layers[il].ffn[1]->data, - model.layers[il].ffn[3]->data, N, int(cur->ne[0] / 2), + if (model.layers[il].ffn_fusion && bestla_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[2]->data, + model.layers[il].ffn[1]->data, model.layers[il].ffn[3]->data, + N, int(cur->ne[0] / 2), model.layers[il].ffn[2]->ne[1], model.layers[il].ffn[1]->ne[1])){ mlp_output = ne_ffn_silu(ctx0, model.layers[il].ffn[2], model.layers[il].ffn[1], model.layers[il].ffn[3], mlp_output); } else { diff --git a/neural_speed/models/chatglm/chatglm2_utils.cpp b/neural_speed/models/chatglm/chatglm2_utils.cpp index 645a4e7e7..655dacfde 100644 --- a/neural_speed/models/chatglm/chatglm2_utils.cpp +++ b/neural_speed/models/chatglm/chatglm2_utils.cpp @@ -153,6 +153,7 @@ void CHATGLM2::load(model_context* ctx, model_progress_callback progress_callbac ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight", {n_embd, uint32_t(hparams.ffn_hidden_size)}, backend); layer.ffn[3] = ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight", {n_embd, uint32_t(hparams.ffn_hidden_size)}, backend); + layer.ffn_fusion = true; } // kv-cache diff --git a/neural_speed/models/model_utils/model_types.h b/neural_speed/models/model_utils/model_types.h index d438dac33..33e7df888 100644 --- a/neural_speed/models/model_utils/model_types.h +++ b/neural_speed/models/model_utils/model_types.h @@ -160,6 +160,8 @@ struct model_layer { struct ne_tensor* k_cache; struct ne_tensor* v_cache; + + bool ffn_fusion = false; }; typedef int32_t model_pos; From 8f41d05c281c856f70b10af75335b870b4e668a5 Mon Sep 17 00:00:00 2001 From: Zhenzhong1 Date: Tue, 27 Feb 2024 03:12:16 -0800 Subject: [PATCH 06/10] add --format=GGUF to extension tests --- tests/model-test/cpp_graph_inference.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/model-test/cpp_graph_inference.sh b/tests/model-test/cpp_graph_inference.sh index e5e45b2da..973c09892 100644 --- a/tests/model-test/cpp_graph_inference.sh +++ b/tests/model-test/cpp_graph_inference.sh @@ -221,7 +221,7 @@ function main() { infer_cmd="./build/bin/run_dolly" elif [[ "${model}" == "chatglm2" ]]; then quant_script="./build/bin/quant_chatglm2" - convert_script="${convert_script}/convert_chatglm.py" + convert_script="${convert_script}/convert_chatglm.py --format=GGUF" infer_cmd="./build/bin/run_chatglm2" elif [[ "${model}" == "chatglm-6b" ]]; then quant_script="./build/bin/quant_chatglm" From 867369b78637d89830aeb18b29faff8ee8df7407 Mon Sep 17 00:00:00 2001 From: Zhenzhong1 Date: Tue, 27 Feb 2024 19:51:10 -0800 Subject: [PATCH 07/10] split tensors for the ne bin --- neural_speed/convert/convert_chatglm.py | 48 ++++++++++++++++++- .../models/chatglm/chatglm2_utils.cpp | 1 + 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/neural_speed/convert/convert_chatglm.py b/neural_speed/convert/convert_chatglm.py index ab4416687..57ef0e5fe 100644 --- a/neural_speed/convert/convert_chatglm.py +++ b/neural_speed/convert/convert_chatglm.py @@ -453,10 +453,56 @@ def chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): for i in range(n_dims): fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) fout.write(str) - # data data.tofile(fout) + if "mlp.dense_h_to_4h" in name: + name_0 = name.replace("dense_h_to_4h", "dense_h_to_4h_0") + name_1 = name.replace("dense_h_to_4h", "dense_h_to_4h_1") + shape_0 = data.shape[0] + half_shape_0 = int(shape_0 / 2) + data_0 = data[0:half_shape_0, :] + data_1 = data[half_shape_0:shape_0, :] + + print("Converting: %-75s" % name_0, " shape: ", data_0.shape) + print("Converting: %-75s" % name_1, " shape: ", data_1.shape) + + n_dims = len(data_0.shape) + assert (len(data_0.shape) == len(data_1.shape)) + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0 + if ftype != 0: + if name_0[-7:] == ".weight" and n_dims == 2: + print(" to float16".rjust(15)) + data_0 = data_0.astype(np.float16) + data_1 = data_1.astype(np.float32) + ftype_cur = 1 + else: + print(" to float32".rjust(15)) + data_0 = data_0.astype(np.float32) + data_1 = data_1.astype(np.float32) + ftype_cur = 0 + else: + if data_0.dtype != np.float32: + print(" to float32".rjust(15)) + data_0 = data_0.astype(np.float32) + data_1 = data_1.astype(np.float32) + ftype_cur = 0 + + str_0 = name_0.encode("utf-8") + fout.write(struct.pack("iii", n_dims, len(str_0), ftype_cur)) + for i in range(n_dims): + fout.write(struct.pack("i", data_0.shape[n_dims - 1 - i])) + fout.write(str_0) + data_0.tofile(fout) + + str_1 = name_1.encode("utf-8") + fout.write(struct.pack("iii", n_dims, len(str_1), ftype_cur)) + for i in range(n_dims): + fout.write(struct.pack("i", data_1.shape[n_dims - 1 - i])) + fout.write(str_1) + data_1.tofile(fout) + fout.close() print("Done. Output file: " + fname_out) diff --git a/neural_speed/models/chatglm/chatglm2_utils.cpp b/neural_speed/models/chatglm/chatglm2_utils.cpp index 655dacfde..018719aba 100644 --- a/neural_speed/models/chatglm/chatglm2_utils.cpp +++ b/neural_speed/models/chatglm/chatglm2_utils.cpp @@ -168,6 +168,7 @@ void CHATGLM2::load(model_context* ctx, model_progress_callback progress_callbac } } + fprintf(stderr, "%s: layers[0].ffn_fusion = %u\n", __func__, model.layers[0].ffn_fusion); // print memory requirements // this is the total memory required to run the inference const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory From 3264d2465510ba13f869fc4db7110796a59ce820 Mon Sep 17 00:00:00 2001 From: Zhenzhong1 Date: Wed, 28 Feb 2024 18:09:08 -0800 Subject: [PATCH 08/10] remove extended name & clang-format --- neural_speed/convert/convert_chatglm.py | 2 +- neural_speed/models/chatglm/chatglm2.cpp | 14 ++++++++------ neural_speed/models/chatglm/chatglm2_utils.cpp | 11 ++++++----- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/neural_speed/convert/convert_chatglm.py b/neural_speed/convert/convert_chatglm.py index 57ef0e5fe..92bb8734a 100644 --- a/neural_speed/convert/convert_chatglm.py +++ b/neural_speed/convert/convert_chatglm.py @@ -154,7 +154,7 @@ def chatglm2_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams print(hparams) - gguf_file = fname_out + '.gguf' + gguf_file = fname_out gguf_writer = gguf.GGUFWriter(gguf_file, "chatglm2") arch = "chatglm2." diff --git a/neural_speed/models/chatglm/chatglm2.cpp b/neural_speed/models/chatglm/chatglm2.cpp index cd0e180b3..00ee3e57e 100644 --- a/neural_speed/models/chatglm/chatglm2.cpp +++ b/neural_speed/models/chatglm/chatglm2.cpp @@ -300,15 +300,17 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i struct ne_tensor* mlp_output = ne_rms_norm(ctx0, hidden_states, hparams.rms_norm_eps); mlp_output = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[1], mlp_output), mlp_output); - if (model.layers[il].ffn_fusion && bestla_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[2]->data, - model.layers[il].ffn[1]->data, model.layers[il].ffn[3]->data, - N, int(cur->ne[0] / 2), - model.layers[il].ffn[2]->ne[1], model.layers[il].ffn[1]->ne[1])){ - mlp_output = ne_ffn_silu(ctx0, model.layers[il].ffn[2], model.layers[il].ffn[1], model.layers[il].ffn[3], mlp_output); + if (model.layers[il].ffn_fusion && + bestla_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[2]->data, model.layers[il].ffn[1]->data, + model.layers[il].ffn[3]->data, N, int(cur->ne[0] / 2), + model.layers[il].ffn[2]->ne[1], model.layers[il].ffn[1]->ne[1])) { + mlp_output = + ne_ffn_silu(ctx0, model.layers[il].ffn[2], model.layers[il].ffn[1], model.layers[il].ffn[3], mlp_output); } else { // mlp.forward mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[0], mlp_output); - struct ne_tensor* x0 = ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1], 0); + struct ne_tensor* x0 = + ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1], 0); x0 = ne_silu(ctx0, x0); struct ne_tensor* x1 = ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1], mlp_output->ne[0] / 2 * ne_element_size(mlp_output)); diff --git a/neural_speed/models/chatglm/chatglm2_utils.cpp b/neural_speed/models/chatglm/chatglm2_utils.cpp index 018719aba..3fd38d2f3 100644 --- a/neural_speed/models/chatglm/chatglm2_utils.cpp +++ b/neural_speed/models/chatglm/chatglm2_utils.cpp @@ -148,11 +148,12 @@ void CHATGLM2::load(model_context* ctx, model_progress_callback progress_callbac layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.dense_4h_to_h.weight", {uint32_t(hparams.ffn_hidden_size), n_embd}, backend); - if (ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight") && ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight")) { - layer.ffn[2] = - ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight", {n_embd, uint32_t(hparams.ffn_hidden_size)}, backend); - layer.ffn[3] = - ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight", {n_embd, uint32_t(hparams.ffn_hidden_size)}, backend); + if (ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight") && + ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight")) { + layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight", + {n_embd, uint32_t(hparams.ffn_hidden_size)}, backend); + layer.ffn[3] = ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight", + {n_embd, uint32_t(hparams.ffn_hidden_size)}, backend); layer.ffn_fusion = true; } From 6ebad2f7d950d374eb458fe5db5b7ffa8504dd8e Mon Sep 17 00:00:00 2001 From: Zhenzhong1 Date: Thu, 29 Feb 2024 22:25:21 -0800 Subject: [PATCH 09/10] extend mem --- neural_speed/models/chatglm/chatglm2.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_speed/models/chatglm/chatglm2.h b/neural_speed/models/chatglm/chatglm2.h index c5d551f58..35db8175b 100644 --- a/neural_speed/models/chatglm/chatglm2.h +++ b/neural_speed/models/chatglm/chatglm2.h @@ -26,7 +26,7 @@ enum chatglm2_model { static const model_scratch chatglm_mem_req(int n_layers) { switch (n_layers) { case 28: - return {2048ull * MB, 2048ull * MB, 4096ull * MB}; + return {4096ull * MB, 4096ull * MB, 8192ull * MB}; default: MODEL_ASSERT(false); } From 30d9d9d8cdd137d876eaf9e34a29fe4ac3840664 Mon Sep 17 00:00:00 2001 From: Zhenzhong1 Date: Thu, 29 Feb 2024 22:53:52 -0800 Subject: [PATCH 10/10] remove ne_repeat --- neural_speed/models/chatglm/chatglm2.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/neural_speed/models/chatglm/chatglm2.cpp b/neural_speed/models/chatglm/chatglm2.cpp index 00ee3e57e..001ef95ab 100644 --- a/neural_speed/models/chatglm/chatglm2.cpp +++ b/neural_speed/models/chatglm/chatglm2.cpp @@ -150,11 +150,11 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i // self-attention cur = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps); - cur = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[0], cur), cur); + cur = ne_mul(ctx0, cur, model.layers[il].norm[0]); { // compute QKV cur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur); - cur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[1], cur), cur); + cur = ne_add(ctx0, cur, model.layers[il].attn[1]); struct ne_tensor* query_layer = ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1], @@ -298,7 +298,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i // mlp.forward struct ne_tensor* mlp_output = ne_rms_norm(ctx0, hidden_states, hparams.rms_norm_eps); - mlp_output = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[1], mlp_output), mlp_output); + mlp_output = ne_mul(ctx0, mlp_output, model.layers[il].norm[1]); if (model.layers[il].ffn_fusion && bestla_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[2]->data, model.layers[il].ffn[1]->data, @@ -333,9 +333,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i // norm { inpL = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps); - ne_set_name(inpL, "inpL"); - // inpL = ne_mul(ctx0, inpL, model.others[1]); - inpL = ne_mul(ctx0, ne_repeat(ctx0, model.others[1], inpL), inpL); + inpL = ne_mul(ctx0, inpL, model.others[1]); } lctx.use_buf(ctx0, -1);