Skip to content

Commit

Permalink
tfacc增加一些融合算子
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Apr 23, 2024
1 parent faf71da commit 1407747
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 119 deletions.
3 changes: 2 additions & 1 deletion include/devices/tfacc/fastllm-tfacc.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ namespace fastllm {
void RunTfaccLinearU(int n, int m, int k, int group, int groupCnt,
fastllm::Data *weight, fastllm::Data *bias,
std::vector <LowBitConfig> *inputConfigs,
uint8_t *uinput, float *output);
uint8_t *uinput, float *output,
LinearExType exType);

void AppendKVCache(long long uid, Data *content);

Expand Down
8 changes: 8 additions & 0 deletions include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,14 @@ namespace fastllm {

void Linear(Data &input, Data &weight, const Data &bias, Data &output);

enum LinearExType {
ExTypeNone = 0,
ExSwiglu = 1
};

void LinearEx(Data &input, Data &weight, const Data &bias, Data &output,
LinearExType exType); // 扩展Linear,可以接后续操作

void Split(const Data &input, int axis, int start, int end, Data &output);

void Cat(const Data &input0, const Data &input1, int axis, Data &output);
Expand Down
14 changes: 12 additions & 2 deletions src/devices/tfacc/fastllm-tfacc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,13 @@ namespace fastllm {
void TfaccClient::RunTfaccLinearU(int n, int m, int k, int group, int groupCnt,
fastllm::Data *weight, fastllm::Data *bias,
std::vector <LowBitConfig> *inputConfigs,
uint8_t *uinput, float *output) {
RegisterFastllmData(weight, "linear");
uint8_t *uinput, float *output,
LinearExType exType) {
std::string linearType = "linear";
if (exType == LinearExType::ExSwiglu) {
linearType = "linearSwiglu";
}
RegisterFastllmData(weight, linearType);
RegisterFastllmData(bias, "bias");

int opType = ComputeTaskType::LinearInt4NoZero;
Expand Down Expand Up @@ -232,6 +237,7 @@ namespace fastllm {
((int32_t*)buf)[4] = groupCnt;
((int32_t*)buf)[5] = weight->name.size();
((int32_t*)buf)[6] = biasName.size();
((int32_t*)buf)[7] = exType;

volatile uint8_t *cur = (uint8_t*)buf + 10 * sizeof(int32_t);
for (int i = 0; i < curN * group; i++) {
Expand All @@ -247,6 +253,10 @@ namespace fastllm {

this->Launch(opType);
this->Wait();

if (exType == LinearExType::ExSwiglu) {
k /= 2;
}
memcpy(((uint8_t*) output) + baseN * k * sizeof(int32_t),
(uint8_t*) result,
curN * k * sizeof(int32_t));
Expand Down
196 changes: 81 additions & 115 deletions src/devices/tfacc/tfaccdevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,59 @@
#include "utils.h"

namespace fastllm {
void GetArrayMinMax(float *a, int len, float &minValue, float &maxValue) {
int j = 0;
minValue = 1e100;
maxValue = -1e100;
#ifdef __aarch64__
float32x4_t mins = vdupq_n_f32(1e100);
float32x4_t maxs = vdupq_n_f32(-1e100);
for (; j + 3 < len; j += 4) {
float32x4_t v = vld1q_f32(a + j);
mins = vminq_f32(mins, v);
maxs = vmaxq_f32(maxs, v);
}
for (int l = 0; l < 4; l++) {
minValue = std::min(minValue, mins[l]);
maxValue = std::max(maxValue, maxs[l]);
}
#endif
for (; j < len; j++) {
minValue = std::min(minValue, a[j]);
maxValue = std::max(maxValue, a[j]);
}
}

void QuantizationAll(float *fValue, uint8_t *uValue, int len, LowBitConfig *config) {
float scale = config->scale;
float zeroPoint = config->zeroPoint;
int j = 0;
#ifdef __aarch64__
float32x4_t scales = vdupq_n_f32(scale);
float32x4_t zeros = vdupq_n_f32(zeroPoint + 0.5);
int32x4_t maxds = vcombine_s32(vcreate_s32(0x000000ff000000ff), vcreate_s32(0x000000ff000000ff));
int32x4_t minds = vcombine_s32(vcreate_s32(0x0000000000000000), vcreate_s32(0x0000000000000000));
for (; j + 7 < len; j += 8) {
float32x4_t fin1 = vld1q_f32(fValue + j);
float32x4_t fin2 = vld1q_f32(fValue + j + 4);
fin1 = vaddq_f32(vdivq_f32(fin1, scales), zeros);
fin2 = vaddq_f32(vdivq_f32(fin2, scales), zeros);
int32x4_t out1 = vcvtq_s32_f32(fin1);
int32x4_t out2 = vcvtq_s32_f32(fin2);
out1 = vmaxq_s32(out1, minds);
out1 = vminq_s32(out1, maxds);
out2 = vmaxq_s32(out2, minds);
out2 = vminq_s32(out2, maxds);
uint16x8_t out3 = vpaddq_u16(vreinterpretq_u16_s32(out1), vreinterpretq_u16_s32(out2));
uint8x8_t out = vmovn_u16(out3);
vst1_u8(uValue + j, out);
}
#endif
for (; j < len; j++) {
uValue[j] = (uint8_t) (std::min(255., (double) std::max(fValue[j] / scale + zeroPoint + 0.5, 0.0)));
}
}

static TfaccClient tfaccClient;

TfaccDevice::TfaccDevice() {
Expand Down Expand Up @@ -70,12 +123,19 @@ namespace fastllm {
std::vector <int> dims = input.dims;
dims.back() = weight.dims[0];

if (intParams.find("exType") != intParams.end()) {
LinearExType type = (LinearExType)intParams.find("exType")->second;
if (type == LinearExType::ExSwiglu) {
dims.back() /= 2;
}
}

output.dataType = input.dataType;
output.Resize(dims);
}

void TfaccLinearOp::Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams) {
//auto st = std::chrono::system_clock::now();
//auto st = std::chrono::system_clock::now();
Data &input = *(datas.find("input")->second);
Data &output = *(datas.find("output")->second);
Data &weight = *(datas.find("weight")->second);
Expand All @@ -85,117 +145,44 @@ namespace fastllm {
int n = input.Count(0) / input.dims.back();
int m = input.dims.back();
int k = output.dims.back();
LinearExType exType = LinearExType::ExTypeNone;
if (intParams.find("exType") != intParams.end()) {
exType = (LinearExType)intParams.find("exType")->second;
if (exType == LinearExType::ExSwiglu) {
k *= 2;
}
}

if (input.dataType == DataType::FLOAT32 && output.dataType == DataType::FLOAT32) {
if (weight.dataType == DataType::FLOAT32) {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
} else if (weight.dataType == DataType::INT4 ||
weight.dataType == DataType::INT4_NOZERO ||
weight.dataType == DataType::INT4_GROUP ||
weight.dataType == DataType::INT8) {
float *inputData = (float *) input.cpuData;
uint8_t *weightData = (uint8_t *) weight.cpuData;
float *outputData = (float *) output.cpuData;
float *biasData = bias.dims.size() > 0 ? (float *) bias.cpuData : nullptr;
weight.CalcWeightSum();

std::vector<LowBitConfig> inputConfigs;
for (int i = 0; i < n; i++) {
float minValue = 1e9, maxValue = -1e9;
int j = 0;
#ifdef __aarch64__
float32x4_t mins = vdupq_n_f32(1e100);
float32x4_t maxs = vdupq_n_f32(-1e100);
for (; j + 3 < m; j += 4) {
float32x4_t v = vld1q_f32(inputData + i * m + j);
mins = vminq_f32(mins, v);
maxs = vmaxq_f32(maxs, v);
}
for (int l = 0; l < 4; l++) {
minValue = std::min(minValue, mins[l]);
maxValue = std::max(maxValue, maxs[l]);
}
#endif
for (; j < m; j++) {
minValue = std::min(minValue, inputData[i * m + j]);
maxValue = std::max(maxValue, inputData[i * m + j]);
}
inputConfigs.push_back(LowBitConfig(minValue, maxValue, 8, 0));
}
std::vector<uint8_t> uinput;
uinput.resize(n * m);

for (int i = 0; i < n; i++) {
float scale = inputConfigs[i].scale;
float zeroPoint = inputConfigs[i].zeroPoint;
float *cur = inputData + i * m;
uint8_t *u = uinput.data() + i * m;
int j = 0;
#ifdef __aarch64__
float32x4_t scales = vdupq_n_f32(scale);
float32x4_t zeros = vdupq_n_f32(zeroPoint + 0.5);
int32x4_t maxds = vcombine_s32(vcreate_s32(0x000000ff000000ff), vcreate_s32(0x000000ff000000ff));
int32x4_t minds = vcombine_s32(vcreate_s32(0x0000000000000000), vcreate_s32(0x0000000000000000));
for (; j + 7 < m; j += 8) {
float32x4_t fin1 = vld1q_f32(cur + j);
float32x4_t fin2 = vld1q_f32(cur + j + 4);
fin1 = vaddq_f32(vdivq_f32(fin1, scales), zeros);
fin2 = vaddq_f32(vdivq_f32(fin2, scales), zeros);
int32x4_t out1 = vcvtq_s32_f32(fin1);
int32x4_t out2 = vcvtq_s32_f32(fin2);
out1 = vmaxq_s32(out1, minds);
out1 = vminq_s32(out1, maxds);
out2 = vmaxq_s32(out2, minds);
out2 = vminq_s32(out2, maxds);
uint16x8_t out3 = vpaddq_u16(vreinterpretq_u16_s32(out1), vreinterpretq_u16_s32(out2));
uint8x8_t out = vmovn_u16(out3);
vst1_u8(u + j, out);
}
#endif
for (; j < m; j++) {
u[j] = (uint8_t) (std::min(255., (double) std::max(cur[j] / scale + zeroPoint + 0.5, 0.0)));
}
}

if (weight.dataType == DataType::INT4) {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
} else if (weight.dataType == DataType::INT8 || weight.dataType == DataType::INT4_NOZERO) {
tfaccClient.RunTfaccLinearU(n, m, k, 1, 1, &weight, &bias, &inputConfigs, uinput.data(), outputData);
}
} else if (weight.dataType == DataType::INT4_GROUP) {
float *inputData = (float *) input.cpuData;
uint8_t *weightData = (uint8_t *) weight.cpuData;
float *outputData = (float *) output.cpuData;
float *biasData = bias.dims.size() > 0 ? (float *) bias.cpuData : nullptr;
int group = weight.group, groupCnt = weight.groupCnt;
weight.CalcWeightSum();
if (weight.dataType != DataType::INT4_GROUP) {
group = 1;
groupCnt = m;
}

std::vector<LowBitConfig> inputConfigs;
for (int i = 0; i < n; i++) {
for (int g = 0; g < group; g++) {
int st = g * groupCnt;
int end = std::min(m, (g + 1) * groupCnt);
float minValue = 1e9, maxValue = -1e9;
int j = st;
#ifdef __aarch64__
float32x4_t mins = vdupq_n_f32(1e100);
float32x4_t maxs = vdupq_n_f32(-1e100);
for (; j + 3 < end; j += 4) {
float32x4_t v = vld1q_f32(inputData + i * m + j);
mins = vminq_f32(mins, v);
maxs = vmaxq_f32(maxs, v);
}
for (int l = 0; l < 4; l++) {
minValue = std::min(minValue, mins[l]);
maxValue = std::max(maxValue, maxs[l]);
}
#endif
for (; j < end; j++) {
minValue = std::min(minValue, inputData[i * m + j]);
maxValue = std::max(maxValue, inputData[i * m + j]);
}
GetArrayMinMax(inputData + i * m + st, end - st, minValue, maxValue);
inputConfigs.push_back(LowBitConfig(minValue, maxValue, 8, 0));
}
}

std::vector<uint8_t> uinput;
uinput.resize(n * m);
for (int i = 0; i < n; i++) {
Expand All @@ -204,36 +191,15 @@ namespace fastllm {
for (int g = 0; g < group; g++) {
int st = g * groupCnt;
int end = std::min(m, (g + 1) * groupCnt);
int j = st;
auto &config = inputConfigs[i * group + g];
#ifdef __aarch64__
float32x4_t scales = vdupq_n_f32(config.scale);
float32x4_t zeros = vdupq_n_f32(config.zeroPoint + 0.5);
int32x4_t maxds = vcombine_s32(vcreate_s32(0x000000ff000000ff), vcreate_s32(0x000000ff000000ff));
int32x4_t minds = vcombine_s32(vcreate_s32(0x0000000000000000), vcreate_s32(0x0000000000000000));
for (; j + 7 < end; j += 8) {
float32x4_t fin1 = vld1q_f32(cur + j);
float32x4_t fin2 = vld1q_f32(cur + j + 4);
fin1 = vaddq_f32(vdivq_f32(fin1, scales), zeros);
fin2 = vaddq_f32(vdivq_f32(fin2, scales), zeros);
int32x4_t out1 = vcvtq_s32_f32(fin1);
int32x4_t out2 = vcvtq_s32_f32(fin2);
out1 = vmaxq_s32(out1, minds);
out1 = vminq_s32(out1, maxds);
out2 = vmaxq_s32(out2, minds);
out2 = vminq_s32(out2, maxds);
uint16x8_t out3 = vpaddq_u16(vreinterpretq_u16_s32(out1), vreinterpretq_u16_s32(out2));
uint8x8_t out = vmovn_u16(out3);
vst1_u8(u + j, out);
}
#endif
for (; j < end; j++) {
uinput[i * m + j] = config.quantization(inputData[i * m + j]);
}
QuantizationAll(cur + st, u + st, end - st, &inputConfigs[i * group + g]);
}
}

tfaccClient.RunTfaccLinearU(n, m, k, group, groupCnt, &weight, &bias, &inputConfigs, uinput.data(), outputData);
if (weight.dataType == DataType::INT4) {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
} else if (weight.dataType == DataType::INT8 || weight.dataType == DataType::INT4_NOZERO || weight.dataType == DataType::INT4_GROUP) {
tfaccClient.RunTfaccLinearU(n, m, k, group, groupCnt, &weight, &bias, &inputConfigs, uinput.data(), outputData, exType);
}
}
} else if (input.dataType == DataType::FLOAT16 && output.dataType == DataType::FLOAT16) {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
Expand Down
6 changes: 5 additions & 1 deletion src/models/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,17 @@ namespace fastllm {
// 1.4 MLP
std::string fcInKeyName = "transformer.encoder.layers." + std::to_string(i) + ".mlp.dense_h_to_4h";
std::string fcOutKeyName = "transformer.encoder.layers." + std::to_string(i) + ".mlp.dense_4h_to_h";
#ifdef USE_TFACC
LinearEx(mlpInput, weight[fcInKeyName + ".weight"], weight[fcInKeyName + ".bias"], middle2, LinearExType::ExSwiglu);
#else
Linear(mlpInput, weight[fcInKeyName + ".weight"], weight[fcInKeyName + ".bias"], middle);
Swiglu(middle, middle2);
#endif
Linear(middle2, weight[fcOutKeyName + ".weight"], weight[fcOutKeyName + ".bias"], hiddenStates);
AddTo(hiddenStates, temp);
}
}

Data logits, topk;
Data tempHiddenStates;
Data *lastHiddenStates;
Expand Down
Binary file modified third_party/tfacc/server
Binary file not shown.

0 comments on commit 1407747

Please sign in to comment.