Skip to content

Commit

Permalink
修复直接读取Llama3,Qwen2的HF模型时的问题
Browse files Browse the repository at this point in the history
  • Loading branch information
cgli committed Jun 17, 2024
1 parent 7bc30ef commit 2bdcf14
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
35 changes: 25 additions & 10 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,13 +402,37 @@ namespace fastllm {
for (auto &it : config.object_items()) {
model->weight.AddDict(it.first, it.second.dump().c_str());
}
// 设置eos_token_id
if (config["eos_token_id"].is_array()) {
for (auto &it : config["eos_token_id"].array_items()) {
model->eos_token_ids.insert(it.int_value());
}
} else {
model->eos_token_id = config["eos_token_id"].int_value();
}

std::string generatetionConfigFile = path + "generation_config.json";
if (FileExists(generatetionConfigFile)) {
auto generation_config = json11::Json::parse(ReadAllFile(generatetionConfigFile), error);
for (auto &it : generation_config.object_items()) {
if ("eos_token_id" == it.first && it.second.type() == json11::Json::ARRAY)
continue;
model->weight.AddDict(it.first, it.second.dump().c_str());
}
// 更新eos_token_id
if (generation_config["eos_token_id"].is_array()) {
for (auto &it : generation_config["eos_token_id"].array_items()) {
model->eos_token_ids.insert(it.int_value());
}
}
}

// 3. 读取分词
std::string tokenizerConfigFile = path + "tokenizer_config.json";
auto tokenizerConfig = json11::Json::parse(ReadAllFile(tokenizerConfigFile), error);
model->weight.tokenizer.SetTokenizerConfig(tokenizerConfig);
std::string tokenizerClass = tokenizerConfig["tokenizer_class"].string_value();
if (tokenizerClass == "PreTrainedTokenizerFast") {
if (tokenizerClass == "PreTrainedTokenizerFast" || tokenizerClass == "Qwen2Tokenizer") {
// PreTrainedTokenizerFast
std::string tokenizerFile = path + "tokenizer.json";
auto tokenizer = json11::Json::parse(ReadAllFile(tokenizerFile), error);
Expand Down Expand Up @@ -445,15 +469,6 @@ namespace fastllm {
((ChatGLMModel*)model)->bos_token_id = model->weight.tokenizer.GetTokenId("<sop>");
((ChatGLMModel*)model)->tokenizerClass = tokenizerClass;

// 设置eos_token_id
if (config["eos_token_id"].is_array()) {
for (auto &it : config["eos_token_id"].array_items()) {
model->eos_token_ids.insert(it.int_value());
}
} else {
model->eos_token_id = config["eos_token_id"].int_value();
}

// ChatGLM采用拼接token的方法,需要强行指定分割词的TokenID
model->pre_prompt = "";
model->user_role = ("<FLM_FIX_TOKEN_" + std::to_string(model->weight.tokenizer.GetTokenId("<|user|>")) + ">\n");
Expand Down
1 change: 1 addition & 0 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,7 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to
{"content", message.second}
});
}
ret["add_generation_prompt"] = fastllm::JinjaVar{1};
return ret;
}

Expand Down
4 changes: 4 additions & 0 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,10 @@ namespace fastllm {
pastKeyValues.push_back(std::make_pair(Data(DataType::FLOAT32),
Data(DataType::FLOAT32)));
}
if (this->weight.weight.find("lm_head.weight") == this->weight.weight.end()) {
this->weight["lm_head.weight"] = Data();
this->weight["lm_head.weight"].CopyFrom(this->weight["model.embed_tokens.weight"]);
}
Forward(inputIds, attentionMask, positionIds, pastKeyValues);
printf("finish.\n");
}
Expand Down

0 comments on commit 2bdcf14

Please sign in to comment.