Skip to content

Commit

Permalink
int4g neon优化
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Apr 22, 2024
1 parent 6eb96e8 commit faf71da
Showing 1 changed file with 46 additions and 3 deletions.
49 changes: 46 additions & 3 deletions src/devices/tfacc/tfaccdevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,21 @@ namespace fastllm {
int st = g * groupCnt;
int end = std::min(m, (g + 1) * groupCnt);
float minValue = 1e9, maxValue = -1e9;
for (int j = st; j < end; j++) {
int j = st;
#ifdef __aarch64__
float32x4_t mins = vdupq_n_f32(1e100);
float32x4_t maxs = vdupq_n_f32(-1e100);
for (; j + 3 < end; j += 4) {
float32x4_t v = vld1q_f32(inputData + i * m + j);
mins = vminq_f32(mins, v);
maxs = vmaxq_f32(maxs, v);
}
for (int l = 0; l < 4; l++) {
minValue = std::min(minValue, mins[l]);
maxValue = std::max(maxValue, maxs[l]);
}
#endif
for (; j < end; j++) {
minValue = std::min(minValue, inputData[i * m + j]);
maxValue = std::max(maxValue, inputData[i * m + j]);
}
Expand All @@ -185,8 +199,37 @@ namespace fastllm {
std::vector<uint8_t> uinput;
uinput.resize(n * m);
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
uinput[i * m + j] = inputConfigs[i * group + j / groupCnt].quantization(inputData[i * m + j]);
float *cur = inputData + i * m;
uint8_t *u = uinput.data() + i * m;
for (int g = 0; g < group; g++) {
int st = g * groupCnt;
int end = std::min(m, (g + 1) * groupCnt);
int j = st;
auto &config = inputConfigs[i * group + g];
#ifdef __aarch64__
float32x4_t scales = vdupq_n_f32(config.scale);
float32x4_t zeros = vdupq_n_f32(config.zeroPoint + 0.5);
int32x4_t maxds = vcombine_s32(vcreate_s32(0x000000ff000000ff), vcreate_s32(0x000000ff000000ff));
int32x4_t minds = vcombine_s32(vcreate_s32(0x0000000000000000), vcreate_s32(0x0000000000000000));
for (; j + 7 < end; j += 8) {
float32x4_t fin1 = vld1q_f32(cur + j);
float32x4_t fin2 = vld1q_f32(cur + j + 4);
fin1 = vaddq_f32(vdivq_f32(fin1, scales), zeros);
fin2 = vaddq_f32(vdivq_f32(fin2, scales), zeros);
int32x4_t out1 = vcvtq_s32_f32(fin1);
int32x4_t out2 = vcvtq_s32_f32(fin2);
out1 = vmaxq_s32(out1, minds);
out1 = vminq_s32(out1, maxds);
out2 = vmaxq_s32(out2, minds);
out2 = vminq_s32(out2, maxds);
uint16x8_t out3 = vpaddq_u16(vreinterpretq_u16_s32(out1), vreinterpretq_u16_s32(out2));
uint8x8_t out = vmovn_u16(out3);
vst1_u8(u + j, out);
}
#endif
for (; j < end; j++) {
uinput[i * m + j] = config.quantization(inputData[i * m + j]);
}
}
}

Expand Down

0 comments on commit faf71da

Please sign in to comment.