Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

[Model]enable glm4-9b #291

Merged
merged 9 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ Neural Speed supports the following models:
<tr>
<td><a href="https://huggingface.co/THUDM/chatglm-6b" target="_blank" rel="noopener noreferrer">ChatGLM-6B</a>,
<a href="https://huggingface.co/THUDM/chatglm2-6b" target="_blank" rel="noopener noreferrer">ChatGLM2-6B</a>,
<a href="https://huggingface.co/THUDM/chatglm3-6b" target="_blank" rel="noopener noreferrer">ChatGLM3-6B</a></td>
<a href="https://huggingface.co/THUDM/chatglm3-6b" target="_blank" rel="noopener noreferrer">ChatGLM3-6B</a>,
<a href="https://huggingface.co/THUDM/glm-4-9b" target="_blank" rel="noopener noreferrer">GLM-4-9B</a></td>
<td>✅</td>
<td> </td>
<td> </td>
Expand Down
5 changes: 4 additions & 1 deletion neural_speed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def _get_model_type(model_config):
if model_type == "chatglm" and "chatglm3" in model_config._name_or_path:
# due to the same model architecture.
model_type = "chatglm2"

# For GLM4
if model_type == "chatglm" and "glm-4" in model_config._name_or_path:
# due to the same model architecture.
model_type = "chatglm2"
zhentaoyu marked this conversation as resolved.
Show resolved Hide resolved
# for TheBloke/falcon-40b-instruct-GPTQ & TheBloke/Falcon-7B-Instruct-GPTQ
if model_type == "RefinedWebModel" or model_type == "RefinedWeb":
model_type = "falcon"
Expand Down
146 changes: 145 additions & 1 deletion neural_speed/convert/convert_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,148 @@ def write_vocab_gguf(dir_model):
print("Done. Output file: " + fname_out)
print("")

def chatglm4_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
print("GLM-4 converting: ")
list_vars = model.state_dict()
for name in list_vars.keys():
print(name, list_vars[name].shape, list_vars[name].dtype)

fout = open(fname_out, "wb")

print(hparams)

fout.write(struct.pack("i", 0x67676d66))
fout.write(struct.pack("i", 1))

fout.write(struct.pack("i", hparams["padded_vocab_size"]))
fout.write(struct.pack("i", hparams["hidden_size"]))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", hparams["num_attention_heads"]))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", hparams["num_layers"]))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", ftype))
fout.write(struct.pack("i", hparams["seq_length"]))
fout.write(struct.pack("f", 0))
fout.write(struct.pack("f", 0))
fout.write(struct.pack("i", 0))

fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)

fout.write(struct.pack("i", hparams["multi_query_group_num"]))
fout.write(struct.pack("i", hparams["ffn_hidden_size"]))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layernorm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", hparams.get("rope_ratio", 1))) # rope_factor

fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))

fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))


for i in range(hparams["vocab_size"]):
if i < tokenizer.vocab_size:
text = tokenizer.decode([i]).encode('utf-8')
fout.write(struct.pack("i", len(text)))
fout.write(text)
fout.write(struct.pack("f", 0.0 - i))
else:
text = tokenizer.decode([tokenizer.vocab_size - 1]).encode('utf-8')
fout.write(struct.pack("i", len(text)))
fout.write(text)
fout.write(struct.pack("f", -10000))

for name in list_vars.keys():
data = list_vars[name].float().squeeze().numpy()
data = data.astype(np.float32)
if name == "transformer.rotary_pos_emb.inv_freq":
continue
# No gradients for these

n_dims = len(data.shape)
print(name, n_dims, data.shape)

# default type is fp32
ftype_cur = 0
if ftype == 1 and n_dims > 1:
print(" Converting to float16", data.shape, data[:3, :3].tolist())
data = data.astype(np.float16)
ftype_cur = 1
else:
print(" Converting to float32", data.shape, data[:3, :3].tolist() if n_dims > 1 else data[:3].tolist())
data = data.astype(np.float32)

# header
str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
print(str)
fout.write(str)

# data
data.tofile(fout)
if "mlp.dense_h_to_4h" in name:
name_0 = name.replace("dense_h_to_4h", "dense_h_to_4h_0")
name_1 = name.replace("dense_h_to_4h", "dense_h_to_4h_1")
shape_0 = data.shape[0]
half_shape_0 = int(shape_0 / 2)
data_0 = data[0:half_shape_0, :]
data_1 = data[half_shape_0:shape_0, :]

print("Converting: %-75s" % name_0, " shape: ", data_0.shape)
print("Converting: %-75s" % name_1, " shape: ", data_1.shape)

n_dims = len(data_0.shape)
assert (len(data_0.shape) == len(data_1.shape))
# ftype == 0 -> float32, ftype == 1 -> float16
ftype_cur = 0
if ftype != 0:
if name_0[-7:] == ".weight" and n_dims == 2:
print(" to float16".rjust(15))
data_0 = data_0.astype(np.float16)
data_1 = data_1.astype(np.float32)
ftype_cur = 1
else:
print(" to float32".rjust(15))
data_0 = data_0.astype(np.float32)
data_1 = data_1.astype(np.float32)
ftype_cur = 0
else:
if data_0.dtype != np.float32:
print(" to float32".rjust(15))
data_0 = data_0.astype(np.float32)
data_1 = data_1.astype(np.float32)
ftype_cur = 0

str_0 = name_0.encode("utf-8")
fout.write(struct.pack("iii", n_dims, len(str_0), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data_0.shape[n_dims - 1 - i]))
fout.write(str_0)
data_0.tofile(fout)

str_1 = name_1.encode("utf-8")
fout.write(struct.pack("iii", n_dims, len(str_1), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data_1.shape[n_dims - 1 - i]))
fout.write(str_1)
data_1.tofile(fout)

fout.close()

print("Done. Output file: " + fname_out)
print("")

def chatglm3_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
print("ChatGLM-3 converting: ")
Expand Down Expand Up @@ -973,7 +1115,9 @@ def main(args_in: Optional[List[str]] = None) -> None:
# ChatGLM3 shares the same architecture and model config with ChatGLM2
# but its tokenizer further supports system prompts,
# so we can check system token to discriminate ChatGLM3 from ChatGLM2.
if hasattr(tokenizer, "tokenizer") and "<|system|>" in tokenizer.tokenizer.special_tokens:
if hasattr(model.config, "rope_ratio"):
a32543254 marked this conversation as resolved.
Show resolved Hide resolved
chatglm4_convert(model, tokenizer, dir_model, fname_out, ftype, hparams)
elif hasattr(tokenizer, "tokenizer") and "<|system|>" in tokenizer.tokenizer.special_tokens:
if args.format == "GGUF":
chatglm3_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams)
else:
Expand Down
6 changes: 6 additions & 0 deletions neural_speed/models/chatglm/chatglm2.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ static const model_scratch chatglm_mem_req(int n_layers, float scratch_size_rati
static_cast<unsigned long long>(scratch_size_ratio * 2048) * MB,
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
};
case 40:
return {
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
a32543254 marked this conversation as resolved.
Show resolved Hide resolved
static_cast<unsigned long long>(scratch_size_ratio * 2048) * MB,
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
};
default:
MODEL_ASSERT(false);
}
Expand Down
Loading