From 8c61f189ad91608c52b20c563e7a39695903451f Mon Sep 17 00:00:00 2001 From: siemon Date: Wed, 18 Oct 2023 18:06:29 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0llama.cpp=EF=BC=8C=E5=BA=94?= =?UTF-8?q?=E7=94=A8=E4=B8=80=E4=BA=9Bchatglm.cpp=E4=B8=AD=E7=9A=84?= =?UTF-8?q?=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/llama.cpp | 138 +++++++++++++++++++++++++++---------------- 1 file changed, 87 insertions(+), 51 deletions(-) diff --git a/src/models/llama.cpp b/src/models/llama.cpp index 92565cbb..c3446225 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -91,6 +91,7 @@ namespace fastllm { alibiData.CopyFrom(Data(DataType::FLOAT32, {(int) alibi.size()}, alibi)); } + int maxLen = inputIds.dims[1]; Data hiddenStates; Data attenInput; Data q, k, v, qkv; @@ -143,6 +144,14 @@ namespace fastllm { PermuteSelf(v, {1, 0, 2}); Data &pastKey = pastKeyValues[i].first, &pastValue = pastKeyValues[i].second; + if (GetKVCacheInCPU()) { + pastKey.lockInCPU = true; + pastValue.lockInCPU = true; + } else { + pastKey.ToDevice(DataDevice::CUDA); + pastValue.ToDevice(DataDevice::CUDA); + } + int unitLen = 64; #ifdef USE_CUDA unitLen = 128; @@ -200,28 +209,34 @@ namespace fastllm { Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2); AddTo(hiddenStates, w2); } - - RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-6, hiddenStates); - Data logits; - Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); - logits.ToDevice(DataDevice::CPU); + Data logits, topk; + Data tempHiddenStates; + Data *lastHiddenStates; + if (maxLen > 1) { + Split(hiddenStates, 1, maxLen - 1, maxLen, tempHiddenStates); + lastHiddenStates = &tempHiddenStates; + } else { + lastHiddenStates = &hiddenStates; + } int lastRet = -1; - if (generationConfig.output_logits && retLogits != nullptr) { - int size = logits.dims.back(); - logits.ToDevice(DataDevice::CPU); - retLogits->resize(size); - memcpy((float*)retLogits->data(), ((float*)logits.cpuData) + (logits.dims[1] - 1) * size, size * logits.unitSize); - } - if (generationConfig.IsSimpleGreedy()) { - std::pair ret = std::make_pair(-1e9, -1); - int base = logits.dims[1] - 1; - for (int i = 0; i < logits.dims.back(); i++) { - ret = max(ret, std::make_pair(((float*)logits.cpuData)[base * logits.dims.back() + i], i)); - } - lastRet = ret.second; - } else if (!lastTokens.units.empty()) { - lastRet = LLMSampling(logits, logits.dims[1] - 1, generationConfig, lastTokens.units[0]); + { + auto &hiddenStates = *lastHiddenStates; + RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-6, hiddenStates); + Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); + if (generationConfig.output_logits && retLogits != nullptr) { + int size = logits.dims.back(); + logits.ToDevice(DataDevice::CPU); + retLogits->resize(size); + memcpy((float*)retLogits->data(), ((float*)logits.cpuData) + (logits.dims[1] - 1) * size, size * logits.unitSize); + } + if (generationConfig.IsSimpleGreedy()) { + TopK(logits, topk, 1); + topk.ToDevice(DataDevice::CPU); + lastRet = (int) (((float *) topk.cpuData)[0] + 1e-3); + } else if (!lastTokens.units.empty()) { + lastRet = LLMSampling(logits, logits.dims[1] - 1, generationConfig, lastTokens.units[0]); + } } return lastRet; @@ -237,6 +252,7 @@ namespace fastllm { alibiData.CopyFrom(Data(DataType::FLOAT32, {(int) alibi.size()}, alibi)); } + int maxLen = inputIds.dims[1]; Data hiddenStates; Data attenInput; Data q, k, v, qkv; @@ -290,6 +306,14 @@ namespace fastllm { v.Reshape(qkvSize); Data &pastKey = pastKeyValues[i].first, &pastValue = pastKeyValues[i].second; + if (GetKVCacheInCPU()) { + pastKey.lockInCPU = true; + pastValue.lockInCPU = true; + } else { + pastKey.ToDevice(DataDevice::CUDA); + pastValue.ToDevice(DataDevice::CUDA); + } + int unitLen = 64; #ifdef USE_CUDA unitLen = 128; @@ -351,25 +375,33 @@ namespace fastllm { AddTo(hiddenStates, w2); } - RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-6, hiddenStates); - Data logits; - Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); - logits.ToDevice(DataDevice::CPU); + Data logits, topk; + Data tempHiddenStates; + Data *lastHiddenStates; + if (maxLen > 1) { + Split(hiddenStates, 1, maxLen - 1, maxLen, tempHiddenStates); + lastHiddenStates = &tempHiddenStates; + } else { + lastHiddenStates = &hiddenStates; + } std::vector lastRet; - if (generationConfig.IsSimpleGreedy()) { - for (int b = 0; b < batch; b++) { - int base = b * logits.dims[1] + logits.dims[1] - 1; - std::pair ret = std::make_pair(-1e9, -1); - for (int i = 0; i < logits.dims.back(); i++) { - ret = max(ret, std::make_pair(((float *) logits.cpuData)[base * logits.dims.back() + i], i)); + { + auto &hiddenStates = *lastHiddenStates; + RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-6, hiddenStates); + Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); + if (generationConfig.IsSimpleGreedy()) { + TopK(logits, topk, 1); + topk.ToDevice(DataDevice::CPU); + for (int b = 0; b < batch; b++) { + int base = b; + lastRet.push_back((int) (((float *) topk.cpuData)[base * 2] + 1e-3)); + } + } else { + for (int b = 0; b < batch; b++) { + int base = b * logits.dims[1] + logits.dims[1] - 1; + lastRet.push_back(LLMSampling(logits, base, generationConfig, lastTokens.units[b])); } - lastRet.push_back(ret.second); - } - } else { - for (int b = 0; b < batch; b++) { - int base = b * logits.dims[1] + logits.dims[1] - 1; - lastRet.push_back(LLMSampling(logits, base, generationConfig, lastTokens.units[b])); } } @@ -460,6 +492,14 @@ namespace fastllm { v.Reshape(qkvSize); Data &pastKey = *pastKeyValues[b * block_cnt + i].first, &pastValue = *pastKeyValues[b * block_cnt + i].second; + if (GetKVCacheInCPU()) { + pastKey.lockInCPU = true; + pastValue.lockInCPU = true; + } else { + pastKey.ToDevice(DataDevice::CUDA); + pastValue.ToDevice(DataDevice::CUDA); + } + int unitLen = 64; #ifdef USE_CUDA unitLen = 128; @@ -528,31 +568,27 @@ namespace fastllm { AddTo(hiddenStates, w2); } + Data logits, curLogit; RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-6, hiddenStates); - Data logits; Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); - logits.ToDevice(DataDevice::CPU); std::vector lastRet; int total = 0; for (int b = 0; b < batch; b++) { + Split(logits, 1, total + seqLens[b] - 1, total + seqLens[b], curLogit); if (generationConfigs[b].output_logits && retLogits != nullptr && (*retLogits)[b] != nullptr) { - int base = (total + seqLens[b] - 1); - (*retLogits)[b]->resize(logits.dims.back()); - memcpy((float*)(*retLogits)[b]->data(), (float*)(logits.cpuData + base * logits.dims.back() * logits.unitSize), logits.dims.back() * logits.unitSize); + curLogit.ToDevice(DataDevice::CPU); + (*retLogits)[b]->resize(curLogit.Count(0)); + memcpy((float*)(*retLogits)[b]->data(), (float*)curLogit.cpuData, curLogit.GetBytes()); } if (generationConfigs[b].IsSimpleGreedy()) { - std::pair ret = std::make_pair(-1e9, -1); - int base = (total + seqLens[b] - 1); - total += seqLens[b]; - for (int i = 0; i < logits.dims.back(); i++) { - ret = max(ret, std::make_pair(((float *) logits.cpuData)[base * logits.dims.back() + i], i)); - } - lastRet.push_back(ret.second); + Data topk; + TopK(curLogit, topk, 1); + topk.ToDevice(DataDevice::CPU); + lastRet.push_back((int) (((float *) topk.cpuData)[0] + 1e-3)); } else { - int base = (total + seqLens[b] - 1); - total += seqLens[b]; - lastRet.push_back(LLMSampling(logits, base, generationConfigs[b], lastTokens.units[b])); + lastRet.push_back(LLMSampling(curLogit, 0, generationConfigs[b], lastTokens.units[b])); } + total += seqLens[b]; } return lastRet; }