Skip to content

Commit

Permalink
增加baichuan2转换脚本
Browse files Browse the repository at this point in the history
  • Loading branch information
siemon committed Oct 11, 2023
1 parent b942b49 commit bfa14e1
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 1 deletion.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,18 @@ python3 tools/baichuan2flm.py baichuan-13b-int8.flm int8 #导出int8模型
python3 tools/baichuan2flm.py baichuan-13b-int4.flm int4 #导出int4模型
```

### baichuan2模型导出 (默认脚本导出baichuan2-7b-chat模型)

``` sh
# 需要先安装baichuan2环境
# 如果使用自己finetune的模型需要修改baichuan2_2flm.py文件中创建tokenizer, model的代码
# 根据所需的精度,导出相应的模型
cd build
python3 tools/baichuan2_2flm.py baichuan2-7b-fp16.flm float16 #导出float16模型
python3 tools/baichuan2_2flm.py baichuan2-7b-int8.flm int8 #导出int8模型
python3 tools/baichuan2_2flm.py baichuan2-7b-int4.flm int4 #导出int4模型
```

### MOSS模型导出

``` sh
Expand Down
6 changes: 6 additions & 0 deletions tools/fastllm_pytools/torch2flm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ def tofile(exportPath,
modelInfo["user_role"] = ("<FLM_FIX_TOKEN_" + str(model.generation_config.user_token_id) + ">") if hasattr(model.generation_config, "user_token_id") else "";
modelInfo["bot_role"] = ("<FLM_FIX_TOKEN_" + str(model.generation_config.assistant_token_id) + ">") if hasattr(model.generation_config, "assistant_token_id") else "";
modelInfo["history_sep"] = ""
if (modelInfo["model_type"] == "baichuan" and modelInfo["vocab_size"] == 125696):
# Baichuan 2代 7B
modelInfo["pre_prompt"] = ""
modelInfo["user_role"] = ("<FLM_FIX_TOKEN_" + str(model.generation_config.user_token_id) + ">") if hasattr(model.generation_config, "user_token_id") else "";
modelInfo["bot_role"] = ("<FLM_FIX_TOKEN_" + str(model.generation_config.assistant_token_id) + ">") if hasattr(model.generation_config, "assistant_token_id") else "";
modelInfo["history_sep"] = ""
if modelInfo["model_type"] == "qwen":
if modelInfo["chat_format"] == "chatml":
modelInfo["im_end_id"] = tokenizer.im_end_id
Expand Down
24 changes: 24 additions & 0 deletions tools/scripts/baichuan2_2flm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
from fastllm_pytools import torch2flm

if __name__ == "__main__":
modelpath = "baichuan-inc/Baichuan2-7B-Chat"
tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(modelpath, device_map="auto", torch_dtype=torch.float32, trust_remote_code=True)

# normalize lm_head
state_dict = model.state_dict()
state_dict['lm_head.weight'] = torch.nn.functional.normalize(state_dict['lm_head.weight'])
model.load_state_dict(state_dict)

try:
model.generation_config = GenerationConfig.from_pretrained(modelpath)
except:
pass

dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
exportPath = sys.argv[1] if len(sys.argv) >= 2 else "baichuan2-7b-" + dtype + ".flm"
torch2flm.tofile(exportPath, model.to('cpu'), tokenizer, dtype=dtype)
2 changes: 1 addition & 1 deletion tools/scripts/baichuan2flm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
except:
pass
dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
exportPath = sys.argv[1] if len(sys.argv) >= 2 else "baichuan-13b-' + dtype + '.flm"
exportPath = sys.argv[1] if len(sys.argv) >= 2 else "baichuan-13b-" + dtype + ".flm"
torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype)

0 comments on commit bfa14e1

Please sign in to comment.