Skip to content

Commit

Permalink
Merge pull request ztxz16#364 from fluxlinkage/master
Browse files Browse the repository at this point in the history
修复GLM模型内存溢出问题
  • Loading branch information
ztxz16 authored Nov 8, 2023
2 parents 737a58a + 37735b7 commit c0c9630
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
10 changes: 9 additions & 1 deletion example/apiserver/apiserver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,11 @@ struct WorkQueue {
while (true) {
std::unique_lock <std::mutex> lock(ts->locker);
if (ts->activateQueryNumber >= ts->maxActivateQueryNumber) {
#ifdef WIN32
Sleep(0);
#else
sleep(0);
#endif
continue;
}
if (ts->q.empty()) {
Expand Down Expand Up @@ -435,7 +439,11 @@ int main(int argc, char** argv) {
buff[size] = 0;

while (workQueue.q.size() > workQueue.maxActivateQueryNumber) {
sleep(0);
#ifdef WIN32
Sleep(0);
#else
sleep(0);
#endif
}
workQueue.Push(buff, client);
}
Expand Down
8 changes: 3 additions & 5 deletions src/models/glm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ namespace fastllm {
Data mlpOutput;
Data middle, middle2;
Data toSave;
Data mem2;
Data mem2,mem3;
std::vector<int> lastRet;
// GLMBlock
std::string weightPre, weightMiddle;
Expand Down Expand Up @@ -131,8 +131,8 @@ namespace fastllm {
Split(qkv, -1, per * 2, per * 3, v);
}else{
LayerNorm(mem, weight[inputLNWeightName], weight[inputLNBiasName], -1, mem2);
CatDirect(mem2,attenInput,1);
Linear(mem2, weight[qkvWeightName], weight[qkvBiasName], qkv);
Cat(mem2,attenInput,1,mem3);
Linear(mem3, weight[qkvWeightName], weight[qkvBiasName], qkv);
int per = qkv.dims.back() / 3;
Split(qkv, -1, 0, per, q0);
Split(qkv, -1, per, per * 2, k);
Expand Down Expand Up @@ -296,10 +296,8 @@ namespace fastllm {
for(unsigned int i=0;i<hexString.length();i+=2){
decoded.push_back(std::stoi(hexString.substr(i,2),nullptr,16));
}
printf("%lu\n",decoded.length());
weight.tokenizer.spProcessor=std::make_unique<sentencepiece::SentencePieceProcessor>();
weight.tokenizer.spProcessor->LoadFromSerializedProto(decoded);
printf("GetPieceSize=%d\n",weight.tokenizer.spProcessor->GetPieceSize());
}
}
#endif
Expand Down

0 comments on commit c0c9630

Please sign in to comment.