Skip to content

Commit

Permalink
加入int8下linear的向量化访存,提高推理速度。
Browse files Browse the repository at this point in the history
  • Loading branch information
cgli committed Oct 18, 2023
1 parent 6da45df commit d730e63
Showing 1 changed file with 43 additions and 5 deletions.
48 changes: 43 additions & 5 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ typedef union __align__(16) _union_half_4 {
}
} union_half4;

const size_t ST128_FP16_COUNT = 8;

static std::map<int, cublasHandle_t> s_fastllmCublasHandleMap;
cublasHandle_t getFastllmCublasHandle() {
int id = -1;
Expand Down Expand Up @@ -70,10 +72,40 @@ __global__ void FastllmCudaFloat2HalfKernel(float* a, half *b, int len) {
}

__global__ void FastllmCudaInt82HalfKernel(uint8_t* a, float *scales, uint8_t *zeros, half *b, int len, int per) {
#ifdef CUDA_NO_TENSOR_CORE
float scalesBuffer[2];
uint8_t zerosBuffer[2];
int threshold = ST128_FP16_COUNT;
int index = (threadIdx.x + blockIdx.x * blockDim.x) * ST128_FP16_COUNT;
for (int idx = index; idx < len; idx += (gridDim.x * blockDim.x) * ST128_FP16_COUNT) {
int startIdx = idx / per;
int endIdx = (idx + ST128_FP16_COUNT - 1) / per;
scalesBuffer[1] = scalesBuffer[0] = scales[startIdx];
zerosBuffer[1] = zerosBuffer[0] = zeros[startIdx];
if (endIdx > startIdx) {
threshold = (idx + ST128_FP16_COUNT - 1) % per;
scalesBuffer[1] = scales[endIdx];
zerosBuffer[1] = zeros[endIdx];
}
// 读取
union_char8 aBuffer[2];
half bBuffer[ST128_FP16_COUNT];
aBuffer[0].in = *reinterpret_cast<const uint2 *>(a + idx);
// 处理
for (int i=0; i<ST128_FP16_COUNT; i++) {
if (idx + i < len) {
int scaleIdx = i < threshold ? 0 : 1;
bBuffer[i] = __float2half(scalesBuffer[scaleIdx] * ((float)aBuffer[0].out[i] - zerosBuffer[scaleIdx]));
}
}
reinterpret_cast<uint4 *>(b)[idx / ST128_FP16_COUNT] = *reinterpret_cast<uint4 *>(bBuffer);
}
#else
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
b[idx] = __float2half(scales[idx / per] * ((float)a[idx] - zeros[idx / per]));
}
#endif
}

__global__ void FastllmCudaInt42HalfKernel(uint8_t* a, float *scales, float *mins, half *b, int len, int per) {
Expand Down Expand Up @@ -1230,12 +1262,13 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh
FastllmCudaFloat2HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaFp16Input, len);

len = k * m;
FastllmCudaInt82HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight.cudaData,
cudaScales,
cudaZeropoints,
cudaFp16Weight, len, m);

#ifdef CUDA_NO_TENSOR_CORE
int gridSize = (len - 1) / (threadPerBlock * ST128_FP16_COUNT) + 1;
FastllmCudaInt82HalfKernel <<< gridSize, threadPerBlock>>>((uint8_t*)weight.cudaData,
cudaScales,
cudaZeropoints,
cudaFp16Weight, len, m);

status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
Expand All @@ -1245,6 +1278,11 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh
cudaOutput, CType,
k, ComputeType, static_cast<cublasGemmAlgo_t>(CUBLAS_GEMM_DEFAULT));
#else
FastllmCudaInt82HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight.cudaData,
cudaScales,
cudaZeropoints,
cudaFp16Weight, len, m);

status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
Expand Down

0 comments on commit d730e63

Please sign in to comment.