From 2bdcf14028f4352ca5f0ee35a47035037e13d290 Mon Sep 17 00:00:00 2001 From: cgli Date: Sun, 16 Jun 2024 20:30:30 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E7=9B=B4=E6=8E=A5=E8=AF=BB?= =?UTF-8?q?=E5=8F=96Llama3=EF=BC=8CQwen2=E7=9A=84HF=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=97=B6=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/model.cpp | 35 +++++++++++++++++++++++++---------- src/models/basellm.cpp | 1 + src/models/llama.cpp | 4 ++++ 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/src/model.cpp b/src/model.cpp index d3939e2c..ff1f7853 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -402,13 +402,37 @@ namespace fastllm { for (auto &it : config.object_items()) { model->weight.AddDict(it.first, it.second.dump().c_str()); } + // 设置eos_token_id + if (config["eos_token_id"].is_array()) { + for (auto &it : config["eos_token_id"].array_items()) { + model->eos_token_ids.insert(it.int_value()); + } + } else { + model->eos_token_id = config["eos_token_id"].int_value(); + } + + std::string generatetionConfigFile = path + "generation_config.json"; + if (FileExists(generatetionConfigFile)) { + auto generation_config = json11::Json::parse(ReadAllFile(generatetionConfigFile), error); + for (auto &it : generation_config.object_items()) { + if ("eos_token_id" == it.first && it.second.type() == json11::Json::ARRAY) + continue; + model->weight.AddDict(it.first, it.second.dump().c_str()); + } + // 更新eos_token_id + if (generation_config["eos_token_id"].is_array()) { + for (auto &it : generation_config["eos_token_id"].array_items()) { + model->eos_token_ids.insert(it.int_value()); + } + } + } // 3. 读取分词 std::string tokenizerConfigFile = path + "tokenizer_config.json"; auto tokenizerConfig = json11::Json::parse(ReadAllFile(tokenizerConfigFile), error); model->weight.tokenizer.SetTokenizerConfig(tokenizerConfig); std::string tokenizerClass = tokenizerConfig["tokenizer_class"].string_value(); - if (tokenizerClass == "PreTrainedTokenizerFast") { + if (tokenizerClass == "PreTrainedTokenizerFast" || tokenizerClass == "Qwen2Tokenizer") { // PreTrainedTokenizerFast std::string tokenizerFile = path + "tokenizer.json"; auto tokenizer = json11::Json::parse(ReadAllFile(tokenizerFile), error); @@ -445,15 +469,6 @@ namespace fastllm { ((ChatGLMModel*)model)->bos_token_id = model->weight.tokenizer.GetTokenId(""); ((ChatGLMModel*)model)->tokenizerClass = tokenizerClass; - // 设置eos_token_id - if (config["eos_token_id"].is_array()) { - for (auto &it : config["eos_token_id"].array_items()) { - model->eos_token_ids.insert(it.int_value()); - } - } else { - model->eos_token_id = config["eos_token_id"].int_value(); - } - // ChatGLM采用拼接token的方法,需要强行指定分割词的TokenID model->pre_prompt = ""; model->user_role = ("weight.tokenizer.GetTokenId("<|user|>")) + ">\n"); diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index dc320f44..2a91a49d 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -984,6 +984,7 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to {"content", message.second} }); } + ret["add_generation_prompt"] = fastllm::JinjaVar{1}; return ret; } diff --git a/src/models/llama.cpp b/src/models/llama.cpp index 2c72a06e..1e94c7ca 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -926,6 +926,10 @@ namespace fastllm { pastKeyValues.push_back(std::make_pair(Data(DataType::FLOAT32), Data(DataType::FLOAT32))); } + if (this->weight.weight.find("lm_head.weight") == this->weight.weight.end()) { + this->weight["lm_head.weight"] = Data(); + this->weight["lm_head.weight"].CopyFrom(this->weight["model.embed_tokens.weight"]); + } Forward(inputIds, attentionMask, positionIds, pastKeyValues); printf("finish.\n"); }