Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

[FFN Fusion] Support FFN_fusion for the ChatGLM2 #142

Merged
merged 11 commits into from
Mar 4, 2024
108 changes: 94 additions & 14 deletions neural_speed/convert/convert_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def bytes_to_unicode():


class SentencePieceVocab:

def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
added_tokens: Dict[str, int]
Expand Down Expand Up @@ -149,11 +150,11 @@ def chatglm2_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams
print("ChatGLM-2.gguf converting: ")
list_vars = model.state_dict()
for name in list_vars.keys():
print(name, list_vars[name].shape, list_vars[name].dtype)
print("%-80s" % name, list_vars[name].shape, list_vars[name].dtype)

print(hparams)

gguf_file = fname_out + '.gguf'
gguf_file = fname_out
gguf_writer = gguf.GGUFWriter(gguf_file, "chatglm2")

arch = "chatglm2."
Expand Down Expand Up @@ -285,35 +286,68 @@ def write_vocab_gguf(dir_model):
print("gguf: get tensor metadata")
for name in list_vars.keys():
data = list_vars[name].squeeze().numpy()

print("Processing variable: " + name + " with shape: ", data.shape)
if 'inv_freq' in name:
print("Converting: %-75s" % name, " shape: %-15s" % str(data.shape))
continue

print("Converting: %-75s" % name, " shape: %-15s" % str(data.shape), end=" ")
n_dims = len(data.shape)

# ftype == 0 -> float32, ftype == 1 -> float16
ftype_cur = 0
if ftype != 0:
if name[-7:] == ".weight" and n_dims == 2:
print(" Converting to float16")
print(" to float16".rjust(15))
data = data.astype(np.float16)
ftype_cur = 1
else:
print(" Converting to float32")
print(" to float32".rjust(15))
data = data.astype(np.float32)
ftype_cur = 0
else:
if data.dtype != np.float32:
print(" Converting to float32")
print(" to float32".rjust(15))
data = data.astype(np.float32)
ftype_cur = 0

# print(f"[{i+1:{padi}d}/{len(model)}]
# Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4}")

gguf_writer.add_tensor(name, data)

if "mlp.dense_h_to_4h" in name:
name_0 = name.replace("dense_h_to_4h", "dense_h_to_4h_0")
name_1 = name.replace("dense_h_to_4h", "dense_h_to_4h_1")
shape_0 = data.shape[0]
half_shape_0 = int(shape_0 / 2)
data_0 = data[0:half_shape_0, :]
data_1 = data[half_shape_0:shape_0, :]

print("Converting: %-75s" % name_0, " shape: %-15s" % str(data_0.shape))
print("Converting: %-75s" % name_1, " shape: %-15s" % str(data_1.shape))

n_dims = len(data_0.shape)
assert (len(data_0.shape) == len(data_1.shape))
# ftype == 0 -> float32, ftype == 1 -> float16
ftype_cur = 0
if ftype != 0:
if name_0[-7:] == ".weight" and n_dims == 2:
print(" to float16".rjust(15))
data_0 = data_0.astype(np.float16)
data_1 = data_1.astype(np.float32)
ftype_cur = 1
else:
print(" to float32".rjust(15))
data_0 = data_0.astype(np.float32)
data_1 = data_1.astype(np.float32)
ftype_cur = 0
else:
if data_0.dtype != np.float32:
print(" to float32".rjust(15))
data_0 = data_0.astype(np.float32)
data_1 = data_1.astype(np.float32)
ftype_cur = 0

gguf_writer.add_tensor(name_0, data_0)
gguf_writer.add_tensor(name_1, data_1)

print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
Expand Down Expand Up @@ -363,9 +397,9 @@ def chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor

fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))

fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
Expand Down Expand Up @@ -419,10 +453,56 @@ def chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(str)

# data
data.tofile(fout)

if "mlp.dense_h_to_4h" in name:
name_0 = name.replace("dense_h_to_4h", "dense_h_to_4h_0")
name_1 = name.replace("dense_h_to_4h", "dense_h_to_4h_1")
shape_0 = data.shape[0]
half_shape_0 = int(shape_0 / 2)
data_0 = data[0:half_shape_0, :]
data_1 = data[half_shape_0:shape_0, :]

print("Converting: %-75s" % name_0, " shape: ", data_0.shape)
print("Converting: %-75s" % name_1, " shape: ", data_1.shape)

n_dims = len(data_0.shape)
assert (len(data_0.shape) == len(data_1.shape))
# ftype == 0 -> float32, ftype == 1 -> float16
ftype_cur = 0
if ftype != 0:
if name_0[-7:] == ".weight" and n_dims == 2:
print(" to float16".rjust(15))
data_0 = data_0.astype(np.float16)
data_1 = data_1.astype(np.float32)
ftype_cur = 1
else:
print(" to float32".rjust(15))
data_0 = data_0.astype(np.float32)
data_1 = data_1.astype(np.float32)
ftype_cur = 0
else:
if data_0.dtype != np.float32:
print(" to float32".rjust(15))
data_0 = data_0.astype(np.float32)
data_1 = data_1.astype(np.float32)
ftype_cur = 0

str_0 = name_0.encode("utf-8")
fout.write(struct.pack("iii", n_dims, len(str_0), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data_0.shape[n_dims - 1 - i]))
fout.write(str_0)
data_0.tofile(fout)

str_1 = name_1.encode("utf-8")
fout.write(struct.pack("iii", n_dims, len(str_1), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data_1.shape[n_dims - 1 - i]))
fout.write(str_1)
data_1.tofile(fout)

fout.close()

print("Done. Output file: " + fname_out)
Expand Down
40 changes: 22 additions & 18 deletions neural_speed/models/chatglm/chatglm2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i

// self-attention
cur = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps);
cur = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[0], cur), cur);
cur = ne_mul(ctx0, cur, model.layers[il].norm[0]);
{
// compute QKV
cur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur);
cur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[1], cur), cur);
cur = ne_add(ctx0, cur, model.layers[il].attn[1]);

struct ne_tensor* query_layer =
ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1],
Expand Down Expand Up @@ -298,19 +298,25 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i

// mlp.forward
struct ne_tensor* mlp_output = ne_rms_norm(ctx0, hidden_states, hparams.rms_norm_eps);
ne_set_name(mlp_output, "mlp_output");
// mlp_output = ne_mul(ctx0, mlp_output, model.layers[il].norm[1]);
mlp_output = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[1], mlp_output), mlp_output);

mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[0], mlp_output);
struct ne_tensor* x0 = ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1], 0);
x0 = ne_silu(ctx0, x0);
struct ne_tensor* x1 = ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1],
mlp_output->ne[0] / 2 * ne_element_size(mlp_output));
ne_set_name(x0, "x0");
ne_set_name(x1, "x1");
mlp_output = ne_mul(ctx0, x0, x1);
mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[1], mlp_output);
mlp_output = ne_mul(ctx0, mlp_output, model.layers[il].norm[1]);

if (model.layers[il].ffn_fusion &&
bestla_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[2]->data, model.layers[il].ffn[1]->data,
model.layers[il].ffn[3]->data, N, int(cur->ne[0] / 2),
model.layers[il].ffn[2]->ne[1], model.layers[il].ffn[1]->ne[1])) {
mlp_output =
ne_ffn_silu(ctx0, model.layers[il].ffn[2], model.layers[il].ffn[1], model.layers[il].ffn[3], mlp_output);
} else {
// mlp.forward
mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[0], mlp_output);
struct ne_tensor* x0 =
ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1], 0);
x0 = ne_silu(ctx0, x0);
struct ne_tensor* x1 = ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1],
mlp_output->ne[0] / 2 * ne_element_size(mlp_output));
mlp_output = ne_mul(ctx0, x0, x1);
mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[1], mlp_output);
}

#ifdef NS_TP_MODEL
if (enable_tp) {
Expand All @@ -327,9 +333,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
// norm
{
inpL = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps);
ne_set_name(inpL, "inpL");
// inpL = ne_mul(ctx0, inpL, model.others[1]);
inpL = ne_mul(ctx0, ne_repeat(ctx0, model.others[1], inpL), inpL);
inpL = ne_mul(ctx0, inpL, model.others[1]);
}

lctx.use_buf(ctx0, -1);
Expand Down
2 changes: 1 addition & 1 deletion neural_speed/models/chatglm/chatglm2.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ enum chatglm2_model {
static const model_scratch chatglm_mem_req(int n_layers) {
switch (n_layers) {
case 28:
return {2048ull * MB, 2048ull * MB, 4096ull * MB};
return {4096ull * MB, 4096ull * MB, 8192ull * MB};
default:
MODEL_ASSERT(false);
}
Expand Down
25 changes: 17 additions & 8 deletions neural_speed/models/chatglm/chatglm2_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,13 @@ void CHATGLM2::init(const char* path_model, model_context* ctx, int n_gpu_layer_
model.hparams = ml->file_loaders.at(0)->hparams;
model_file_version file_version = ml->file_loaders.at(0)->file_version;
auto& hparams = model.hparams;
n_ff = 4 * hparams.n_embd;
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult);
fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot);
fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
fprintf(stderr, "%s: hparams.n_vocab = %u\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: hparams.n_embd = %u\n", __func__, hparams.n_embd);
fprintf(stderr, "%s: hparams.n_mult = %u\n", __func__, hparams.n_mult);
fprintf(stderr, "%s: hparams.n_head = %u\n", __func__, hparams.n_head);
fprintf(stderr, "%s: hparams.n_layer = %u\n", __func__, hparams.n_layer);
fprintf(stderr, "%s: hparams.n_rot = %u\n", __func__, hparams.n_rot);
fprintf(stderr, "%s: hparams.ffn_hidden_size = %u\n", __func__, hparams.ffn_hidden_size);
fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size());
n_embd = hparams.n_embd;
n_vocab = hparams.n_vocab;
Expand Down Expand Up @@ -149,6 +148,15 @@ void CHATGLM2::load(model_context* ctx, model_progress_callback progress_callbac
layer.ffn[1] =
ml->get_tensor(layers_i + ".mlp.dense_4h_to_h.weight", {uint32_t(hparams.ffn_hidden_size), n_embd}, backend);

if (ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight") &&
ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight")) {
layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight",
{n_embd, uint32_t(hparams.ffn_hidden_size)}, backend);
layer.ffn[3] = ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight",
{n_embd, uint32_t(hparams.ffn_hidden_size)}, backend);
layer.ffn_fusion = true;
}

// kv-cache
layer.k_cache = nullptr; // kv-cache will be init later in model_utils
layer.v_cache = nullptr; // kv-cache will be init later in model_utils
Expand All @@ -161,6 +169,7 @@ void CHATGLM2::load(model_context* ctx, model_progress_callback progress_callbac
}
}

fprintf(stderr, "%s: layers[0].ffn_fusion = %u\n", __func__, model.layers[0].ffn_fusion);
// print memory requirements
// this is the total memory required to run the inference
const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory
Expand Down
2 changes: 2 additions & 0 deletions neural_speed/models/model_utils/model_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ struct model_layer {

struct ne_tensor* k_cache;
struct ne_tensor* v_cache;

bool ffn_fusion = false;
};

typedef int32_t model_pos;
Expand Down
2 changes: 1 addition & 1 deletion tests/model-test/cpp_graph_inference.sh
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ function main() {
infer_cmd="./build/bin/run_dolly"
elif [[ "${model}" == "chatglm2" ]]; then
quant_script="./build/bin/quant_chatglm2"
convert_script="${convert_script}/convert_chatglm.py"
convert_script="${convert_script}/convert_chatglm.py --format=GGUF"
infer_cmd="./build/bin/run_chatglm2"
elif [[ "${model}" == "chatglm-6b" ]]; then
quant_script="./build/bin/quant_chatglm"
Expand Down
Loading