From cc9552f8d7b7cdecabe3d78e2abea29149bbc295 Mon Sep 17 00:00:00 2001 From: cgli Date: Sat, 2 Mar 2024 18:49:02 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DMiniCPM=20GPU=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96/=E4=BD=8E=E5=86=85=E5=AD=98=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/models/minicpm.h | 9 +++++ src/model.cpp | 2 +- src/models/minicpm.cpp | 71 ++++++++++++++++------------------------ 3 files changed, 38 insertions(+), 44 deletions(-) diff --git a/include/models/minicpm.h b/include/models/minicpm.h index 241e6c72..b1849b6a 100644 --- a/include/models/minicpm.h +++ b/include/models/minicpm.h @@ -15,6 +15,8 @@ namespace fastllm { public: MiniCpmModel(); // 构造函数 + virtual void InitParams(); // 初始化参数信息 + // 推理 virtual int Forward( const Data &inputIds, @@ -65,6 +67,13 @@ namespace fastllm { virtual std::string MakeInput(const std::string &history, int round, const std::string &input); // 根据历史信息和当前输入生成prompt virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); // 根据当前回复更新history + + private: + float embed_scale = 1.f; + + float attention_scale = 1.f / std::sqrt(block_cnt); + + float rms_scale = 1.f / 4096.f; }; } diff --git a/src/model.cpp b/src/model.cpp index d454291a..82566dfe 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -108,7 +108,7 @@ namespace fastllm { model = new LlamaModel(); model->model_type = "qwen"; } else if (modelType=="minicpm") { - model = (basellm*)(new MiniCpmModel()); + model = new MiniCpmModel(); } else if (modelType == "qwen") { model = (basellm *) (new QWenModel()); model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN; diff --git a/src/models/minicpm.cpp b/src/models/minicpm.cpp index 7085b7a6..1ee4aa37 100644 --- a/src/models/minicpm.cpp +++ b/src/models/minicpm.cpp @@ -47,12 +47,6 @@ namespace fastllm { MiniCpmModel::MiniCpmModel() { this->model_type = "minicpm"; - // 默认使用alpaca的提示词和instruction - /* - this->pre_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"; - this->user_role = "### Instruction:\n"; - this->bot_role = "\n\n### Response:"; - */ this->history_sep = ""; this->pre_prompt = ""; this->user_role = ""; @@ -87,6 +81,21 @@ namespace fastllm { weight.embeddingNames.insert("model.embed_tokens.weight"); } + void MiniCpmModel::InitParams() { + basellm::InitParams(); + if (this->weight.dicts.find("scale_emb") != this->weight.dicts.end()) { + this->embed_scale = std::stof(this->weight.dicts["scale_emb"]); + } + if (this->weight.dicts.find("scale_depth") != this->weight.dicts.end()) { + float scale_depth = std::stof(this->weight.dicts["scale_depth"]); + this->attention_scale = scale_depth / std::sqrt(block_cnt); + } + if (this->weight.dicts.find("dim_model_base") != this->weight.dicts.end()) { + int32_t dim_model_base = std::stoi(this->weight.dicts["dim_model_base"]); + this->rms_scale = 1.f / (this->embed_dim / dim_model_base); + } + } + int MiniCpmModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask, const fastllm::Data &positionIds, std::vector> &pastKeyValues, const GenerationConfig &generationConfig, const LastTokensManager &lastTokens, @@ -105,15 +114,8 @@ namespace fastllm { Data attenLastOutput; Data w1, w2, w3; - float scale_emb = std::stof(this->weight.dicts["scale_emb"]); - float scale_depth = std::stof(this->weight.dicts["scale_depth"]); - int32_t num_hidden_layers = std::stoi(this->weight.dicts["num_hidden_layers"]); - int32_t dim_model = std::stoi(this->weight.dicts["hidden_size"]); - int32_t dim_model_base = std::stoi(this->weight.dicts["dim_model_base"]); - float rms_scale = 1.f / (dim_model / dim_model_base); - Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates); - Mul(hiddenStates, scale_emb, hiddenStates); + Mul(hiddenStates, embed_scale, hiddenStates); for (int i = 0; i < block_cnt; i++) { ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".input_layernorm.weight"], @@ -213,10 +215,8 @@ namespace fastllm { attenOutput.Reshape({bsz, seqlen, -1}); Linear(attenOutput, weight[oWeightName], Data(), attenLastOutput); - - Mul(attenLastOutput, scale_depth / std::sqrt(num_hidden_layers), attenLastOutput); + Mul(attenLastOutput, this->attention_scale, attenLastOutput); AddTo(hiddenStates, attenLastOutput); - // 2. mlp RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-5, attenInput); Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1); @@ -224,7 +224,7 @@ namespace fastllm { Silu(w1, w1); MulTo(w1, w3); Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2); - Mul(w2, scale_depth / std::sqrt(num_hidden_layers), w2); + Mul(w2, this->attention_scale, w2); AddTo(hiddenStates, w2); } Data logits, topk; @@ -241,8 +241,8 @@ namespace fastllm { { auto &hiddenStates = *lastHiddenStates; RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-5, hiddenStates); - Mul(hiddenStates, rms_scale, hiddenStates); - Linear(hiddenStates, weight["model.embed_tokens.weight"], Data(), logits); + Mul(hiddenStates, this->rms_scale, hiddenStates); + Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); if (generationConfig.output_logits && retLogits != nullptr) { int size = logits.dims.back(); logits.ToDevice(DataDevice::CPU); @@ -278,16 +278,9 @@ namespace fastllm { Data attenWeights, attenOutput; Data attenLastOutput; Data w1, w2, w3; - - float scale_emb = std::stof(this->weight.dicts["scale_emb"]); - float scale_depth = std::stof(this->weight.dicts["scale_depth"]); - int32_t num_hidden_layers = std::stoi(this->weight.dicts["num_hidden_layers"]); - int32_t dim_model = std::stoi(this->weight.dicts["hidden_size"]); - int32_t dim_model_base = std::stoi(this->weight.dicts["dim_model_base"]); - float rms_scale = 1.f / (dim_model / dim_model_base); Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates); - Mul(hiddenStates, scale_emb, hiddenStates); + Mul(hiddenStates, embed_scale, hiddenStates); int seqlen = hiddenStates.dims[1]; for (int i = 0; i < block_cnt; i++) { ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); @@ -391,7 +384,7 @@ namespace fastllm { PermuteSelf(attenOutput, {1, 0, 2}); Linear(attenOutput, weight[oWeightName], Data(), attenLastOutput); - Mul(attenLastOutput, scale_depth / std::sqrt(num_hidden_layers), attenLastOutput); + Mul(attenLastOutput, this->attention_scale, attenLastOutput); AddTo(hiddenStates, attenLastOutput); // 2. mlp RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-5, attenInput); @@ -400,7 +393,7 @@ namespace fastllm { Silu(w1, w1); MulTo(w1, w3); Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2); - Mul(w2, scale_depth / std::sqrt(num_hidden_layers), w2); + Mul(w2, this->attention_scale, w2); AddTo(hiddenStates, w2); } @@ -418,7 +411,7 @@ namespace fastllm { { auto &hiddenStates = *lastHiddenStates; RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-5, hiddenStates); - Mul(hiddenStates, rms_scale, hiddenStates); + Mul(hiddenStates, this->rms_scale, hiddenStates); Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); if (generationConfig.IsSimpleGreedy()) { TopK(logits, topk, 1); @@ -459,15 +452,8 @@ namespace fastllm { Data attenLastOutput; Data w1, w2, w3; - float scale_emb = std::stof(this->weight.dicts["scale_emb"]); - float scale_depth = std::stof(this->weight.dicts["scale_depth"]); - int32_t num_hidden_layers = std::stoi(this->weight.dicts["num_hidden_layers"]); - int32_t dim_model = std::stoi(this->weight.dicts["hidden_size"]); - int32_t dim_model_base = std::stoi(this->weight.dicts["dim_model_base"]); - float rms_scale = 1.f / (dim_model / dim_model_base); - Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates); - Mul(hiddenStates, scale_emb, hiddenStates); + Mul(hiddenStates, embed_scale, hiddenStates); int seqlen = hiddenStates.dims[1]; for (int i = 0; i < block_cnt; i++) { ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); @@ -594,9 +580,8 @@ namespace fastllm { } Linear(attenOutput, weight[oWeightName], Data(), attenLastOutput); - Mul(attenLastOutput, scale_depth / std::sqrt(num_hidden_layers), attenLastOutput); + Mul(attenLastOutput, this->attention_scale, attenLastOutput); AddTo(hiddenStates, attenLastOutput); - // 2. mlp RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-5, attenInput); Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1); @@ -604,13 +589,13 @@ namespace fastllm { Silu(w1, w1); MulTo(w1, w3); Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2); - Mul(w2, scale_depth / std::sqrt(num_hidden_layers), w2); + Mul(w2, this->attention_scale, w2); AddTo(hiddenStates, w2); } Data logits, curLogit; RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-5, hiddenStates); - Mul(hiddenStates, rms_scale, hiddenStates); + Mul(hiddenStates, this->rms_scale, hiddenStates); Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); std::vector lastRet; int total = 0;