Skip to content

Commit

Permalink
Merge pull request ztxz16#351 from siemonchan/master
Browse files Browse the repository at this point in the history
更新llama.cpp,应用一些chatglm.cpp中的优化
  • Loading branch information
ztxz16 authored Oct 19, 2023
2 parents c4df37d + 8c61f18 commit 8e9ede2
Showing 1 changed file with 87 additions and 51 deletions.
138 changes: 87 additions & 51 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ namespace fastllm {
alibiData.CopyFrom(Data(DataType::FLOAT32, {(int) alibi.size()}, alibi));
}

int maxLen = inputIds.dims[1];
Data hiddenStates;
Data attenInput;
Data q, k, v, qkv;
Expand Down Expand Up @@ -143,6 +144,14 @@ namespace fastllm {
PermuteSelf(v, {1, 0, 2});

Data &pastKey = pastKeyValues[i].first, &pastValue = pastKeyValues[i].second;
if (GetKVCacheInCPU()) {
pastKey.lockInCPU = true;
pastValue.lockInCPU = true;
} else {
pastKey.ToDevice(DataDevice::CUDA);
pastValue.ToDevice(DataDevice::CUDA);
}

int unitLen = 64;
#ifdef USE_CUDA
unitLen = 128;
Expand Down Expand Up @@ -200,28 +209,34 @@ namespace fastllm {
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2);
AddTo(hiddenStates, w2);
}

RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-6, hiddenStates);
Data logits;
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);
logits.ToDevice(DataDevice::CPU);
Data logits, topk;
Data tempHiddenStates;
Data *lastHiddenStates;
if (maxLen > 1) {
Split(hiddenStates, 1, maxLen - 1, maxLen, tempHiddenStates);
lastHiddenStates = &tempHiddenStates;
} else {
lastHiddenStates = &hiddenStates;
}

int lastRet = -1;
if (generationConfig.output_logits && retLogits != nullptr) {
int size = logits.dims.back();
logits.ToDevice(DataDevice::CPU);
retLogits->resize(size);
memcpy((float*)retLogits->data(), ((float*)logits.cpuData) + (logits.dims[1] - 1) * size, size * logits.unitSize);
}
if (generationConfig.IsSimpleGreedy()) {
std::pair <float, int> ret = std::make_pair(-1e9, -1);
int base = logits.dims[1] - 1;
for (int i = 0; i < logits.dims.back(); i++) {
ret = max(ret, std::make_pair(((float*)logits.cpuData)[base * logits.dims.back() + i], i));
}
lastRet = ret.second;
} else if (!lastTokens.units.empty()) {
lastRet = LLMSampling(logits, logits.dims[1] - 1, generationConfig, lastTokens.units[0]);
{
auto &hiddenStates = *lastHiddenStates;
RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-6, hiddenStates);
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);
if (generationConfig.output_logits && retLogits != nullptr) {
int size = logits.dims.back();
logits.ToDevice(DataDevice::CPU);
retLogits->resize(size);
memcpy((float*)retLogits->data(), ((float*)logits.cpuData) + (logits.dims[1] - 1) * size, size * logits.unitSize);
}
if (generationConfig.IsSimpleGreedy()) {
TopK(logits, topk, 1);
topk.ToDevice(DataDevice::CPU);
lastRet = (int) (((float *) topk.cpuData)[0] + 1e-3);
} else if (!lastTokens.units.empty()) {
lastRet = LLMSampling(logits, logits.dims[1] - 1, generationConfig, lastTokens.units[0]);
}
}

return lastRet;
Expand All @@ -237,6 +252,7 @@ namespace fastllm {
alibiData.CopyFrom(Data(DataType::FLOAT32, {(int) alibi.size()}, alibi));
}

int maxLen = inputIds.dims[1];
Data hiddenStates;
Data attenInput;
Data q, k, v, qkv;
Expand Down Expand Up @@ -290,6 +306,14 @@ namespace fastllm {
v.Reshape(qkvSize);

Data &pastKey = pastKeyValues[i].first, &pastValue = pastKeyValues[i].second;
if (GetKVCacheInCPU()) {
pastKey.lockInCPU = true;
pastValue.lockInCPU = true;
} else {
pastKey.ToDevice(DataDevice::CUDA);
pastValue.ToDevice(DataDevice::CUDA);
}

int unitLen = 64;
#ifdef USE_CUDA
unitLen = 128;
Expand Down Expand Up @@ -351,25 +375,33 @@ namespace fastllm {
AddTo(hiddenStates, w2);
}

RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-6, hiddenStates);
Data logits;
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);
logits.ToDevice(DataDevice::CPU);
Data logits, topk;
Data tempHiddenStates;
Data *lastHiddenStates;
if (maxLen > 1) {
Split(hiddenStates, 1, maxLen - 1, maxLen, tempHiddenStates);
lastHiddenStates = &tempHiddenStates;
} else {
lastHiddenStates = &hiddenStates;
}

std::vector <int> lastRet;
if (generationConfig.IsSimpleGreedy()) {
for (int b = 0; b < batch; b++) {
int base = b * logits.dims[1] + logits.dims[1] - 1;
std::pair <float, int> ret = std::make_pair(-1e9, -1);
for (int i = 0; i < logits.dims.back(); i++) {
ret = max(ret, std::make_pair(((float *) logits.cpuData)[base * logits.dims.back() + i], i));
{
auto &hiddenStates = *lastHiddenStates;
RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-6, hiddenStates);
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);
if (generationConfig.IsSimpleGreedy()) {
TopK(logits, topk, 1);
topk.ToDevice(DataDevice::CPU);
for (int b = 0; b < batch; b++) {
int base = b;
lastRet.push_back((int) (((float *) topk.cpuData)[base * 2] + 1e-3));
}
} else {
for (int b = 0; b < batch; b++) {
int base = b * logits.dims[1] + logits.dims[1] - 1;
lastRet.push_back(LLMSampling(logits, base, generationConfig, lastTokens.units[b]));
}
lastRet.push_back(ret.second);
}
} else {
for (int b = 0; b < batch; b++) {
int base = b * logits.dims[1] + logits.dims[1] - 1;
lastRet.push_back(LLMSampling(logits, base, generationConfig, lastTokens.units[b]));
}
}

Expand Down Expand Up @@ -460,6 +492,14 @@ namespace fastllm {
v.Reshape(qkvSize);

Data &pastKey = *pastKeyValues[b * block_cnt + i].first, &pastValue = *pastKeyValues[b * block_cnt + i].second;
if (GetKVCacheInCPU()) {
pastKey.lockInCPU = true;
pastValue.lockInCPU = true;
} else {
pastKey.ToDevice(DataDevice::CUDA);
pastValue.ToDevice(DataDevice::CUDA);
}

int unitLen = 64;
#ifdef USE_CUDA
unitLen = 128;
Expand Down Expand Up @@ -528,31 +568,27 @@ namespace fastllm {
AddTo(hiddenStates, w2);
}

Data logits, curLogit;
RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-6, hiddenStates);
Data logits;
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);
logits.ToDevice(DataDevice::CPU);
std::vector <int> lastRet;
int total = 0;
for (int b = 0; b < batch; b++) {
Split(logits, 1, total + seqLens[b] - 1, total + seqLens[b], curLogit);
if (generationConfigs[b].output_logits && retLogits != nullptr && (*retLogits)[b] != nullptr) {
int base = (total + seqLens[b] - 1);
(*retLogits)[b]->resize(logits.dims.back());
memcpy((float*)(*retLogits)[b]->data(), (float*)(logits.cpuData + base * logits.dims.back() * logits.unitSize), logits.dims.back() * logits.unitSize);
curLogit.ToDevice(DataDevice::CPU);
(*retLogits)[b]->resize(curLogit.Count(0));
memcpy((float*)(*retLogits)[b]->data(), (float*)curLogit.cpuData, curLogit.GetBytes());
}
if (generationConfigs[b].IsSimpleGreedy()) {
std::pair<float, int> ret = std::make_pair(-1e9, -1);
int base = (total + seqLens[b] - 1);
total += seqLens[b];
for (int i = 0; i < logits.dims.back(); i++) {
ret = max(ret, std::make_pair(((float *) logits.cpuData)[base * logits.dims.back() + i], i));
}
lastRet.push_back(ret.second);
Data topk;
TopK(curLogit, topk, 1);
topk.ToDevice(DataDevice::CPU);
lastRet.push_back((int) (((float *) topk.cpuData)[0] + 1e-3));
} else {
int base = (total + seqLens[b] - 1);
total += seqLens[b];
lastRet.push_back(LLMSampling(logits, base, generationConfigs[b], lastTokens.units[b]));
lastRet.push_back(LLMSampling(curLogit, 0, generationConfigs[b], lastTokens.units[b]));
}
total += seqLens[b];
}
return lastRet;
}
Expand Down

0 comments on commit 8e9ede2

Please sign in to comment.