From 5adbec040d004c469376ec9c60bc10588a6f1be3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Tue, 4 Jun 2024 16:30:52 +0800 Subject: [PATCH] =?UTF-8?q?deepseekv2=20int4=E5=8A=A0=E9=80=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/devices/cpu/cpudevice.cpp | 292 ++++++++++++++++++---------------- src/models/deepseekv2.cpp | 2 +- 2 files changed, 155 insertions(+), 139 deletions(-) diff --git a/src/devices/cpu/cpudevice.cpp b/src/devices/cpu/cpudevice.cpp index 7e4f9ed2..364ac279 100644 --- a/src/devices/cpu/cpudevice.cpp +++ b/src/devices/cpu/cpudevice.cpp @@ -508,6 +508,141 @@ namespace fastllm { } } }; + + struct MultiThreadLinearInt4NoZeroOp : MultiThreadBaseOp { + uint8_t *a, *b; + int32_t *c; + int n, m, k, kstride; + int *weightSums; + float *weightMins, *scales, *bias; + LowBitConfig *config; + float *inputSums; + + MultiThreadLinearInt4NoZeroOp(uint8_t *a, uint8_t *b, int32_t *c, int n, int m, int k, int kstride, + int *weightSums, float *weightMins, float *scales, float *bias, LowBitConfig *config, + float *inputSums) : + a(a), b(b), c(c), n(n), m(m), k(k), kstride(kstride), + weightSums(weightSums), weightMins(weightMins), scales(scales), bias(bias), config(config), inputSums(inputSums) {} + +#ifdef __ARM_FEATURE_DOTPROD + inline static void RunSomeBlock(uint8_t *weightWalk, uint8_t *inputStart, int32_t *c, + int curBlock, uint32x2_t *sum, uint8x8x2_t *vi, + int block, int k, int m, int kstride) { + uint8x8_t maskHigh = vdup_n_u8(0xF0); + uint8x8_t maskLow = vdup_n_u8(0xF); + for (int i = 0; i < k; i++) { + std::vector values = std::vector (curBlock, 0); + uint8_t *inputWalk = inputStart; + int j = 0; + + for (int j = 0; j < curBlock; j++) { + sum[j][0] = sum[j][1] = 0; + } + for (; j + 15 < m; j += 16) { + for (int x = 0; x < curBlock; x++) { + vi[x] = vld2_u8(inputWalk + j + m * x); + } + uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2); + uint8x8_t va = vand_u8(ori, maskLow); + uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4); + for (int x = 0; x < curBlock; x++) { + sum[x] = vdot_u32(sum[x], va, vi[x].val[1]); + sum[x] = vdot_u32(sum[x], vb, vi[x].val[0]); + } + } + for (int x = 0; x < curBlock; x++) { + values[x] += sum[x][0] + sum[x][1]; + } + + for (; j + 1 < m; j += 2) { + int id = (i * m + j) / 2; + for (int x = 0; x < curBlock; x++) { + values[x] += (weightWalk[id] >> 4) * inputWalk[j + x * m]; + values[x] += (weightWalk[id] & 0xF) * inputWalk[j + 1 + x * m]; + } + } + + for (int x = 0; x < curBlock; x++) { + c[(block + x) * kstride + i] = values[x]; + } + } + } +#endif + void Run() { +#ifdef __ARM_FEATURE_DOTPROD +#define RUNBLOCK(x) for (; block + (x - 1) < n; block += (x)) RunSomeBlock(b, a + block * m, c, (x), sum, vi, block, k, m, kstride); + int block = 0; + uint32x2_t sum[16]; + uint8x8x2_t vi[16]; + RUNBLOCK(16); + RUNBLOCK(8);RUNBLOCK(7);RUNBLOCK(6);RUNBLOCK(5); + RUNBLOCK(4);RUNBLOCK(3);RUNBLOCK(2);RUNBLOCK(1); +#undef RUNBLOCK +#else + int block = 0; + + for (; block < n; block++) { + uint8_t *weightWalk = b; + uint8_t *inputStart = a + block * m; + + for (int i = 0; i < k; i++) { + int value = 0; + uint8_t *inputWalk = inputStart; + int j = 0; +#ifdef __ARM_FEATURE_DOTPROD + uint8x8_t maskHigh = vdup_n_u8(0xF0); + uint8x8_t maskLow = vdup_n_u8(0xF); + uint32x2_t sum0 = {0, 0}; + + for (; j + 15 < m; j += 16) { + uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2); + uint8x8x2_t in = vld2_u8(inputWalk + j); + uint8x8_t va = vand_u8(ori, maskLow); + uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4); + sum0 = vdot_u32(sum0, va, in.val[1]); + sum0 = vdot_u32(sum0, vb, in.val[0]); + } + value += sum0[0] + sum0[1]; +#elif defined(__aarch64__) + uint8x8_t maskHigh = vdup_n_u8(0xF0); + uint8x8_t maskLow = vdup_n_u8(0xF); + uint32x4_t sum0 = {0, 0, 0, 0}; + + for (; j + 15 < m; j += 16) { + uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2); + uint8x8x2_t in = vld2_u8(inputWalk + j); + uint8x8_t va = vand_u8(ori, maskLow); + uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4); + sum0 = vpadalq_u16(sum0, vmull_u8(va, in.val[1])); + sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0])); + } + value += sum0[0] + sum0[1] + sum0[2] + sum0[3]; +#elif defined(__AVX2__) + value += DotU4U8(weightWalk + i * m / 2, inputWalk, m); + j += m; +#endif + + for (; j + 1 < m; j += 2) { + int id = (i * m + j) / 2; + value += (weightWalk[id] >> 4) * inputWalk[j]; + value += (weightWalk[id] & 0xF) * inputWalk[j + 1]; + } + + c[block * kstride + i] = value; + } + } +#endif + for (int block = 0; block < n; block++) { + for (int i = 0; i < k; i++) { + int value = c[block * kstride + i]; + value -= weightSums[i] * config[block].zeroPoint; + ((float*)c)[block * kstride + i] = scales[i] * config[block].scale * value + + weightMins[i] * ((float)inputSums[block] - (int)config[block].zeroPoint * m) * config[block].scale + + (bias == nullptr ? 0.0 : bias[i]); + } + } + } + }; struct MultiThreadLinearInt4GroupOp : MultiThreadBaseOp { uint8_t *a, *b; @@ -652,7 +787,9 @@ namespace fastllm { float routeScale = floatParams.find("routeScale") != floatParams.end() ? floatParams.find("routeScale")->second : 1.0f; output.Allocate(); - if (input.dataType == DataType::FLOAT32 && weights[0]->dataType == DataType::INT4_GROUP && input.dims[0] == 1) { + if (input.dataType == DataType::FLOAT32 && + (weights[0]->dataType == DataType::INT4_GROUP || weights[0]->dataType == DataType::INT4_NOZERO) + && input.dims[0] == 1) { int dimsLen = logits.dims.size(); int outer = logits.Count(0) / logits.Count(dimsLen - 1); int channels = logits.dims[dimsLen - 1]; @@ -780,6 +917,10 @@ namespace fastllm { int mid = weights[idx * 2]->dims[0] / 2; Data *weightDown = weights[idx * 2 + 1]; int groupDown = weightDown->group, groupCntDown = weightDown->groupCnt; + if (weightDown->dataType != DataType::INT4_GROUP) { + groupDown = 1; + groupCntDown = mid; + } auto &inputConfigs = inputConfigsDown[l]; auto &inputSums = inputSumsDown[l]; auto &iscales = iscalesDown[l]; @@ -816,6 +957,10 @@ namespace fastllm { auto &izeros = izerosDown[l]; auto &uinputDown = uinputsDown[l]; int curThread = (curK / k) * base; + if (weightDown->dataType != DataType::INT4_GROUP) { + groupDown = 1; + groupCntDown = mid; + } MultiplyInt4GroupMultiThreadLaunch(uinputDown.data(), (uint8_t*)weightDown->cpuData, (int32_t *) results[l], 1, mid, m, weightDown->weightSum.data(), weightDown->mins.data(), weightDown->scales.data(), nullptr, inputSums, iscales, izeros, @@ -1823,141 +1968,6 @@ namespace fastllm { } }; - struct MultiThreadLinearInt4NoZeroOp : MultiThreadBaseOp { - uint8_t *a, *b; - int32_t *c; - int n, m, k, kstride; - int *weightSums; - float *weightMins, *scales, *bias; - LowBitConfig *config; - int *inputSums; - - MultiThreadLinearInt4NoZeroOp(uint8_t *a, uint8_t *b, int32_t *c, int n, int m, int k, int kstride, - int *weightSums, float *weightMins, float *scales, float *bias, LowBitConfig *config, - int *inputSums) : - a(a), b(b), c(c), n(n), m(m), k(k), kstride(kstride), - weightSums(weightSums), weightMins(weightMins), scales(scales), bias(bias), config(config), inputSums(inputSums) {} - -#ifdef __ARM_FEATURE_DOTPROD - inline static void RunSomeBlock(uint8_t *weightWalk, uint8_t *inputStart, int32_t *c, - int curBlock, uint32x2_t *sum, uint8x8x2_t *vi, - int block, int k, int m, int kstride) { - uint8x8_t maskHigh = vdup_n_u8(0xF0); - uint8x8_t maskLow = vdup_n_u8(0xF); - for (int i = 0; i < k; i++) { - std::vector values = std::vector (curBlock, 0); - uint8_t *inputWalk = inputStart; - int j = 0; - - for (int j = 0; j < curBlock; j++) { - sum[j][0] = sum[j][1] = 0; - } - for (; j + 15 < m; j += 16) { - for (int x = 0; x < curBlock; x++) { - vi[x] = vld2_u8(inputWalk + j + m * x); - } - uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2); - uint8x8_t va = vand_u8(ori, maskLow); - uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4); - for (int x = 0; x < curBlock; x++) { - sum[x] = vdot_u32(sum[x], va, vi[x].val[1]); - sum[x] = vdot_u32(sum[x], vb, vi[x].val[0]); - } - } - for (int x = 0; x < curBlock; x++) { - values[x] += sum[x][0] + sum[x][1]; - } - - for (; j + 1 < m; j += 2) { - int id = (i * m + j) / 2; - for (int x = 0; x < curBlock; x++) { - values[x] += (weightWalk[id] >> 4) * inputWalk[j + x * m]; - values[x] += (weightWalk[id] & 0xF) * inputWalk[j + 1 + x * m]; - } - } - - for (int x = 0; x < curBlock; x++) { - c[(block + x) * kstride + i] = values[x]; - } - } - } -#endif - void Run() { -#ifdef __ARM_FEATURE_DOTPROD -#define RUNBLOCK(x) for (; block + (x - 1) < n; block += (x)) RunSomeBlock(b, a + block * m, c, (x), sum, vi, block, k, m, kstride); - int block = 0; - uint32x2_t sum[16]; - uint8x8x2_t vi[16]; - RUNBLOCK(16); - RUNBLOCK(8);RUNBLOCK(7);RUNBLOCK(6);RUNBLOCK(5); - RUNBLOCK(4);RUNBLOCK(3);RUNBLOCK(2);RUNBLOCK(1); -#undef RUNBLOCK -#else - int block = 0; - - for (; block < n; block++) { - uint8_t *weightWalk = b; - uint8_t *inputStart = a + block * m; - - for (int i = 0; i < k; i++) { - int value = 0; - uint8_t *inputWalk = inputStart; - int j = 0; -#ifdef __ARM_FEATURE_DOTPROD - uint8x8_t maskHigh = vdup_n_u8(0xF0); - uint8x8_t maskLow = vdup_n_u8(0xF); - uint32x2_t sum0 = {0, 0}; - - for (; j + 15 < m; j += 16) { - uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2); - uint8x8x2_t in = vld2_u8(inputWalk + j); - uint8x8_t va = vand_u8(ori, maskLow); - uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4); - sum0 = vdot_u32(sum0, va, in.val[1]); - sum0 = vdot_u32(sum0, vb, in.val[0]); - } - value += sum0[0] + sum0[1]; -#elif defined(__aarch64__) - uint8x8_t maskHigh = vdup_n_u8(0xF0); - uint8x8_t maskLow = vdup_n_u8(0xF); - uint32x4_t sum0 = {0, 0, 0, 0}; - - for (; j + 15 < m; j += 16) { - uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2); - uint8x8x2_t in = vld2_u8(inputWalk + j); - uint8x8_t va = vand_u8(ori, maskLow); - uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4); - sum0 = vpadalq_u16(sum0, vmull_u8(va, in.val[1])); - sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0])); - } - value += sum0[0] + sum0[1] + sum0[2] + sum0[3]; -#elif defined(__AVX2__) - value += DotU4U8(weightWalk + i * m / 2, inputWalk, m); - j += m; -#endif - - for (; j + 1 < m; j += 2) { - int id = (i * m + j) / 2; - value += (weightWalk[id] >> 4) * inputWalk[j]; - value += (weightWalk[id] & 0xF) * inputWalk[j + 1]; - } - - c[block * kstride + i] = value; - } - } -#endif - for (int block = 0; block < n; block++) { - for (int i = 0; i < k; i++) { - int value = c[block * kstride + i]; - value -= weightSums[i] * config[block].zeroPoint; - ((float*)c)[block * kstride + i] = scales[i] * config[block].scale * value + - weightMins[i] * ((float)inputSums[block] - (int)config[block].zeroPoint * m) * config[block].scale + - (bias == nullptr ? 0.0 : bias[i]); - } - } - } - }; - //a = [n, m], b = [k, m], c = aT(b') = [n, k] void MultiplyMultiThread(uint8_t *a, uint8_t *b, int32_t *c, int n, int m, int k, int threadNum) { auto *pool = GetAlivePool(); @@ -2029,7 +2039,7 @@ namespace fastllm { void MultiplyInt4NoZeroMultiThread(uint8_t *a, uint8_t *b, int32_t *c, int n, int m, int k, int *weightSums, float *weightMins, float *scales, float *bias, std::vector &configs, int threadNum) { - std::vector inputSums; + std::vector inputSums; for (int i = 0; i < n; i++) { int sum = 0; for (int j = 0; j < m; j++) { @@ -2078,10 +2088,16 @@ namespace fastllm { for (int i = 0; i < threadNum; i++) { int end = (i == threadNum - 1 ? k : cur + per + (cur + per * (threadNum - i) < k)); - ops[startTid + i] = new MultiThreadLinearInt4GroupOp(a, b + cur * m / 2, c + cur, n, m, end - cur, k, + if (group > 1) { + ops[startTid + i] = new MultiThreadLinearInt4GroupOp(a, b + cur * m / 2, c + cur, n, m, end - cur, k, weightSums + cur * group, weightMins + cur * group, scales + cur * group, (bias == nullptr ? (float *) nullptr : bias + cur), iscales.data(), izeros.data(), inputSums.data(), group, groupCnt); + } else { + ops[startTid + i] = new MultiThreadLinearInt4NoZeroOp(a, b + cur * m / 2, c + cur, n, m, end - cur, k, + weightSums + cur * group, weightMins + cur * group, scales + cur * group, + (bias == nullptr ? (float *) nullptr : bias + cur), configs.data(), inputSums.data()); + } cur = end; } for (int i = 0; i < threadNum; i++) { diff --git a/src/models/deepseekv2.cpp b/src/models/deepseekv2.cpp index 4ec94b51..5eb61365 100644 --- a/src/models/deepseekv2.cpp +++ b/src/models/deepseekv2.cpp @@ -378,7 +378,7 @@ namespace fastllm { pastValue.ToDevice(DataDevice::CUDA); } - int unitLen = 64; + int unitLen = 128; #ifdef USE_CUDA unitLen = 128; #endif