From 496752f84a822eacf652c8d79356bbe9967d6474 Mon Sep 17 00:00:00 2001 From: cgli Date: Sun, 11 Feb 2024 18:08:48 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9F=BA=E6=9C=AC=E5=AF=B9=E9=BD=90huggingface?= =?UTF-8?q?=20tokenizers=E7=9A=84BPE=20Tokenizer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/llama_cookbook.md | 37 ++++++++++++++++++++++++++---- src/model.cpp | 2 +- tools/fastllm_pytools/hf_model.py | 30 ++++++++++++++++-------- tools/fastllm_pytools/torch2flm.py | 14 ++++++++++- tools/scripts/alpaca2flm.py | 8 ++++--- 5 files changed, 72 insertions(+), 19 deletions(-) diff --git a/docs/llama_cookbook.md b/docs/llama_cookbook.md index 4a5097e3..8c4c225e 100644 --- a/docs/llama_cookbook.md +++ b/docs/llama_cookbook.md @@ -8,7 +8,13 @@ LLaMA类模型有着基本相同的结构,但权重和prompt构造有差异。 以下配置方案根据模型的源代码整理,不保证模型推理结果与原版完全一致。 -## 修改脚本并转换 +## 修改方式 + +目前,转换脚本和两行加速方式均可用于llama类模型。但无论采用哪一种方式,都需要预留足够的内存(可以用swap空间)。 + +在float16模式下,转换时约需要4×参数量+1GB的空闲内存。 + +### 转换脚本 这里以支持推理各类Llama结构的基座模型为例,介绍如何应用本文档。 @@ -40,17 +46,36 @@ LLaMA类模型有着基本相同的结构,但权重和prompt构造有差异。 如需添加Token ID而非字符串(类似baichuan-chat模型),可以使用“”的格式添加。 +* 执行脚本 + +```shell +python3 tools/alpaca2flm.py [输出文件名] [精度] [原始模型名称或路径] +``` + ### 两行加速 ```python + conf = model.config.__dict__ + conf["model_type"] = "llama" llm.from_hf(model, tokenizer, pre_prompt = "", user_role = "", bot_role = "", history_sep = "", dtype = dtype) ``` +## 对齐 + +如果想使fastllm模型和原版transformers模型基本一致,最主要的操作是对齐tokenizer。 +如果模型使用了huggingface 加速版本的Tokenizers(即模型目录中包含`tokenizer.json`并优先使用),目前的转换脚本**仅在从本地文件转换时,能够对齐tokenizer**。 + +注意检查原始tokenizer的`encode()`方法返回的结果前面是否会加空格。如果原始tokenizer没有加空格,则需要设置: + +```python + conf["tokenizer_add_dummy_prefix"] = False +``` + ## Base Model -见上方“[修改方案](#修改方案)”。 +见上方“[修改方案](#修改方式)”。 一部分模型需要制定bos_token_id,假设bos_token_id为1则可以配置如下: @@ -96,10 +121,12 @@ python3 tools/internlm2flm.py internlm-7b-int4.flm float16 internlm/internlm-cha ```python conf = model.config.__dict__ conf["model_type"] = "llama" + conf["tokenizer_add_dummy_prefix"] = False torch2flm.tofile(exportPath, model, tokenizer, pre_prompt = "", user_role = "Human: ", bot_role = "\n\nAssistant: ", history_sep = "", dtype = dtype) ``` +XVERSE-13B-Chat V1 版本需要对输入做NFKC规范化,fastllm暂不支持,因此需要使用原始tokenizer. ### 其他 llama1 系列 @@ -174,7 +201,7 @@ python3 tools/internlm2flm.py internlm-7b-int4.flm float16 internlm/internlm-cha ```python torch2flm.tofile(exportPath, model, tokenizer, pre_prompt="The following is a conversation between a human and an AI assistant namely YuLan, developed by GSAI, Renmin University of China. " \ - "The AI assistant gives helpful, detailed, and polite answers to the user's questions.\n" + "The AI assistant gives helpful, detailed, and polite answers to the user's questions.\n", user_role="[|Human|]:", bot_role="\n[|AI|]:", history_sep="\n", dtype=dtype) ``` @@ -185,7 +212,7 @@ python3 tools/internlm2flm.py internlm-7b-int4.flm float16 internlm/internlm-cha ```python torch2flm.tofile(exportPath, model, tokenizer, - pre_prompt="Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" + pre_prompt="Below is an instruction that describes a task. " \ + "Write a response that appropriately completes the request.\n\n", user_role="### Instruction:\n", bot_role="\n\n### Response:", history_sep="\n", dtype=dtype) ``` diff --git a/src/model.cpp b/src/model.cpp index 85c0b9cd..d51b1b82 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -88,7 +88,7 @@ namespace fastllm { model = (basellm*)(new ChatGLMModel()); } else if (modelType == "moss") { model = (basellm*)(new MOSSModel()); - model->weight.tokenizer.type = Tokenizer::TokenizerType::NORMAL; + model->weight.tokenizer.type = Tokenizer::TokenizerType::BPE; model->eos_token_id = 106068; } else if (modelType == "baichuan") { model = (basellm*)(new LlamaModel()); diff --git a/tools/fastllm_pytools/hf_model.py b/tools/fastllm_pytools/hf_model.py index 735bdccd..aa2afbe1 100644 --- a/tools/fastllm_pytools/hf_model.py +++ b/tools/fastllm_pytools/hf_model.py @@ -1,5 +1,6 @@ from fastllm_pytools import llm; import ctypes; +import builtins, os, json import numpy as np import torch from transformers import PreTrainedTokenizerFast @@ -118,20 +119,31 @@ def create(model, else: tokenizer = tokenizer.tokenizer if (hasattr(tokenizer, "sp_model")): - piece_size = tokenizer.sp_model.piece_size(); + piece_size = tokenizer.sp_model.piece_size() for i in range(piece_size): llm.fastllm_lib.add_tokenizer_word_llm_model(model_handle, tokenizer.sp_model.id_to_piece(i).encode(), i, ctypes.c_float(tokenizer.sp_model.get_score(i))); else: - vocab = tokenizer.get_vocab(); + merges = {} + if (modelInfo["model_type"] == "moss"): + merges = {("".join(bpe_tokens), token_index) for bpe_tokens, token_index in sorted(tokenizer.bpe_ranks.items(), key=lambda kv: kv[1])} + elif isinstance(tokenizer, PreTrainedTokenizerFast): + tokenizer_file = tokenizer.name_or_path + tokenizer.vocab_files_names['tokenizer_file'] + if os.path.exists(tokenizer_file): + with open(tokenizer_file, "r", encoding='utf-8') as f: + bpe_merges = json.load(f)["model"]["merges"] + bpe_merges = [pair.replace(" ", "") for pair in bpe_merges] + merges = builtins.dict(zip(bpe_merges, range(0, -len(bpe_merges), -1))) + vocab = tokenizer.get_vocab() for v in vocab.keys(): + score = merges[v] if v in merges else 1.0 if (modelInfo["model_type"] == "moss"): - vv = [(ord(c) if c not in tokenizer.byte_decoder else tokenizer.byte_decoder[c]) for c in v]; - llm.fastllm_lib.add_tokenizer_word_llm_model(model_handle, vv, vocab[v], ctypes.c_float(1.0)); + s = [(ord(c) if c not in tokenizer.byte_decoder else tokenizer.byte_decoder[c]) for c in v] + llm.fastllm_lib.add_tokenizer_word_llm_model(model_handle, s, vocab[v], ctypes.c_float(score)); elif (modelInfo["model_type"] == "qwen"): llm.fastllm_lib.add_tokenizer_word_llm_model(model_handle, v, vocab[v], ctypes.c_float(1.0)); else: - llm.fastllm_lib.add_tokenizer_word_llm_model(model_handle, v.encode(), vocab[v], ctypes.c_float(1.0)); + llm.fastllm_lib.add_tokenizer_word_llm_model(model_handle, v.encode(), vocab[v], ctypes.c_float(score)); weight_type_dict = {} module_dict = {} @@ -157,13 +169,13 @@ def create(model, to_data_type = 0 if (cur_weight_type == 1): - to_data_type = fastllm_data_type_dict[dtype]; + to_data_type = fastllm_data_type_dict[dtype] if (to_data_type == 7): - ori_data_type = 7; - ori_np_data_type = np.float16; + ori_data_type = 7 + ori_np_data_type = np.float16 elif (cur_weight_type == 2): # TODO bfloat - to_data_type = 0; + to_data_type = 0 weight_name = key if hasattr(model, "peft_config"): diff --git a/tools/fastllm_pytools/torch2flm.py b/tools/fastllm_pytools/torch2flm.py index 7ae6e333..e8113098 100644 --- a/tools/fastllm_pytools/torch2flm.py +++ b/tools/fastllm_pytools/torch2flm.py @@ -1,4 +1,5 @@ import struct +import builtins, os, json import numpy as np import torch from transformers import PreTrainedTokenizerFast @@ -174,9 +175,20 @@ def tofile(exportPath, fo.write(struct.pack('i', i)) fo.write(struct.pack('f', float(tokenizer.sp_model.get_score(i)))) else: + merges = {} + if (modelInfo["model_type"] == "moss"): + merges = {("".join(bpe_tokens), token_index) for bpe_tokens, token_index in sorted(tokenizer.bpe_ranks.items(), key=lambda kv: kv[1])} + elif isinstance(tokenizer, PreTrainedTokenizerFast): + tokenizer_file = tokenizer.name_or_path + tokenizer.vocab_files_names['tokenizer_file'] + if os.path.exists(tokenizer_file): + with open(tokenizer_file, "r", encoding='utf-8') as f: + bpe_merges = json.load(f)["model"]["merges"] + bpe_merges = [pair.replace(" ", "") for pair in bpe_merges] + merges = builtins.dict(zip(bpe_merges, range(0, -len(bpe_merges), -1))) vocab = tokenizer.get_vocab() fo.write(struct.pack('i', len(vocab))) for v in vocab.keys(): + score = merges[v] if v in merges else 1.0 if (modelInfo["model_type"] == "moss"): s = [(ord(c) if c not in tokenizer.byte_decoder else tokenizer.byte_decoder[c]) for c in v] elif (modelInfo["model_type"] == "qwen"): @@ -187,7 +199,7 @@ def tofile(exportPath, for c in s: fo.write(struct.pack('i', c)) fo.write(struct.pack('i', vocab[v])) - fo.write(struct.pack('f', 1.0)) + fo.write(struct.pack('f', score)) else: fo.write(struct.pack('i', 0)) diff --git a/tools/scripts/alpaca2flm.py b/tools/scripts/alpaca2flm.py index c8b473d2..e8103461 100644 --- a/tools/scripts/alpaca2flm.py +++ b/tools/scripts/alpaca2flm.py @@ -1,11 +1,13 @@ import sys -from transformers import LlamaTokenizer, LlamaForCausalLM +import torch +from transformers import AutoTokenizer, LlamaForCausalLM from fastllm_pytools import torch2flm if __name__ == "__main__": model_name = sys.argv[3] if len(sys.argv) >= 4 else 'minlik/chinese-alpaca-33b-merged' - tokenizer = LlamaTokenizer.from_pretrained(model_name) - model = LlamaForCausalLM.from_pretrained(model_name).float() + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + # `torch_dtype=torch.float16` is set by default, if it will not cause an OOM Error, you can load model in float32. + model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) conf = model.config.__dict__ conf["model_type"] = "llama" dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"