From 9c018ed765f84a0def2022575b19a1d265c07c5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Fri, 19 Jul 2024 18:20:01 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AF=BB=E5=8F=96hf=E6=A8=A1=E5=9E=8B=E6=97=B6?= =?UTF-8?q?=E9=80=92=E5=BD=92=E8=AF=BB=E5=8F=96config?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/model.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/model.cpp b/src/model.cpp index db5c5a07..4edc6e35 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -514,6 +514,17 @@ namespace fastllm { return std::unique_ptr (model); } + // 将config中的内容递归地加入model->dict中 + void AddDictRecursion(basellm *model, const std::string &pre, const json11::Json &config) { + for (auto &it : config.object_items()) { + if (it.second.is_object()) { + AddDictRecursion(model, pre + it.first + ".", it.second); + } else { + model->weight.AddDict(pre + it.first, it.second.is_string() ? it.second.string_value() : it.second.dump()); + } + } + } + // 从hf文件夹读取,仅支持safetensor格式的模型 std::unique_ptr CreateLLMModelFromHF(const std::string &modelPath, DataType linearDataType, int groupCnt, bool skipTokenizer, const std::string &modelConfig) { @@ -545,9 +556,7 @@ namespace fastllm { if (isJsonModel) { ((GraphLLMModel*)model)->graphLLMModelConfig->Init(modelConfig); } - for (auto &it : config.object_items()) { - model->weight.AddDict(it.first, it.second.is_string() ? it.second.string_value() : it.second.dump()); - } + AddDictRecursion(model, "", config); // 设置eos_token_id if (config["eos_token_id"].is_array()) { for (auto &it : config["eos_token_id"].array_items()) {