Skip to content

Commit

Permalink
aliveThreadPool移动到cpuDevice里,部分batch操作改用多线程
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed May 8, 2024
1 parent 36c463e commit 0d9765a
Show file tree
Hide file tree
Showing 8 changed files with 466 additions and 205 deletions.
174 changes: 174 additions & 0 deletions include/devices/cpu/alivethreadpool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
//
// Created by huangyuyang on 11/4/24.
//

#ifndef ALIVETHREAD_H
#define ALIVETHREAD_H

#include <thread>
#include <vector>
#include <unistd.h>
#include <cstring>

namespace fastllm {
struct MultiThreadBaseOp {
virtual void Run() = 0;
};

struct AliveThreadTask {
int signal;
MultiThreadBaseOp *op;

AliveThreadTask () {
signal = 0;
op = nullptr;
}
};

struct AliveThreadLoop {
int id;
AliveThreadTask realTask;
volatile AliveThreadTask *task;

AliveThreadLoop(int id) {
this->id = id;
this->task = &this->realTask;
}

void operator()() {
auto lastRunTime = std::chrono::system_clock::now();
while (true) {
asm volatile("dmb ish");
if (task->signal == 1) {
task->op->Run();
task->signal = 0;
asm volatile("dmb ish");
lastRunTime = std::chrono::system_clock::now();
}

auto duration = std::chrono::duration_cast<std::chrono::microseconds> (std::chrono::system_clock::now() - lastRunTime);
double gap = double(duration.count()) * std::chrono::microseconds::period::num / std::chrono::microseconds::period::den;
if (gap > 3) {
sleep(0);
}
}
}

void PushOp(MultiThreadBaseOp *op) {
this->task->op = op;
asm volatile("dmb ish");
this->task->signal = 1;
asm volatile("dmb ish");
}

void Wait() {
while (true) {
int a = task->signal;
if (a == 0) {
break;
}
}
}
};

struct AliveThreadPool {
std::vector <AliveThreadLoop*> loops;
std::vector <std::thread*> threads;

AliveThreadPool (int threadNum) {
for (int i = 0; i < threadNum; i++) {
this->loops.push_back(new AliveThreadLoop(i));
this->threads.push_back(new std::thread(*(this->loops[i])));
}
}

void PushOp(int tid, MultiThreadBaseOp *op) {
this->loops[tid]->PushOp(op);
}

void Wait(int tid) {
this->loops[tid]->Wait();
}

void Shutdown() {
/// TODO: shutdown
}
};

struct MultiThreadMemcpyOp : MultiThreadBaseOp {
uint8_t *input, *output;
int len;

MultiThreadMemcpyOp (uint8_t *output, uint8_t *input, int len) : input(input), output(output), len(len) {}

void Run() {
memcpy(output, input, len);
}
};

static void RunMultiThreadMemcpy(uint8_t *output, uint8_t *input, int len, AliveThreadPool *pool) {
if (len < 256 * 1024) {
memcpy(output, input, len);
return;
}
int threadNum = pool->threads.size();
int per = len / pool->threads.size();
int cur = 0;
std::vector<fastllm::MultiThreadMemcpyOp*> ops;
for (int i = 0; i < threadNum; i++) {
int end = (i == threadNum - 1 ? len : cur + per + (cur + per * (threadNum - i) < len));
ops.push_back(new MultiThreadMemcpyOp(output + cur, input + cur, end - cur));
cur = end;
}
for (int i = 0; i < threadNum; i++) {
pool->PushOp(i, ops[i]);
}
for (int i = 0; i < threadNum; i++) {
pool->Wait(i);
delete ops[i];
}
}

// [n, m, k] -> [m, n, k], 以k个元素为单位做转置
struct MultiThreadTransposeByLineOp : MultiThreadBaseOp {
uint8_t *input, *output;
int n, m, k, st, end;

MultiThreadTransposeByLineOp (uint8_t *output, uint8_t *input, int n, int m, int k, int st, int end) :
input(input), output(output), n(n), m(m), k(k), st(st), end(end) {}

void Run() {
for (int x = st; x < end; x++) {
int i = x / m, j = x % m;
memcpy(output + (j * n + i) * k, input + (i * m + j) * k, k);
}
}
};

// [n, m, k] -> [m, n, k], 以k个元素为单位做转置
static void RunMultiThreadTransposeByLine(uint8_t *output, uint8_t *input, int n, int m, int k, AliveThreadPool *pool) {
/*if (len < 256 * 1024) {
memcpy(output, input, len);
return;
}*/
int threadNum = pool->threads.size();
int len = n * m;
int per = len / pool->threads.size();
int cur = 0;
std::vector<fastllm::MultiThreadTransposeByLineOp*> ops;
for (int i = 0; i < threadNum; i++) {
int end = (i == threadNum - 1 ? len : cur + per + (cur + per * (threadNum - i) < len));
ops.push_back(new MultiThreadTransposeByLineOp(output, input, n, m, k, cur, end));
cur = end;
}
for (int i = 0; i < threadNum; i++) {
pool->PushOp(i, ops[i]);
}
for (int i = 0; i < threadNum; i++) {
pool->Wait(i);
delete ops[i];
}
}
}

#endif
1 change: 1 addition & 0 deletions include/devices/cpu/cpudevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "device.h"
#include "cputhreadpool.h"
#include "alivethreadpool.h"

namespace fastllm {
class CpuDevice : BaseDevice {
Expand Down
94 changes: 0 additions & 94 deletions include/devices/tfacc/alivethreadpool.h

This file was deleted.

2 changes: 2 additions & 0 deletions include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <locale>
#include <codecvt>
#include "devices/cpu/cputhreadpool.h"
#include "devices/cpu/alivethreadpool.h"

#ifdef USE_SENTENCEPIECE
#include <sentencepiece_processor.h>
Expand All @@ -36,6 +37,7 @@ namespace fastllm {
int GetThreads();
bool GetKVCacheInCPU();
ThreadPool *GetPool();
AliveThreadPool *GetAlivePool();

struct GenerationConfig {
int output_token_limit = -1; // 最多输出多少, <= 0代表无限制
Expand Down
Loading

0 comments on commit 0d9765a

Please sign in to comment.