From 97cb612649648667b1b0331caa322dcbb20bd472 Mon Sep 17 00:00:00 2001 From: zhangshiyu <328574108@qq.com> Date: Fri, 26 Jan 2024 11:55:50 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=BE=93=E5=87=BA?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E8=B6=85=E9=95=BF=EF=BC=8C=E4=B8=94=E4=BD=BF?= =?UTF-8?q?=E7=94=A8stream=5Fresponse=5Fraw=E7=9A=84=E6=97=B6=E5=80=99?= =?UTF-8?q?=E6=8A=A5=E9=94=99=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/fastllm_pytools/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/fastllm_pytools/llm.py b/tools/fastllm_pytools/llm.py index 4a3017da..648ec093 100644 --- a/tools/fastllm_pytools/llm.py +++ b/tools/fastllm_pytools/llm.py @@ -182,7 +182,7 @@ def tokenizer_decode_token(self, token_id: int) -> bytes: return cache_result output_buffer_init_len = 256 - if self.thread_local_obj.tokenizer_decode_token__output_buffer is None: + if "tokenizer_decode_token__output_buffer" not in dir(self.thread_local_obj) or self.thread_local_obj.tokenizer_decode_token__output_buffer is None: self.thread_local_obj.tokenizer_decode_token__output_buffer = ctypes.create_string_buffer(output_buffer_init_len) buffer = self.thread_local_obj.tokenizer_decode_token__output_buffer From f79b6da4f31d32e49737cf474b42e3471a6b444f Mon Sep 17 00:00:00 2001 From: zhangshiyu <328574108@qq.com> Date: Mon, 29 Jan 2024 11:01:27 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E6=8A=8Atokenizer=5Fdecode=5Ftoken=5F=5Fou?= =?UTF-8?q?tput=5Fbuffer=E5=92=8Ctokenizer=5Fencode=5Fstring=5F=5Foutput?= =?UTF-8?q?=5Fbuffer=E5=9C=A8=E5=A4=9A=E7=BA=BF=E7=A8=8B=E7=8E=AF=E5=A2=83?= =?UTF-8?q?=E4=B8=AD=E5=8F=AF=E8=83=BD=E5=87=BA=E9=94=99=E7=9A=84=E6=83=85?= =?UTF-8?q?=E5=86=B5=E8=BF=9B=E8=A1=8C=E4=BA=86=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/fastllm_pytools/llm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/fastllm_pytools/llm.py b/tools/fastllm_pytools/llm.py index 648ec093..ff3e620b 100644 --- a/tools/fastllm_pytools/llm.py +++ b/tools/fastllm_pytools/llm.py @@ -116,8 +116,8 @@ def __init__ (self, path : str, # 为了减少重复申请释放buffer对象而使用的线程局部存储区对象池 self.thread_local_obj = threading.local() - self.thread_local_obj.tokenizer_encode_string__output_buffer = None - self.thread_local_obj.tokenizer_decode_token__output_buffer = None + #self.thread_local_obj.tokenizer_encode_string__output_buffer = None + #self.thread_local_obj.tokenizer_decode_token__output_buffer = None # tokenizer_decode_token 输出结果的静态缓存,手工触发构建 # 由于token数量有限且不太多,所以缓存该结果来减少调用较为适合。 @@ -154,7 +154,7 @@ def build_tokenizer_decode_token_cache(self): def tokenizer_encode_string(self, content: str) -> List[int]: output_buffer_init_len = 1024 - if self.thread_local_obj.tokenizer_encode_string__output_buffer is None: + if "tokenizer_encode_string__output_buffer" not in self.thread_local_obj or self.thread_local_obj.tokenizer_encode_string__output_buffer is None: self.thread_local_obj.tokenizer_encode_string__output_buffer = (ctypes.c_int * output_buffer_init_len)() buffer = self.thread_local_obj.tokenizer_encode_string__output_buffer From 2159c65d16effe5f49ddb6cfe06d27a6eeada542 Mon Sep 17 00:00:00 2001 From: zhangshiyu <328574108@qq.com> Date: Mon, 29 Jan 2024 11:25:08 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/fastllm_pytools/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/fastllm_pytools/llm.py b/tools/fastllm_pytools/llm.py index ff3e620b..bf012dfe 100644 --- a/tools/fastllm_pytools/llm.py +++ b/tools/fastllm_pytools/llm.py @@ -154,7 +154,7 @@ def build_tokenizer_decode_token_cache(self): def tokenizer_encode_string(self, content: str) -> List[int]: output_buffer_init_len = 1024 - if "tokenizer_encode_string__output_buffer" not in self.thread_local_obj or self.thread_local_obj.tokenizer_encode_string__output_buffer is None: + if "tokenizer_encode_string__output_buffer" not in dir(self.thread_local_obj) or self.thread_local_obj.tokenizer_encode_string__output_buffer is None: self.thread_local_obj.tokenizer_encode_string__output_buffer = (ctypes.c_int * output_buffer_init_len)() buffer = self.thread_local_obj.tokenizer_encode_string__output_buffer From 264c6647bdc7aff92d9b71234e785c7c961ce043 Mon Sep 17 00:00:00 2001 From: xuhaifeng Date: Tue, 27 Feb 2024 18:01:26 +0800 Subject: [PATCH 4/4] support OpenBMB/MiniCPM --- CMakeLists.txt | 2 +- include/models/minicpm.h | 71 ++ src/model.cpp | 3 + src/models/minicpm.cpp | 1110 +++++++++++++++++++++++++++++ tools/fastllm_pytools/hf_model.py | 2 + 5 files changed, 1187 insertions(+), 1 deletion(-) create mode 100644 include/models/minicpm.h create mode 100644 src/models/minicpm.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b1b4cc6c..4eb7a03b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,7 +33,7 @@ endif() message(STATUS "CMAKE_CXX_FLAGS" ${CMAKE_CXX_FLAGS}) set(FASTLLM_CXX_SOURCES src/fastllm.cpp src/device.cpp src/model.cpp src/executor.cpp - src/devices/cpu/cpudevice.cpp src/devices/cpu/cpudevicebatch.cpp + src/devices/cpu/cpudevice.cpp src/devices/cpu/cpudevicebatch.cpp src/models/minicpm.cpp src/models/chatglm.cpp src/models/moss.cpp src/models/llama.cpp src/models/qwen.cpp src/models/basellm.cpp src/models/glm.cpp) include_directories(include) diff --git a/include/models/minicpm.h b/include/models/minicpm.h new file mode 100644 index 00000000..241e6c72 --- /dev/null +++ b/include/models/minicpm.h @@ -0,0 +1,71 @@ +// +// Created by huangyuyang on 6/1/23. +// + +#ifndef FASTLLM_MINICPM_H +#define FASTLLM_MINICPM_H + +#include "basellm.h" +#include "cmath" + +#include + +namespace fastllm { + class MiniCpmModel: public basellm { + public: + MiniCpmModel(); // 构造函数 + + // 推理 + virtual int Forward( + const Data &inputIds, + const Data &attentionMask, + const Data &positionIds, + std::vector > &pastKeyValues, + const GenerationConfig &generationConfig = GenerationConfig(), + const LastTokensManager &lastTokens = LastTokensManager(), + std::vector *logits = nullptr); + + std::vector ForwardBatch( + int batch, + const Data &inputIds, + const Data &attentionMask, + const Data &positionIds, + std::vector > &pastKeyValues, + const GenerationConfig &generationConfig = GenerationConfig(), + const LastTokensManager &lastTokens = LastTokensManager(), + std::vector *> *logits = nullptr); + + std::vector ForwardBatch( + int batch, + const Data &inputIds, + const std::vector &attentionMask, + const std::vector &positionIds, + const std::vector &seqLens, + std::vector > &pastKeyValues, + const std::vector &generationConfigs, + const LastTokensManager &lastTokens = LastTokensManager(), + std::vector *> *logits = nullptr); + + virtual std::string Response(const std::string& input, + RuntimeResult retCb, + const GenerationConfig &generationConfig = GenerationConfig()); // 根据给出的内容回复 + + virtual void ResponseBatch(const std::vector &inputs, + std::vector &outputs, + RuntimeResultBatch retCb, + const GenerationConfig &generationConfig = GenerationConfig()); + + virtual int LaunchResponseTokens(const std::vector &inputTokens, + const GenerationConfig &generationConfig = GenerationConfig()); // 启动一个response任务,返回分配的handleId + + virtual int FetchResponseTokens(int handelId); // 获取指定handle的输出, -1代表输出结束了 + + virtual void WarmUp(); // 预热 + + virtual std::string MakeInput(const std::string &history, int round, const std::string &input); // 根据历史信息和当前输入生成prompt + + virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); // 根据当前回复更新history + }; +} + +#endif //FASTLLM_MINICPM_H diff --git a/src/model.cpp b/src/model.cpp index 5c919905..156899b6 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -9,6 +9,7 @@ #include "llama.h" #include "qwen.h" #include "glm.h" +#include "minicpm.h" namespace fastllm { void basellm::LoadFromFile(const std::string &fileName) { @@ -103,6 +104,8 @@ namespace fastllm { model->model_type = "internlm"; } else if (modelType == "llama") { model = (basellm*)(new LlamaModel()); + } else if (modelType=="minicpm") { + model = (basellm*)(new MiniCpmModel()); } else if (modelType == "qwen") { model = (basellm *) (new QWenModel()); model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN; diff --git a/src/models/minicpm.cpp b/src/models/minicpm.cpp new file mode 100644 index 00000000..b9f89200 --- /dev/null +++ b/src/models/minicpm.cpp @@ -0,0 +1,1110 @@ +// +// Created by huangyuyang on 6/1/23. +// + +#include "utils.h" + +#include "minicpm.h" + +#include + +#include + +#include + +#ifdef USE_CUDA +#include "fastllm-cuda.cuh" +#endif + +namespace fastllm { + + std::vector GetInterLeavePowerOf3(int n) { + float start = powf(2, -powf(2, -(log2f(n) - 3))); + float ratio = start; + std::vector ret; + for (int i = 0; i < n; i++) { + ret.push_back(start * powf(ratio, i)); + } + return ret; + } + std::vector GetInterleave2(int n) { + int base = 1; + while (base < n) { + base <<= 1; + } + if (base == n) { + return GetInterLeavePowerOf3(n); + } else { + std::vector ret = GetInterLeavePowerOf3(base / 2); + std::vector part2 = GetInterLeavePowerOf3(base); + for (int i = 0; i < n - base / 2; i++) { + ret.push_back(part2[i * 2]); + } + return ret; + } + } + + MiniCpmModel::MiniCpmModel() { + this->model_type = "minicpm"; + + // 默认使用alpaca的提示词和instruction + /* + this->pre_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"; + this->user_role = "### Instruction:\n"; + this->bot_role = "\n\n### Response:"; + */ + this->history_sep = ""; + this->pre_prompt = ""; + this->user_role = ""; + this->bot_role = ""; + + block_cnt = 32; + rotary_dim = 128; + + sin.resize(max_positions); + cos.resize(max_positions); + std::vector invFreq; + for (int i = 0; i < rotary_dim; i += 2) { + invFreq.push_back(1.0 / pow(10000, (float)i / rotary_dim)); + } + for (int i = 0; i < max_positions; i++) { + sin[i].resize(rotary_dim); + cos[i].resize(rotary_dim); + for (int j = 0; j < invFreq.size(); j++) { + sin[i][j] = ::sin((float)i * invFreq[j]); + cos[i][j] = ::cos((float)i * invFreq[j]); + } + } + std::vector fsin, fcos; + for (int i = 0; i < sin.size(); i++) { + for (int j = 0; j < sin[0].size(); j++) { + fsin.push_back(sin[i][j]); + fcos.push_back(cos[i][j]); + } + } + sinData.CopyFrom(Data(DataType::FLOAT32, {(int)this->sin.size(), (int)this->sin[0].size()}, fsin)); + cosData.CopyFrom(Data(DataType::FLOAT32, {(int)this->cos.size(), (int)this->cos[0].size()}, fcos)); + weight.embeddingNames.insert("model.embed_tokens.weight"); + } + + int MiniCpmModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask, + const fastllm::Data &positionIds, std::vector> &pastKeyValues, + const GenerationConfig &generationConfig, const LastTokensManager &lastTokens, + std::vector *retLogits) { + Data alibiData; + if (this->weight.dicts["use_alibi"] == "1") { + std::vector alibi = GetInterleave2(num_attention_heads); + alibiData.CopyFrom(Data(DataType::FLOAT32, {(int) alibi.size()}, alibi)); + } + + int maxLen = inputIds.dims[1]; + Data hiddenStates; + Data attenInput; + Data q, k, v, qkv; + Data attenWeights, attenOutput; + Data attenLastOutput; + Data w1, w2, w3; + + float scale_emb = std::stof(this->weight.dicts["scale_emb"]); + float scale_depth = std::stof(this->weight.dicts["scale_depth"]); + int32_t num_hidden_layers = std::stoi(this->weight.dicts["num_hidden_layers"]); + int32_t dim_model = std::stoi(this->weight.dicts["hidden_size"]); + int32_t dim_model_base = std::stoi(this->weight.dicts["dim_model_base"]); + float rms_scale = 1.f / (dim_model / dim_model_base); + + Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates); + Mul(hiddenStates, scale_emb, hiddenStates); + for (int i = 0; i < block_cnt; i++) { + ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); + RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".input_layernorm.weight"], + 1e-5, attenInput); + std::string qWeightName = "model.layers." + std::to_string(i) + ".self_attn.q_proj.weight"; + std::string kWeightName = "model.layers." + std::to_string(i) + ".self_attn.k_proj.weight"; + std::string vWeightName = "model.layers." + std::to_string(i) + ".self_attn.v_proj.weight"; + std::string qkvWeightName = "model.layers." + std::to_string(i) + ".self_attn.W_pack.weight"; + std::string oWeightName = "model.layers." + std::to_string(i) + ".self_attn.o_proj.weight"; + + // 1.1 Get q, k, v + int bsz = attenInput.dims[0], seqlen = attenInput.dims[1]; + if (weight.weight.find(qkvWeightName) != weight.weight.end()) { + Linear(attenInput, weight[qkvWeightName], Data(), qkv); + int per = qkv.dims.back() / 3; + Split(qkv, -1, 0, per, q); + Split(qkv, -1, per, per * 2, k); + Split(qkv, -1, per * 2, per * 3, v); + } else { + Linear(attenInput, weight[qWeightName], Data(), q); + Linear(attenInput, weight[kWeightName], Data(), k); + Linear(attenInput, weight[vWeightName], Data(), v); + } + + std::vector qkvSize = {bsz, seqlen, num_attention_heads, -1}; + q.Reshape(qkvSize); + k.Reshape(qkvSize); + v.Reshape(qkvSize); + + if (alibiData.dims.size() == 0) { + fastllm::LlamaRotatePosition2D(q, positionIds, sinData, cosData, rotary_dim); + fastllm::LlamaRotatePosition2D(k, positionIds, sinData, cosData, rotary_dim); + } + + qkvSize = {bsz * seqlen, num_attention_heads, -1}; + q.Reshape(qkvSize); + k.Reshape(qkvSize); + v.Reshape(qkvSize); + + PermuteSelf(q, {1, 0, 2}); + PermuteSelf(k, {1, 0, 2}); + PermuteSelf(v, {1, 0, 2}); + + Data &pastKey = pastKeyValues[i].first, &pastValue = pastKeyValues[i].second; + if (GetKVCacheInCPU()) { + pastKey.lockInCPU = true; + pastValue.lockInCPU = true; + } else { + pastKey.ToDevice(DataDevice::CUDA); + pastValue.ToDevice(DataDevice::CUDA); + } + + int unitLen = 64; +#ifdef USE_CUDA + unitLen = 128; +#endif + while ((pastKey.dims.size() == 0 && (pastKey.expansionDims.size() == 0 || k.dims[1] > pastKey.expansionDims[1])) + || (pastKey.dims.size() > 0 && pastKey.dims[1] + k.dims[1] > pastKey.expansionDims[1])) { + std::vector newDims; + if (pastKey.Count(0) == 0 || pastKey.dims.size() == 0) { + newDims = std::vector {k.dims[0], ((k.dims[1] - 1) / unitLen + 1) * unitLen, k.dims[2]}; + } else { + newDims = pastKey.dims; + newDims[1] += ((k.dims[1] - 1) / unitLen + 1) * unitLen; + } + pastKey.Expansion(newDims); + } + while ((pastValue.dims.size() == 0 && (pastValue.expansionDims.size() == 0 || v.dims[1] > pastValue.expansionDims[1])) + || (pastValue.dims.size() > 0 && pastValue.dims[1] + v.dims[1] > pastValue.expansionDims[1])) { + std::vector newDims; + if (pastValue.Count(0) == 0 || pastValue.dims.size() == 0) { + newDims = std::vector {v.dims[0], ((v.dims[1] - 1) / unitLen + 1) * unitLen, v.dims[2]}; + } else { + newDims = pastValue.dims; + newDims[1] += ((v.dims[1] - 1) / unitLen + 1) * unitLen; + } + pastValue.Expansion(newDims); + } + CatDirect(pastKey, k, 1); + CatDirect(pastValue, v, 1); + + // 1.2 Attention + // 1.2.0 q * k^T + MatMulTransB(q, pastKey, attenWeights, 1.0 / sqrt(head_dim)); + attenWeights.Reshape({1, attenWeights.dims[0], attenWeights.dims[1], attenWeights.dims[2]}); + if (alibiData.dims.size() != 0) { + AlibiMask(attenWeights, alibiData, -10000); + } else if (attentionMask.dims.size() != 0) { + AttentionMask(attenWeights, attentionMask, -10000); + } + + Softmax(attenWeights, attenWeights, -1); + MatMul(attenWeights, pastValue, attenOutput); + + attenOutput.Reshape({attenOutput.dims[1], attenOutput.dims[2], attenOutput.dims[3]}); + PermuteSelf(attenOutput, {1, 0, 2}); + attenOutput.Reshape({bsz, seqlen, -1}); + + Linear(attenOutput, weight[oWeightName], Data(), attenLastOutput); + + Mul(attenLastOutput, scale_depth / std::sqrt(num_hidden_layers), attenLastOutput); + AddTo(hiddenStates, attenLastOutput); + + // 2. mlp + RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-5, attenInput); + Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1); + Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.up_proj.weight"], Data(), w3); + Silu(w1, w1); + MulTo(w1, w3); + Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2); + Mul(w2, scale_depth / std::sqrt(num_hidden_layers), w2); + AddTo(hiddenStates, w2); + } + Data logits, topk; + Data tempHiddenStates; + Data *lastHiddenStates; + if (maxLen > 1) { + Split(hiddenStates, 1, maxLen - 1, maxLen, tempHiddenStates); + lastHiddenStates = &tempHiddenStates; + } else { + lastHiddenStates = &hiddenStates; + } + + int lastRet = -1; + { + auto &hiddenStates = *lastHiddenStates; + RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-5, hiddenStates); + Mul(hiddenStates, rms_scale, hiddenStates); + Linear(hiddenStates, weight["model.embed_tokens.weight"], Data(), logits); + if (generationConfig.output_logits && retLogits != nullptr) { + int size = logits.dims.back(); + logits.ToDevice(DataDevice::CPU); + retLogits->resize(size); + memcpy((float*)retLogits->data(), ((float*)logits.cpuData) + (logits.dims[1] - 1) * size, size * logits.unitSize); + } + if (generationConfig.IsSimpleGreedy()) { + TopK(logits, topk, 1); + topk.ToDevice(DataDevice::CPU); + lastRet = (int) (((float *) topk.cpuData)[0] + 1e-3); + } else if (!lastTokens.units.empty()) { + lastRet = LLMSampling(logits, logits.dims[1] - 1, generationConfig, lastTokens.units[0]); + } + } + + return lastRet; + } + + std::vector MiniCpmModel::ForwardBatch(int batch, const fastllm::Data &inputIds, const fastllm::Data &attentionMask, + const fastllm::Data &positionIds, std::vector> &pastKeyValues, + const GenerationConfig &generationConfig, const LastTokensManager &lastTokens, + std::vector *> *retLogits) { + Data alibiData; + if (this->weight.dicts["use_alibi"] == "1") { + std::vector alibi = GetInterleave2(num_attention_heads); + alibiData.CopyFrom(Data(DataType::FLOAT32, {(int) alibi.size()}, alibi)); + } + + int maxLen = inputIds.dims[1]; + Data hiddenStates; + Data attenInput; + Data q, k, v, qkv; + Data attenWeights, attenOutput; + Data attenLastOutput; + Data w1, w2, w3; + + float scale_emb = std::stof(this->weight.dicts["scale_emb"]); + float scale_depth = std::stof(this->weight.dicts["scale_depth"]); + int32_t num_hidden_layers = std::stoi(this->weight.dicts["num_hidden_layers"]); + int32_t dim_model = std::stoi(this->weight.dicts["hidden_size"]); + int32_t dim_model_base = std::stoi(this->weight.dicts["dim_model_base"]); + float rms_scale = 1.f / (dim_model / dim_model_base); + + Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates); + Mul(hiddenStates, scale_emb, hiddenStates); + int seqlen = hiddenStates.dims[1]; + for (int i = 0; i < block_cnt; i++) { + ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); + RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".input_layernorm.weight"], + 1e-5, attenInput); + std::string qWeightName = "model.layers." + std::to_string(i) + ".self_attn.q_proj.weight"; + std::string kWeightName = "model.layers." + std::to_string(i) + ".self_attn.k_proj.weight"; + std::string vWeightName = "model.layers." + std::to_string(i) + ".self_attn.v_proj.weight"; + std::string qkvWeightName = "model.layers." + std::to_string(i) + ".self_attn.W_pack.weight"; + std::string oWeightName = "model.layers." + std::to_string(i) + ".self_attn.o_proj.weight"; + + // 1.1 Get q, k, v + int bsz = attenInput.dims[0], seqlen = attenInput.dims[1]; + if (weight.weight.find(qkvWeightName) != weight.weight.end()) { + Linear(attenInput, weight[qkvWeightName], Data(), qkv); + int per = qkv.dims.back() / 3; + Split(qkv, -1, 0, per, q); + Split(qkv, -1, per, per * 2, k); + Split(qkv, -1, per * 2, per * 3, v); + } else { + Linear(attenInput, weight[qWeightName], Data(), q); + Linear(attenInput, weight[kWeightName], Data(), k); + Linear(attenInput, weight[vWeightName], Data(), v); + } + + std::vector qkvSize = {bsz, seqlen, num_attention_heads, -1}; + q.Reshape(qkvSize); + k.Reshape(qkvSize); + v.Reshape(qkvSize); + + if (alibiData.dims.size() == 0) { + fastllm::LlamaRotatePosition2D(q, positionIds, sinData, cosData, rotary_dim); + fastllm::LlamaRotatePosition2D(k, positionIds, sinData, cosData, rotary_dim); + } + + PermuteSelf(q, {0, 2, 1, 3}); + PermuteSelf(k, {0, 2, 1, 3}); + PermuteSelf(v, {0, 2, 1, 3}); + + qkvSize = {bsz * num_attention_heads, seqlen, -1}; + q.Reshape(qkvSize); + k.Reshape(qkvSize); + v.Reshape(qkvSize); + + Data &pastKey = pastKeyValues[i].first, &pastValue = pastKeyValues[i].second; + if (GetKVCacheInCPU()) { + pastKey.lockInCPU = true; + pastValue.lockInCPU = true; + } else { + pastKey.ToDevice(DataDevice::CUDA); + pastValue.ToDevice(DataDevice::CUDA); + } + + int unitLen = 64; +#ifdef USE_CUDA + unitLen = 128; +#endif + while ((pastKey.dims.size() == 0 && (pastKey.expansionDims.size() == 0 || k.dims[1] > pastKey.expansionDims[1])) + || (pastKey.dims.size() > 0 && pastKey.dims[1] + k.dims[1] > pastKey.expansionDims[1])) { + std::vector newDims; + if (pastKey.Count(0) == 0 || pastKey.dims.size() == 0) { + newDims = std::vector {k.dims[0], ((k.dims[1] - 1) / unitLen + 1) * unitLen, k.dims[2]}; + } else { + newDims = pastKey.dims; + newDims[1] += ((k.dims[1] - 1) / unitLen + 1) * unitLen; + } + pastKey.Expansion(newDims); + } + while ((pastValue.dims.size() == 0 && (pastValue.expansionDims.size() == 0 || v.dims[1] > pastValue.expansionDims[1])) + || (pastValue.dims.size() > 0 && pastValue.dims[1] + v.dims[1] > pastValue.expansionDims[1])) { + std::vector newDims; + if (pastValue.Count(0) == 0 || pastValue.dims.size() == 0) { + newDims = std::vector {v.dims[0], ((v.dims[1] - 1) / unitLen + 1) * unitLen, v.dims[2]}; + } else { + newDims = pastValue.dims; + newDims[1] += ((v.dims[1] - 1) / unitLen + 1) * unitLen; + } + pastValue.Expansion(newDims); + } + + CatDirect(pastKey, k, 1); + CatDirect(pastValue, v, 1); + + // 1.2 Attention + // 1.2.0 q * k^T + MatMulTransB(q, pastKey, attenWeights, 1.0 / sqrt(head_dim)); + attenWeights.Reshape({1, attenWeights.dims[0], attenWeights.dims[1], attenWeights.dims[2]}); + if (alibiData.dims.size() != 0) { + attenWeights.Reshape({-1, num_attention_heads, attenWeights.dims[2], attenWeights.dims[3]}); + AlibiMask(attenWeights, alibiData, -10000); + attenWeights.Reshape({1, -1, attenWeights.dims[2], attenWeights.dims[3]}); + } else if (attentionMask.dims.size() != 0) { + AttentionMask(attenWeights, attentionMask, -10000); + } + Softmax(attenWeights, attenWeights, -1); + MatMul(attenWeights, pastValue, attenOutput); + + attenOutput.Reshape({attenOutput.dims[1], attenOutput.dims[2], attenOutput.dims[3]}); + PermuteSelf(attenOutput, {1, 0, 2}); + attenOutput.Reshape({seqlen, bsz, -1}); + PermuteSelf(attenOutput, {1, 0, 2}); + + Linear(attenOutput, weight[oWeightName], Data(), attenLastOutput); + Mul(attenLastOutput, scale_depth / std::sqrt(num_hidden_layers), attenLastOutput); + AddTo(hiddenStates, attenLastOutput); + // 2. mlp + RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-5, attenInput); + Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1); + Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.up_proj.weight"], Data(), w3); + Silu(w1, w1); + MulTo(w1, w3); + Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2); + Mul(w2, scale_depth / std::sqrt(num_hidden_layers), w2); + AddTo(hiddenStates, w2); + } + + Data logits, topk; + Data tempHiddenStates; + Data *lastHiddenStates; + if (maxLen > 1) { + Split(hiddenStates, 1, maxLen - 1, maxLen, tempHiddenStates); + lastHiddenStates = &tempHiddenStates; + } else { + lastHiddenStates = &hiddenStates; + } + + std::vector lastRet; + { + auto &hiddenStates = *lastHiddenStates; + RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-5, hiddenStates); + Mul(hiddenStates, rms_scale, hiddenStates); + Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); + if (generationConfig.IsSimpleGreedy()) { + TopK(logits, topk, 1); + topk.ToDevice(DataDevice::CPU); + for (int b = 0; b < batch; b++) { + int base = b; + lastRet.push_back((int) (((float *) topk.cpuData)[base * 2] + 1e-3)); + } + } else { + for (int b = 0; b < batch; b++) { + int base = b * logits.dims[1] + logits.dims[1] - 1; + lastRet.push_back(LLMSampling(logits, base, generationConfig, lastTokens.units[b])); + } + } + } + + return lastRet; + } + + std::vector MiniCpmModel::ForwardBatch(int batch, + const Data &inputIds, + const std::vector &attentionMask, + const std::vector &positionIds, + const std::vector &seqLens, + std::vector > &pastKeyValues, + const std::vector &generationConfigs, + const LastTokensManager &lastTokens, + std::vector *> *retLogits) { + Data alibiData; + if (this->weight.dicts["use_alibi"] == "1") { + std::vector alibi = GetInterleave2(num_attention_heads); + alibiData.CopyFrom(Data(DataType::FLOAT32, {(int) alibi.size()}, alibi)); + } + Data hiddenStates; + Data attenInput; + Data q, k, v, qkv; + Data attenWeights, curAttenOutput; + Data attenLastOutput; + Data w1, w2, w3; + + float scale_emb = std::stof(this->weight.dicts["scale_emb"]); + float scale_depth = std::stof(this->weight.dicts["scale_depth"]); + int32_t num_hidden_layers = std::stoi(this->weight.dicts["num_hidden_layers"]); + int32_t dim_model = std::stoi(this->weight.dicts["hidden_size"]); + int32_t dim_model_base = std::stoi(this->weight.dicts["dim_model_base"]); + float rms_scale = 1.f / (dim_model / dim_model_base); + + Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates); + Mul(hiddenStates, scale_emb, hiddenStates); + int seqlen = hiddenStates.dims[1]; + for (int i = 0; i < block_cnt; i++) { + ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); + RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".input_layernorm.weight"], + 1e-5, attenInput); + std::string qWeightName = "model.layers." + std::to_string(i) + ".self_attn.q_proj.weight"; + std::string kWeightName = "model.layers." + std::to_string(i) + ".self_attn.k_proj.weight"; + std::string vWeightName = "model.layers." + std::to_string(i) + ".self_attn.v_proj.weight"; + std::string qkvWeightName = "model.layers." + std::to_string(i) + ".self_attn.W_pack.weight"; + std::string oWeightName = "model.layers." + std::to_string(i) + ".self_attn.o_proj.weight"; + + // 1.1 Get q, k, v + int bsz = attenInput.dims[0], seqlen = attenInput.dims[1]; + if (weight.weight.find(qkvWeightName) != weight.weight.end()) { + Linear(attenInput, weight[qkvWeightName], Data(), qkv); + int per = qkv.dims.back() / 3; + Split(qkv, -1, 0, per, q); + Split(qkv, -1, per, per * 2, k); + Split(qkv, -1, per * 2, per * 3, v); + } else { + Linear(attenInput, weight[qWeightName], Data(), q); + Linear(attenInput, weight[kWeightName], Data(), k); + Linear(attenInput, weight[vWeightName], Data(), v); + } + + Data attenOutput = Data(DataType::FLOAT32); + int total = 0; + std::vector curKs, curVs, curQs; + curKs.resize(batch); + curVs.resize(batch); + curQs.resize(batch); + for (int b = 0; b < batch; b++) { + Split(k, 1, total, total + seqLens[b], curKs[b]); + Split(v, 1, total, total + seqLens[b], curVs[b]); + Split(q, 1, total, total + seqLens[b], curQs[b]); + total += seqLens[b]; + } + + for (int b = 0; b < batch; b++) { + auto &q = curQs[b], &k = curKs[b], &v = curVs[b]; + + std::vector qkvSize = {bsz, seqLens[b], num_attention_heads, -1}; + q.Reshape(qkvSize); + k.Reshape(qkvSize); + v.Reshape(qkvSize); + + if (alibiData.dims.size() == 0) { + fastllm::LlamaRotatePosition2D(q, *positionIds[b], sinData, cosData, rotary_dim); + fastllm::LlamaRotatePosition2D(k, *positionIds[b], sinData, cosData, rotary_dim); + } + + PermuteSelf(q, {0, 2, 1, 3}); + PermuteSelf(k, {0, 2, 1, 3}); + PermuteSelf(v, {0, 2, 1, 3}); + + qkvSize = {bsz * num_attention_heads, seqLens[b], -1}; + q.Reshape(qkvSize); + k.Reshape(qkvSize); + v.Reshape(qkvSize); + + Data &pastKey = *pastKeyValues[b * block_cnt + i].first, &pastValue = *pastKeyValues[b * block_cnt + i].second; + if (GetKVCacheInCPU()) { + pastKey.lockInCPU = true; + pastValue.lockInCPU = true; + } else { + pastKey.ToDevice(DataDevice::CUDA); + pastValue.ToDevice(DataDevice::CUDA); + } + + int unitLen = 64; +#ifdef USE_CUDA + unitLen = 128; +#endif + while ((pastKey.dims.size() == 0 && + (pastKey.expansionDims.size() == 0 || k.dims[1] > pastKey.expansionDims[1])) + || (pastKey.dims.size() > 0 && pastKey.dims[1] + k.dims[1] > pastKey.expansionDims[1])) { + std::vector newDims; + if (pastKey.Count(0) == 0 || pastKey.dims.size() == 0) { + newDims = std::vector{k.dims[0], ((k.dims[1] - 1) / unitLen + 1) * unitLen, k.dims[2]}; + } else { + newDims = pastKey.dims; + newDims[1] += ((k.dims[1] - 1) / unitLen + 1) * unitLen; + } + pastKey.Expansion(newDims); + } + while ((pastValue.dims.size() == 0 && + (pastValue.expansionDims.size() == 0 || v.dims[1] > pastValue.expansionDims[1])) + || (pastValue.dims.size() > 0 && pastValue.dims[1] + v.dims[1] > pastValue.expansionDims[1])) { + std::vector newDims; + if (pastValue.Count(0) == 0 || pastValue.dims.size() == 0) { + newDims = std::vector{v.dims[0], ((v.dims[1] - 1) / unitLen + 1) * unitLen, v.dims[2]}; + } else { + newDims = pastValue.dims; + newDims[1] += ((v.dims[1] - 1) / unitLen + 1) * unitLen; + } + pastValue.Expansion(newDims); + } + + CatDirect(pastKey, k, 1); + CatDirect(pastValue, v, 1); + + // 1.2 Attention + // 1.2.0 q * k^T + MatMulTransB(q, pastKey, attenWeights, 1.0 / sqrt(head_dim)); + attenWeights.Reshape({1, attenWeights.dims[0], attenWeights.dims[1], attenWeights.dims[2]}); + if (alibiData.dims.size() != 0) { + AlibiMask(attenWeights, alibiData, -10000); + } else if (attentionMask[b] != nullptr) { + AttentionMask(attenWeights, *attentionMask[b], -10000); + } + + Softmax(attenWeights, attenWeights, -1); + MatMul(attenWeights, pastValue, curAttenOutput); + curAttenOutput.Reshape({curAttenOutput.dims[1], curAttenOutput.dims[2], curAttenOutput.dims[3]}); + PermuteSelf(curAttenOutput, {1, 0, 2}); + curAttenOutput.Reshape({seqLens[b], bsz, -1}); + PermuteSelf(curAttenOutput, {1, 0, 2}); + if (attenOutput.dims.size() == 0) { + std::vector dims = curAttenOutput.dims; + dims[1] = total; + attenOutput.Expansion(dims); + } + CatDirect(attenOutput, curAttenOutput, 1); + } + + Linear(attenOutput, weight[oWeightName], Data(), attenLastOutput); + Mul(attenLastOutput, scale_depth / std::sqrt(num_hidden_layers), attenLastOutput); + AddTo(hiddenStates, attenLastOutput); + + // 2. mlp + RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-5, attenInput); + Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1); + Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.up_proj.weight"], Data(), w3); + Silu(w1, w1); + MulTo(w1, w3); + Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2); + Mul(w2, scale_depth / std::sqrt(num_hidden_layers), w2); + AddTo(hiddenStates, w2); + } + + Data logits, curLogit; + RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-5, hiddenStates); + Mul(hiddenStates, rms_scale, hiddenStates); + Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); + std::vector lastRet; + int total = 0; + for (int b = 0; b < batch; b++) { + Split(logits, 1, total + seqLens[b] - 1, total + seqLens[b], curLogit); + if (generationConfigs[b].output_logits && retLogits != nullptr && (*retLogits)[b] != nullptr) { + curLogit.ToDevice(DataDevice::CPU); + (*retLogits)[b]->resize(curLogit.Count(0)); + memcpy((float*)(*retLogits)[b]->data(), (float*)curLogit.cpuData, curLogit.GetBytes()); + } + if (generationConfigs[b].IsSimpleGreedy()) { + Data topk; + TopK(curLogit, topk, 1); + topk.ToDevice(DataDevice::CPU); + lastRet.push_back((int) (((float *) topk.cpuData)[0] + 1e-3)); + } else { + lastRet.push_back(LLMSampling(curLogit, 0, generationConfigs[b], lastTokens.units[b])); + } + total += seqLens[b]; + } + return lastRet; + } + + std::string MiniCpmModel::Response(const std::string& input, RuntimeResult retCb, + const GenerationConfig &generationConfig) { +#ifdef USE_CUDA + FastllmCudaClearBigBuffer(); +#endif +//auto st = std::chrono::system_clock::now(); +#ifdef PY_API + size_t pos = input.rfind("time_stamp:"); + std::string prompt = (generationConfig.enable_hash_id && pos != -1)? input.substr(0, pos):input; + size_t hash_id = std::hash{}(input); + Data inputIds = this->weight.tokenizer.Encode(prompt); +#else + Data inputIds = this->weight.tokenizer.Encode(input); +#endif + std::vector ids; + for (int i = 0; i < inputIds.Count(0); i++) { + ids.push_back(((float*)inputIds.cpuData)[i]); + } + int seqLen = ids.size(); + inputIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, ids)); + + std::vector vmask = std::vector (seqLen * seqLen, 0); + std::vector vpids = std::vector (seqLen, 0); + for (int i = 0; i < seqLen; i++) { + vpids[i] = i; + for (int j = i + 1; j < seqLen; j++) { + vmask[i * seqLen + j] = 1; + } + } + + Data attentionMask = Data(DataType::FLOAT32, {seqLen, seqLen}, vmask); + Data positionIds = Data(DataType::FLOAT32, {1, seqLen}, vpids); + + std::vector > pastKeyValues; + for (int i = 0; i < block_cnt; i++) { + pastKeyValues.push_back(std::make_pair(Data(DataType::FLOAT32), + Data(DataType::FLOAT32))); + } + + std::string retString = ""; + int len = seqLen; + std::vector results; + int index = 0; + + LastTokensManager tokens (1, generationConfig.last_n); + while (true) { + auto st = std::chrono::system_clock::now(); + + int ret = Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens); + tokens.units[0].Push(ret); + if (ret == eos_token_id) { + break; + } + + results.push_back(ret); + std::string curString = weight.tokenizer.Decode(Data(DataType::FLOAT32, {(int)results.size()}, results)).c_str(); + retString += curString; + if (retCb) +#ifdef PY_API + { + if (generationConfig.enable_hash_id) { + std::stringstream ss; + ss << retString << "hash_id:" << hash_id; + retCb(index, pybind11::bytes(ss.str())); + } else { + retCb(index, pybind11::bytes(retString)); + } + } +#else + retCb(index, curString.c_str()); +#endif + index++; + + if (index == generationConfig.output_token_limit) { + break; + } + results.clear(); + + attentionMask.ToDevice(DataDevice::CPU); + positionIds.ToDevice(DataDevice::CPU); + inputIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, {(float)ret})); + attentionMask = Data(); + positionIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, {(float)len})); + //if (do_sample) { + // tokenPenaltyManager.InsertToken(ret); + //} + len++; + if (index == generationConfig.output_token_limit) { + break; + } + + //printf("spend %f s.\n", GetSpan(st, std::chrono::system_clock::now())); + } + if (retCb) +#ifdef PY_API + { + if (generationConfig.enable_hash_id) { + std::stringstream ss; + ss << retString << "hash_id:" << hash_id; + retCb(-1, pybind11::bytes(ss.str())); + } else { + retCb(-1, pybind11::bytes(retString)); + } + } +#else + retCb(-1, retString.c_str()); +#endif + + return retString; + } + + void MiniCpmModel::ResponseBatch(const std::vector &inputs, std::vector &outputs, + RuntimeResultBatch retCb, + const GenerationConfig &generationConfig) { +#ifdef USE_CUDA + FastllmCudaClearBigBuffer(); +#endif +#ifdef PY_API + std::vector prompts; + std::vector < size_t > hash_ids; + for (auto _input: inputs){ + size_t hash_id = std::hash{}(_input); + hash_ids.push_back(hash_id); + + size_t pos = _input.rfind("time_stamp:"); + std::string prompt = (generationConfig.enable_hash_id && pos != -1) ? _input.substr(0, pos) : _input; + prompts.push_back(prompt); + } +#else + std::vector prompts = inputs; +#endif + int batch = prompts.size(); + outputs.clear(); + outputs.resize(batch, ""); + + std::vector inputTokens; + std::vector seqLens; + inputTokens.resize(batch); + seqLens.resize(batch); + int maxLen = 0; + for (int i = 0; i < batch; i++) { + inputTokens[i].CopyFrom(this->weight.tokenizer.Encode(prompts[i])); + maxLen = std::max(maxLen, (int)inputTokens[i].Count(0)); + seqLens[i] = (int)inputTokens[i].Count(0); + } + + std::vector ids = std::vector (batch * maxLen, 0); + std::vector vpids = std::vector (batch * maxLen, 0); + std::vector vmask = std::vector (batch * maxLen * maxLen, 0); + for (int i = 0; i < batch; i++) { + Data &tokens = inputTokens[i]; + int len = tokens.Count(0), base = maxLen - len; + for (int j = 0; j < len; j++) { + ids[i * maxLen + base + j] = ((float*)tokens.cpuData)[j]; + } + for (int j = 0; j < len; j++) { + vpids[i * maxLen + base + j] = j; + } + + std::fill(vmask.data() + i * maxLen * maxLen, + vmask.data() + i * maxLen * maxLen + (maxLen - len) * maxLen, 1.0); + for (int j = maxLen - len; j < maxLen; j++) { + std::fill(vmask.data() + i * maxLen * maxLen + j * maxLen, + vmask.data() + i * maxLen * maxLen + j * maxLen + maxLen - len, 1.0); + } + for (int j = 0; j < len; j++) { + for (int k = j + 1; k < len; k++) { + vmask[i * maxLen * maxLen + (base + j) * maxLen + base + k] = 1; + } + } + } + + Data inputIds = Data(DataType::FLOAT32, {batch, maxLen}, ids); + Data attentionMask = Data(DataType::FLOAT32, {batch, maxLen, maxLen}, vmask); + Data positionIds = Data(DataType::FLOAT32, {batch, maxLen}, vpids); + + std::vector > pastKeyValues; + for (int i = 0; i < block_cnt; i++) { + pastKeyValues.push_back(std::make_pair(Data(DataType::FLOAT32), + Data(DataType::FLOAT32))); + } + + std::string retString = ""; + std::vector lens = seqLens; + std::vector isEnding = std::vector (batch, false); + std::vector results; + int index = 0; + + LastTokensManager tokensManager (batch, generationConfig.last_n); + while (true) { + auto st = std::chrono::system_clock::now(); + std::vector ret = ForwardBatch(batch, inputIds, attentionMask, positionIds, pastKeyValues, + generationConfig, tokensManager); + for (int i = 0; i < batch; i++) { + tokensManager.units[i].Push(ret[i]); + } + std::vector fret; + std::vector results; + int endingCount = 0; + std::vector curStrings; + for (int i = 0; i < batch; i++) { + fret.push_back(ret[i]); + if (ret[i] == eos_token_id) { + isEnding[i] = true; + } + if (isEnding[i]) { + curStrings.push_back(""); + endingCount++; + continue; + } + results.push_back(ret[i]); + std::string curString = weight.tokenizer.Decode( + Data(DataType::FLOAT32, {(int) results.size()}, results)).c_str(); + outputs[i] += curString; + curStrings.push_back(curString); + results.clear(); + } + + if (endingCount == batch) { + break; + } + if (retCb) +#ifdef PY_API + { + if (generationConfig.enable_hash_id) { + std::vector rtnStrings; + for (size_t i=0; i rtnStrings; + for (size_t i=0; i pids = std::vector (batch); + std::vector vmasks = std::vector (batch * maxLen, 0.0f); + for (int i = 0; i < batch; i++) { + pids[i] = lens[i]; + lens[i]++; + for (int j = 0; j < maxLen - lens[i]; j++) { + vmasks[i * maxLen + j] = 1.0f; + } + } + positionIds.ToDevice(DataDevice::CPU); + attentionMask.ToDevice(DataDevice::CPU); + attentionMask.CopyFrom(Data(DataType::FLOAT32, {batch, 1, maxLen}, vmasks)); + inputIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, fret)); + positionIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, pids)); + if (index == generationConfig.output_token_limit) { + break; + } + + //printf("spend %f s.\n", GetSpan(st, std::chrono::system_clock::now())); + } + if (retCb) +#ifdef PY_API + { + if (generationConfig.enable_hash_id) { + std::vector rtnStrings; + for (size_t i=0; i rtnStrings; + for (size_t i=0; i > pastKeyValues; + for (int i = 0; i < block_cnt; i++) { + pastKeyValues.push_back(std::make_pair(Data(DataType::FLOAT32), + Data(DataType::FLOAT32))); + } + Forward(inputIds, attentionMask, positionIds, pastKeyValues); + printf("finish.\n"); + } + + int MiniCpmModel::LaunchResponseTokens(const std::vector &inputTokens, + const GenerationConfig &generationConfig) { + mainLoopLocker.lock(); + + if (mainLoop == nullptr) { + if (mainLoop == nullptr) { + mainLoop = new std::thread([](MiniCpmModel *model) { + while (true) { + std::vector attentionMasks; + std::vector positionIds; + std::vector > pastKeyValues; + std::vector ids; + std::vector seqLens; + std::vector generationConfigs; + LastTokensManager tokensManager; + std::vector * > logits; + model->dictLocker.lock(); + for (auto &it: model->responseContextDict.dicts) { + if (it.second->isEnding) { + continue; + } + generationConfigs.push_back(it.second->generationConfig); + if (it.second->generationConfig.output_logits) { + it.second->resultLogits.push(new std::vector ()); + logits.push_back(it.second->resultLogits.back()); + } else { + logits.push_back(nullptr); + } + tokensManager.units.push_back(it.second->tokens); + if (it.second->preTokens == 0) { + int seqLen = it.second->currentTokens.size(); + for (int i = 0; i < it.second->currentTokens.size(); i++) { + ids.push_back(it.second->currentTokens[i]); + } + + seqLens.push_back(seqLen); + + std::vector vmask = std::vector (seqLen * seqLen, 0); + std::vector vpids = std::vector (seqLen, 0); + for (int i = 0; i < seqLen; i++) { + vpids[i] = i; + for (int j = i + 1; j < seqLen; j++) { + vmask[i * seqLen + j] = 1; + } + } + it.second->intParams["len"] = seqLen; + + attentionMasks.push_back(new Data(DataType::FLOAT32, {seqLen, seqLen}, vmask)); + positionIds.push_back(new Data(DataType::FLOAT32, {2, seqLen}, vpids)); + } else { + int ret = it.second->currentTokens[0]; + seqLens.push_back(1); + ids.push_back(ret); + attentionMasks.push_back(nullptr); + positionIds.push_back(new Data(DataType::FLOAT32, {1, 1}, {(float)it.second->intParams["len"]})); + it.second->intParams["len"]++; + } + + it.second->preTokens += seqLens.back(); + for (int i = 0; i < model->block_cnt; i++) { + pastKeyValues.push_back(std::make_pair(&it.second->pastKeyValues[i].first, + &it.second->pastKeyValues[i].second)); + } + } + + if (seqLens.size() > 0) { + model->dictLocker.unlock(); +#ifdef USE_CUDA + FastllmCudaClearBigBuffer(); +#endif + Data inputIds = Data(DataType::FLOAT32, {1, (int) ids.size()}, ids); + std::vector ret; + ret = model->ForwardBatch(seqLens.size(), inputIds, attentionMasks, + positionIds, seqLens, pastKeyValues, generationConfigs, + tokensManager, &logits); + model->dictLocker.lock(); + int idx = 0; + for (auto &it: model->responseContextDict.dicts) { + if (it.second->isEnding) { + continue; + } + int curRet = ret[idx++]; + if (curRet == model->eos_token_id) { + it.second->isEnding = true; + } else { + auto itStopTk = it.second->generationConfig.stop_token_ids.find(curRet); + if (itStopTk != it.second->generationConfig.stop_token_ids.end()) { + it.second->isEnding = true; + } + } + if (it.second->isEnding == false) { + it.second->currentTokens = std::vector{curRet}; + it.second->resultTokenQueue.push(curRet); + it.second->tokens.Push(curRet); + it.second->curTokens++; + if (it.second->curTokens == it.second->generationConfig.output_token_limit) { + it.second->isEnding = true; + } + } + } + } + + for (int i = 0; i < attentionMasks.size(); i++) { + delete attentionMasks[i]; + } + for (int i = 0; i < positionIds.size(); i++) { + delete positionIds[i]; + } + + model->dictLocker.unlock(); + MySleep(0); + } + }, this); + } + } + mainLoopLocker.unlock(); + + dictLocker.lock(); + int handleId = responseContextDict.CreateHandle(); + ResponseContext *context = responseContextDict.GetHandle(handleId); + context->Init(this->block_cnt); + + context->currentTokens = inputTokens; + //context->currentTokens.insert(context->currentTokens.begin(), this->bos_token_id); + context->generationConfig = generationConfig; + context->tokens = LastTokensUnit(generationConfig.last_n); + dictLocker.unlock(); + return handleId; + } + + int MiniCpmModel::FetchResponseTokens(int handleId) { + dictLocker.lock(); + ResponseContext *context = responseContextDict.GetHandle(handleId); + if (context == nullptr) { + dictLocker.unlock(); + return -1; + } else { + while (true) { + if (context->resultTokenQueue.size() > 0) { + int ret = context->resultTokenQueue.front(); + context->resultTokenQueue.pop(); + dictLocker.unlock(); + return ret; + } else { + if (context->isEnding) { + responseContextDict.RemoveHandle(handleId); + dictLocker.unlock(); + return -1; + } + } + dictLocker.unlock(); + MySleep(0); + dictLocker.lock(); + } + } + } +} diff --git a/tools/fastllm_pytools/hf_model.py b/tools/fastllm_pytools/hf_model.py index e50e8cb3..49daa040 100644 --- a/tools/fastllm_pytools/hf_model.py +++ b/tools/fastllm_pytools/hf_model.py @@ -41,6 +41,8 @@ def create(model, modelInfo["bot_role"] = bot_role if (history_sep): modelInfo["history_sep"] = history_sep + if modelInfo["architectures"] == ["MiniCPMForCausalLM"]: + modelInfo["model_type"] = "minicpm" if (modelInfo["model_type"] == "baichuan"): if (hasattr(model, "model") and hasattr(model.model, "get_alibi_mask")): # Baichuan / Baichuan2 13B