Skip to content

Commit

Permalink
支持保存直接读取safetrensors得到的llama类模型为flm格式,并加载推理
Browse files Browse the repository at this point in the history
  • Loading branch information
cgli authored and TylunasLi committed Jul 6, 2024
1 parent afa47d7 commit 2c07254
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 3 deletions.
2 changes: 1 addition & 1 deletion include/devices/cpu/alivethreadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ namespace fastllm {
auto duration = std::chrono::duration_cast<std::chrono::microseconds> (std::chrono::system_clock::now() - lastRunTime);
double gap = double(duration.count()) * std::chrono::microseconds::period::num / std::chrono::microseconds::period::den;
if (gap > 3) {
std::this_thread::sleep_for(std::chrono::seconds(0));
std::this_thread::sleep_for(std::chrono::microseconds(2));
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1976,6 +1976,8 @@ namespace fastllm {
}
tokenizer.SetSpecialTokens(specialTokens);
}
if (this->dicts.find("chat_template") != this->dicts.end())
tokenizer.chatTemplate = this->dicts["chat_template"];

int len = buffer.ReadInt();
for (int i = 0; i < len; i++) {
Expand Down
7 changes: 6 additions & 1 deletion src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,8 @@ namespace fastllm {
std::string tokenizerConfigFile = path + "tokenizer_config.json";
auto tokenizerConfig = json11::Json::parse(ReadAllFile(tokenizerConfigFile), error);
model->weight.tokenizer.SetTokenizerConfig(tokenizerConfig);
if (!model->weight.tokenizer.chatTemplate.empty() && model->weight.dicts.find("chat_template") == model->weight.dicts.end())
model->weight.AddDict("chat_template", model->weight.tokenizer.chatTemplate);
std::string tokenizerClass = tokenizerConfig["tokenizer_class"].string_value();
if (tokenizerClass == "PreTrainedTokenizerFast"
|| tokenizerClass == "Qwen2Tokenizer"
Expand All @@ -439,10 +441,13 @@ namespace fastllm {
spTokens[it["content"].string_value()] = it["id"].int_value();
}
model->weight.tokenizer.SetSpecialTokens(spTokens);
if (!spTokens.empty())
model->weight.AddDict("tokenizer_has_special_tokens", "1");

if (!tokenizer["decoder"].is_null() && !tokenizer["decoder"]["type"].is_null() &&
tokenizer["decoder"]["type"].string_value() == "ByteLevel") {
model->weight.tokenizer.byteAsChar = true;
model->weight.AddDict("tokenizer_byte_as_char", "True");
}
} else if (tokenizerClass == "ChatGLM4Tokenizer") {
// GLM4御用的分词
Expand Down Expand Up @@ -515,7 +520,7 @@ namespace fastllm {
auto config = json11::Json::parse(ReadAllFile(configFile), error);
basellm *model = CreateModelWithType(config["model_type"].string_value());
for (auto &it : config.object_items()) {
model->weight.AddDict(it.first, it.second.dump().c_str());
model->weight.AddDict(it.first, it.second.is_string() ? it.second.string_value() : it.second.dump());
}
// 设置eos_token_id
if (config["eos_token_id"].is_array()) {
Expand Down
6 changes: 5 additions & 1 deletion src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ namespace fastllm {
std::string mergeQkvWeightName = "model.layers." + std::to_string(i) + ".self_attn.mergeqkv.weight";
std::string mergeQkvBiasName = "model.layers." + std::to_string(i) + ".self_attn.mergeqkv.bias";

if (weight.weight.find(qkvWeightName) != weight.weight.end()) {
if (weight.weight.find(qkvWeightName) != weight.weight.end() || weight.weight.find(mergeQkvWeightName) != weight.weight.end()) {
mergeQKV = true;
break;
} else {
Expand Down Expand Up @@ -214,6 +214,10 @@ namespace fastllm {
std::string w3WeightName = "model.layers." + std::to_string(i) + ".mlp.up_proj.weight";
std::string swigluWeightName = "model.layers." + std::to_string(i) + ".mlp.gateup_proj.weight";

if (weight.weight.find(swigluWeightName) != weight.weight.end()) {
mergeQKV = true;
break;
}
Data &w1 = weight.weight[w1WeightName], &w3 = weight.weight[w3WeightName];
if ((w1.dataType == DataType::INT4_GROUP && w1.dims[1] % w1.groupCnt != 0) ||
(w3.dataType == DataType::INT4_GROUP && w3.dims[1] % w3.groupCnt != 0)) {
Expand Down

0 comments on commit 2c07254

Please sign in to comment.