Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

[Model]enable glm4-9b #291

Merged
merged 9 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ Neural Speed supports the following models:
<tr>
<td><a href="https://huggingface.co/THUDM/chatglm-6b" target="_blank" rel="noopener noreferrer">ChatGLM-6B</a>,
<a href="https://huggingface.co/THUDM/chatglm2-6b" target="_blank" rel="noopener noreferrer">ChatGLM2-6B</a>,
<a href="https://huggingface.co/THUDM/chatglm3-6b" target="_blank" rel="noopener noreferrer">ChatGLM3-6B</a></td>
<a href="https://huggingface.co/THUDM/chatglm3-6b" target="_blank" rel="noopener noreferrer">ChatGLM3-6B</a>,
<a href="https://huggingface.co/THUDM/glm-4-9b" target="_blank" rel="noopener noreferrer">GLM-4-9B</a></td>
<td>✅</td>
<td> </td>
<td> </td>
Expand Down
5 changes: 4 additions & 1 deletion neural_speed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def _get_model_type(model_config):
if model_type == "chatglm" and "chatglm3" in model_config._name_or_path:
# due to the same model architecture.
model_type = "chatglm2"

# For GLM4
if model_type == "chatglm" and "glm-4" in model_config._name_or_path:
# due to the same model architecture.
model_type = "chatglm2"
zhentaoyu marked this conversation as resolved.
Show resolved Hide resolved
# for TheBloke/falcon-40b-instruct-GPTQ & TheBloke/Falcon-7B-Instruct-GPTQ
if model_type == "RefinedWebModel" or model_type == "RefinedWeb":
model_type = "falcon"
Expand Down
147 changes: 146 additions & 1 deletion neural_speed/convert/convert_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,148 @@ def write_vocab_gguf(dir_model):
print("Done. Output file: " + fname_out)
print("")

def chatglm4_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
print("GLM-4 converting: ")
list_vars = model.state_dict()
for name in list_vars.keys():
print(name, list_vars[name].shape, list_vars[name].dtype)

fout = open(fname_out, "wb")

print(hparams)

fout.write(struct.pack("i", 0x67676d66))
fout.write(struct.pack("i", 1))

fout.write(struct.pack("i", hparams["padded_vocab_size"]))
fout.write(struct.pack("i", hparams["hidden_size"]))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", hparams["num_attention_heads"]))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", hparams["num_layers"]))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", ftype))
fout.write(struct.pack("i", hparams["seq_length"]))
fout.write(struct.pack("f", 0))
fout.write(struct.pack("f", 0))
fout.write(struct.pack("i", 0))

fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)

fout.write(struct.pack("i", hparams["multi_query_group_num"]))
fout.write(struct.pack("i", hparams["ffn_hidden_size"]))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layernorm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", hparams.get("rope_ratio", 1))) # rope_factor

fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))

fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))


for i in range(hparams["vocab_size"]):
if i < tokenizer.vocab_size:
text = tokenizer.decode([i]).encode('utf-8')
fout.write(struct.pack("i", len(text)))
fout.write(text)
fout.write(struct.pack("f", 0.0 - i))
else:
text = tokenizer.decode([tokenizer.vocab_size - 1]).encode('utf-8')
fout.write(struct.pack("i", len(text)))
fout.write(text)
fout.write(struct.pack("f", -10000))

for name in list_vars.keys():
data = list_vars[name].float().squeeze().numpy()
data = data.astype(np.float32)
if name == "transformer.rotary_pos_emb.inv_freq":
continue
# No gradients for these

n_dims = len(data.shape)
print(name, n_dims, data.shape)

# default type is fp32
ftype_cur = 0
if ftype == 1 and n_dims > 1:
print(" Converting to float16", data.shape, data[:3, :3].tolist())
data = data.astype(np.float16)
ftype_cur = 1
else:
print(" Converting to float32", data.shape, data[:3, :3].tolist() if n_dims > 1 else data[:3].tolist())
data = data.astype(np.float32)

# header
str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
print(str)
fout.write(str)

# data
data.tofile(fout)
if "mlp.dense_h_to_4h" in name:
name_0 = name.replace("dense_h_to_4h", "dense_h_to_4h_0")
name_1 = name.replace("dense_h_to_4h", "dense_h_to_4h_1")
shape_0 = data.shape[0]
half_shape_0 = int(shape_0 / 2)
data_0 = data[0:half_shape_0, :]
data_1 = data[half_shape_0:shape_0, :]

print("Converting: %-75s" % name_0, " shape: ", data_0.shape)
print("Converting: %-75s" % name_1, " shape: ", data_1.shape)

n_dims = len(data_0.shape)
assert (len(data_0.shape) == len(data_1.shape))
# ftype == 0 -> float32, ftype == 1 -> float16
ftype_cur = 0
if ftype != 0:
if name_0[-7:] == ".weight" and n_dims == 2:
print(" to float16".rjust(15))
data_0 = data_0.astype(np.float16)
data_1 = data_1.astype(np.float32)
ftype_cur = 1
else:
print(" to float32".rjust(15))
data_0 = data_0.astype(np.float32)
data_1 = data_1.astype(np.float32)
ftype_cur = 0
else:
if data_0.dtype != np.float32:
print(" to float32".rjust(15))
data_0 = data_0.astype(np.float32)
data_1 = data_1.astype(np.float32)
ftype_cur = 0

str_0 = name_0.encode("utf-8")
fout.write(struct.pack("iii", n_dims, len(str_0), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data_0.shape[n_dims - 1 - i]))
fout.write(str_0)
data_0.tofile(fout)

str_1 = name_1.encode("utf-8")
fout.write(struct.pack("iii", n_dims, len(str_1), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data_1.shape[n_dims - 1 - i]))
fout.write(str_1)
data_1.tofile(fout)

fout.close()

print("Done. Output file: " + fname_out)
print("")

def chatglm3_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
print("ChatGLM-3 converting: ")
Expand Down Expand Up @@ -973,7 +1115,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
# ChatGLM3 shares the same architecture and model config with ChatGLM2
# but its tokenizer further supports system prompts,
# so we can check system token to discriminate ChatGLM3 from ChatGLM2.
if hasattr(tokenizer, "tokenizer") and "<|system|>" in tokenizer.tokenizer.special_tokens:
# For GLM4-9B
if model.config.num_layers == 40:
chatglm4_convert(model, tokenizer, dir_model, fname_out, ftype, hparams)
elif hasattr(tokenizer, "tokenizer") and "<|system|>" in tokenizer.tokenizer.special_tokens:
if args.format == "GGUF":
chatglm3_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams)
else:
Expand Down
6 changes: 6 additions & 0 deletions neural_speed/models/chatglm/chatglm2.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ static const model_scratch chatglm_mem_req(int n_layers, float scratch_size_rati
static_cast<unsigned long long>(scratch_size_ratio * 2048) * MB,
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
};
case 40:
return {
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
a32543254 marked this conversation as resolved.
Show resolved Hide resolved
static_cast<unsigned long long>(scratch_size_ratio * 2048) * MB,
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
};
default:
MODEL_ASSERT(false);
}
Expand Down
6 changes: 6 additions & 0 deletions scripts/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,10 @@ def __init__(
if self.model_type == "chatglm" and "chatglm3" in self._config._name_or_path:
# due to the same model architecture.
self.model_type = "chatglm2"
# For GLM4
if self.model_type == "chatglm" and "glm-4" in self._config._name_or_path:
# due to the same model architecture.
self.model_type = "chatglm2"

@property
def config(self):
Expand Down Expand Up @@ -594,6 +598,8 @@ def _create_model(
if init_from_bin != "default_none":
if self.config.model_type == "chatglm" and "chatglm2" in self.config._name_or_path:
model_type = "chatglm2"
elif self.config.model_type == "chatglm" and "glm-4" in self.config._name_or_path:
model_type = "chatglm2"
else:
model_type = self.config.model_type

Expand Down
8 changes: 7 additions & 1 deletion tests/model-test/cpp_graph_inference.sh
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ model_name_map["baichuan13b-gptq"]="Baichuan2-13B-Chat-GPTQ"
model_name_map["mistral-gptq"]="TheBloke/Mistral-7B-Instruct-v0.2-GPTQ"
model_name_map["phi3"]="microsoft/Phi-3-mini-128k-instruct"
model_name_map["llama3"]="meta-llama/Meta-Llama-3-8B"
model_name_map["glm4"]="THUDM/glm-4-9b"


function main() {
Expand Down Expand Up @@ -256,6 +257,11 @@ function main() {
extension=" --model_name chatglm3 --tokenizer $model_path"
requirements_file="$working_dir/neural_speed/models/requirements/chatglm-6b.sh"
input_list=(32 1024)
elif [[ "${model}" == "glm4" ]]; then
quant_script="./build/bin/quant_chatglm2"
convert_script="${convert_script}/convert_chatglm.py"
infer_cmd="./build/bin/run_chatglm2"
input_list=(32 1024)
elif [[ "${model}" == "chatglm-6b" ]]; then
quant_script="./build/bin/quant_chatglm"
convert_script="${convert_script}/convert_chatglm.py"
Expand Down Expand Up @@ -474,7 +480,7 @@ function main() {
$infer_cmd -f "/tf_dataset2/models/nlp_toolkit/whisper-tiny/jfk.wav" -m ${model}-${precision}.bin
else
real_ctx=$ctx # TODO(Zhenzhong): use same ctx for chatglm & baichuan
[[ "${model}" == "chatglm2" || "${model}" == "chatglm-6b" ||
[[ "${model}" == "chatglm2" || "${model}" == "chatglm-6b" || "${model}" == "glm4" ||
"${model}" == "baichuan-13b" || "${model}" == "baichuan2-13b" ]] && real_ctx=2048
if [[ "${model}" == *"gptq" ]]; then
NEURAL_SPEED_VERBOSE=1 OMP_NUM_THREADS=$cores_per_instance numactl -m 0 -C 0-$(($cores_per_instance - 1)) $infer_cmd 2>&1 | tee ${WORKSPACE}/${logs_file} || true &
Expand Down
Loading