diff --git a/CMakeLists.txt b/CMakeLists.txt index 8b33235f..9742089f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,8 @@ cmake_minimum_required(VERSION 3.5) project(fastllm LANGUAGES CXX) -option(USE_CUDA "use cuda" OFF) +option(USE_CUDA "use CUDA" OFF) +option(CUDA_NO_TENSOR_CORE "Optimize for legacy CUDA GPUs which has Tensor Core." OFF) option(PY_API "python api" OFF) @@ -14,6 +15,8 @@ option(USE_IVCOREX "use iluvatar corex gpu" OFF) message(STATUS "USE_CUDA: ${USE_CUDA}") +message(STATUS "For legacy CUDA GPUs: ${CUDA_NO_TENSOR_CORE}") + message(STATUS "PYTHON_API: ${PY_API}") message(STATUS "USE_SENTENCEPIECE: ${USE_SENTENCEPIECE}") @@ -56,6 +59,9 @@ endif() if (USE_CUDA) enable_language(CUDA) add_compile_definitions(USE_CUDA) + if (CUDA_NO_TENSOR_CORE) + add_compile_definitions(CUDA_NO_TENSOR_CORE) + endif() include_directories(include/devices/cuda) #message(${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES}) set(FASTLLM_CUDA_SOURCES src/devices/cuda/cudadevice.cpp src/devices/cuda/cudadevicebatch.cpp src/devices/cuda/fastllm-cuda.cu) diff --git a/docs/faq.md b/docs/faq.md index 5ce9a868..17472d5f 100755 --- a/docs/faq.md +++ b/docs/faq.md @@ -43,6 +43,21 @@ cmake .. -DUSE_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=native if (PY_API) ``` +### identifier "__hdiv" is undefined + +**现象:** + +> src/devices/cuda/fastllm-cuda.cu(247): error: identifier "hexp" is undefined +> src/devices/cuda/fastllm-cuda.cu(247): error: identifier "__hdiv" is undefined +> ... + +**原因:** [计算能力(Compute Capability)](https://developer.nvidia.com/cuda-gpus) <= 5.3 的GPU不支持半精度计算。 + +**解决办法:** 如需要支持这些GPU,执行cmake时使用编译选项`CUDA_NO_TENSOR_CORE`: + +```shell +cmake .. -DUSE_CUDA=ON -DCUDA_NO_TENSOR_CORE=ON +``` ## Windows diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index 1a9bd44c..f89152d7 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -108,7 +108,11 @@ __global__ void FastllmCudaBiasKernel(half *a, half *bias, int k) { half *now = a + blockIdx.x * k; int stride = blockDim.x; for (int i = threadIdx.x; i < k; i += stride) { +#ifdef CUDA_NO_TENSOR_CORE + now[i] = __float2half(__half2float(now[i]) + __half2float(bias[i])); +#else now[i] = __hadd(now[i], bias[i]); +#endif } } @@ -141,8 +145,13 @@ __global__ void FastllmSwigluKernel(half* a, half *b, int len, int spatial, int int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < len) { int id = idx / mid * spatial + idx % mid; +#ifdef CUDA_NO_TENSOR_CORE + float x = __half2float(a[id]), y = __half2float(a[id + mid]); + b[idx] = __float2half((x / (1.0 + expf(-x))) * y); +#else half x = a[id], y = a[id + mid]; b[idx] = __hmul(__hdiv(x, __hadd(__float2half(1.0), hexp(-x))), y); +#endif } } @@ -156,7 +165,11 @@ __global__ void FastllmMulKernel(float* a, float *b, float v, int len) { __global__ void FastllmMulKernel(half* a, half *b, half v, int len) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < len) { +#ifdef CUDA_NO_TENSOR_CORE + b[idx] = __float2half(__half2float(a[idx]) * __half2float(v)); +#else b[idx] = __hmul(a[idx], v); +#endif } } @@ -180,7 +193,11 @@ __global__ void FastllmAddToKernel(float* a, float *b, float alpha, int len) { __global__ void FastllmAddToKernel(half* a, half *b, half alpha, int len) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < len) { +#ifdef CUDA_NO_TENSOR_CORE + a[idx] = __float2half(__half2float(a[idx]) + __half2float(b[idx]) * __half2float(alpha)); +#else a[idx] = __hadd(a[idx], __hmul(b[idx], alpha)); +#endif } } @@ -418,6 +435,18 @@ __device__ void FastllmSoftmaxKernelInner1Func(float *input, float *output, int } } +__device__ half FastllmHalfMaxFunc(const __half a, const __half b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 + return __half2float(a) >= __half2float(b) ? a : b; +#else +#if defined(CUDART_VERSION) && CUDART_VERSION > 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __hmax(a, b); +#else + return __hge(a, b) ? a : b; +#endif +#endif +} + template __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int channels) { __shared__ half sdata[THREAD_PER_BLOCK]; @@ -427,7 +456,7 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch unsigned int tid = threadIdx.x; half maxValue = input[tid]; for (int i = tid; i < channels; i += THREAD_PER_BLOCK) { - maxValue = __hmax(maxValue, input[i]); + maxValue = FastllmHalfMaxFunc(maxValue, input[i]); } sdata[tid] = maxValue; __syncthreads(); @@ -435,7 +464,7 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch // 2. 求max for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s) { - sdata[tid] = __hmax(sdata[tid], sdata[tid + s]); + sdata[tid] = FastllmHalfMaxFunc(sdata[tid], sdata[tid + s]); } __syncthreads(); } @@ -447,12 +476,38 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch __syncthreads(); // 4. 求和 - half sum = 0; + float sum = 0; +#ifdef CUDA_NO_TENSOR_CORE + for (int i = tid; i < channels; i += THREAD_PER_BLOCK) { + float outF = exp(__half2float(input[i]) - __half2float(maxV)); + sum = sum + outF; + output[i] = __float2half(outF); + } + sdata[tid] = __float2half(sum); + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + sdata[tid] = __float2half(__half2float(sdata[tid]) + __half2float(sdata[tid + s])); + } + __syncthreads(); + } + if (tid == 0) { + if (fabs(__half2float(sdata[0])) < 1e-6) { + sdata[0] = __float2half(0.1); + } + } + __syncthreads(); + + for (int i = tid; i < channels; i += THREAD_PER_BLOCK) { + output[i] = __float2half(__half2float(output[i]) / __half2float(sdata[0])); + } +#else for (int i = tid; i < channels; i += THREAD_PER_BLOCK) { output[i] = hexp(__hsub(input[i], maxV)); - sum = __hadd(sum, output[i]); + sum = sum + __half2float(output[i]); } - sdata[tid] = sum; + sdata[tid] = __float2half(sum); __syncthreads(); for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { @@ -471,6 +526,7 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch for (int i = tid; i < channels; i += THREAD_PER_BLOCK) { output[i] = __hdiv(output[i], sdata[0]); } +#endif } template @@ -1340,12 +1396,19 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh if (n >= 8) { auto fastllmCublasHandle = getFastllmCublasHandle(); half *cudaFp16Input, *cudaFp16Output, *cudaFp16Weight; +#ifdef CUDA_NO_TENSOR_CORE + cudaFp16Input = (half *) FastllmCudaMalloc(n * m * sizeof(half)); + cudaFp16Weight = (half *) FastllmCudaMalloc(k * m * sizeof(half)); + float h_alpha = 1.0, h_beta = 0.0; + cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_32F, ComputeType = CUDA_R_32F; +#else cudaFp16Input = (half *) FastllmCudaMalloc(n * m * sizeof(half)); cudaFp16Output = (half *) FastllmCudaMalloc(n * k * sizeof(half)); cudaFp16Weight = (half *) FastllmCudaMalloc(k * m * sizeof(half)); __half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0); cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_16F, ComputeType = CUDA_R_16F; +#endif cublasStatus_t status; int len = n * m; @@ -1358,6 +1421,16 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh cudaZeropoints, cudaFp16Weight, len, m); +#ifdef CUDA_NO_TENSOR_CORE + status = cublasGemmEx(fastllmCublasHandle, + CUBLAS_OP_T, CUBLAS_OP_N, + k, n, m, + &h_alpha, cudaFp16Weight, AType, + m, cudaFp16Input, BType, + m, &h_beta, + cudaOutput, CType, + k, ComputeType, static_cast(CUBLAS_GEMM_DEFAULT)); +#else status = cublasGemmEx(fastllmCublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, k, n, m, @@ -1366,6 +1439,7 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh m, &h_beta, cudaFp16Output, CType, k, ComputeType, static_cast(CUBLAS_GEMM_DEFAULT)); +#endif if (status != CUBLAS_STATUS_SUCCESS) { printf("Error: cublas error.\n"); throw("cublas error"); @@ -1373,12 +1447,18 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh } len = n * k; +#ifdef CUDA_NO_TENSOR_CORE + FastllmCudaBiasKernel <<< n, 256 >>> (cudaOutput, cudaBiasData, k); + FastllmCudaFree(cudaFp16Input); + FastllmCudaFree(cudaFp16Weight); +#else FastllmCudaHalf2FlotaKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock >>>(cudaFp16Output, cudaOutput, len); FastllmCudaBiasKernel <<< n, 256 >>> (cudaOutput, cudaBiasData, k); FastllmCudaFree(cudaFp16Input); FastllmCudaFree(cudaFp16Output); FastllmCudaFree(cudaFp16Weight); +#endif } else { for (int i = 0; i < n; i++) { FastllmGemvInt8Kernel2<256, 1> <<< k, 256 >>>(cudaInput + i * m, @@ -1582,12 +1662,20 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data if (n >= 8) { auto fastllmCublasHandle = getFastllmCublasHandle(); half *cudaFp16Input, *cudaFp16Output, *cudaFp16Weight; +#ifdef CUDA_NO_TENSOR_CORE + cudaFp16Input = (half *) FastllmCudaMalloc(n * m * sizeof(half)); + cudaFp16Weight = (half *) FastllmCudaMalloc(k * m * sizeof(half)); + + float h_alpha = 1.0, h_beta = 0.0; + cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_32F, ComputeType = CUDA_R_32F; +#else cudaFp16Input = (half *) FastllmCudaMalloc(n * m * sizeof(half)); cudaFp16Output = (half *) FastllmCudaMalloc(n * k * sizeof(half)); cudaFp16Weight = (half *) FastllmCudaMalloc(k * m * sizeof(half)); __half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0); cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_16F, ComputeType = CUDA_R_16F; +#endif cublasStatus_t status; int len = n * m; @@ -1601,6 +1689,16 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data cudaMins, cudaFp16Weight, len, m); +#ifdef CUDA_NO_TENSOR_CORE + status = cublasGemmEx(fastllmCublasHandle, + CUBLAS_OP_T, CUBLAS_OP_N, + k, n, m, + &h_alpha, cudaFp16Weight, AType, + m, cudaFp16Input, BType, + m, &h_beta, + cudaOutput, CType, + k, ComputeType, static_cast(CUBLAS_GEMM_DEFAULT)); +#else status = cublasGemmEx(fastllmCublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, k, n, m, @@ -1609,6 +1707,7 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data m, &h_beta, cudaFp16Output, CType, k, ComputeType, static_cast(CUBLAS_GEMM_DEFAULT)); +#endif if (status != CUBLAS_STATUS_SUCCESS) { printf("Error: cublas error.\n"); throw("cublas error"); @@ -1616,6 +1715,11 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data } len = n * k; +#ifdef CUDA_NO_TENSOR_CORE + FastllmCudaBiasKernel <<< n, 256 >>>(cudaOutput, cudaBiasData, k); + FastllmCudaFree(cudaFp16Input); + FastllmCudaFree(cudaFp16Weight); +#else FastllmCudaHalf2FlotaKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock >>>(cudaFp16Output, cudaOutput, len); FastllmCudaBiasKernel <<< n, 256 >>>(cudaOutput, cudaBiasData, k); @@ -1623,6 +1727,7 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data FastllmCudaFree(cudaFp16Input); FastllmCudaFree(cudaFp16Output); FastllmCudaFree(cudaFp16Weight); +#endif } else { for (int i = 0; i < n; i++) { FastllmGemvInt4NoZeroKernel2<256, 1> <<< k, 256 >>>(cudaInput + i * m, @@ -1708,14 +1813,21 @@ bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, float *cudaOutput = (float*)FastllmCudaPrepareOutput(output); if (n > 1) { + auto fastllmCublasHandle = getFastllmCublasHandle(); + //cudaDeviceSynchronize(); half *cudaFp16Input, *cudaFp16Output; +#ifdef CUDA_NO_TENSOR_CORE + cudaFp16Input = (half *) FastllmCudaMalloc(n * m * sizeof(half)); + + float h_alpha = 1.0, h_beta = 0.0; + cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_32F, ComputeType = CUDA_R_32F; +#else cudaFp16Input = (half *) FastllmCudaMalloc(n * m * sizeof(half)); cudaFp16Output = (half *) FastllmCudaMalloc(n * k * sizeof(half)); __half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0); - auto fastllmCublasHandle = getFastllmCublasHandle(); - //cudaDeviceSynchronize(); cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_16F, ComputeType = CUDA_R_16F; +#endif cublasStatus_t status; int len = n * m; @@ -1723,6 +1835,16 @@ bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, FastllmCudaFloat2HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaFp16Input, len); +#ifdef CUDA_NO_TENSOR_CORE + status = cublasGemmEx(fastllmCublasHandle, + CUBLAS_OP_T, CUBLAS_OP_N, + k, n, m, + &h_alpha, (half *) weight.cudaData, AType, + m, cudaFp16Input, BType, + m, &h_beta, + cudaOutput, CType, + k, ComputeType, static_cast(CUBLAS_GEMM_DEFAULT)); +#else status = cublasGemmEx(fastllmCublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, k, n, m, @@ -1731,6 +1853,7 @@ bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, m, &h_beta, cudaFp16Output, CType, k, ComputeType, static_cast(CUBLAS_GEMM_DEFAULT)); +#endif if (status != CUBLAS_STATUS_SUCCESS) { printf("Error: cublas error.\n"); throw("cublas error"); @@ -1738,6 +1861,10 @@ bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, } len = n * k; +#ifdef CUDA_NO_TENSOR_CORE + FastllmCudaBiasKernel <<< n, 256 >>> (cudaOutput, (float*)weight.extraCudaData[0], k); + FastllmCudaFree(cudaFp16Input); +#else FastllmCudaHalf2FlotaKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock >>>(cudaFp16Output, cudaOutput, len); FastllmCudaBiasKernel <<< n, 256 >>> (cudaOutput, (float*)weight.extraCudaData[0], k); @@ -1745,6 +1872,7 @@ bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, FastllmCudaFree(cudaFp16Input); FastllmCudaFree(cudaFp16Output); +#endif } else { FastllmGemvFp32Fp16Kernel2<256, 1> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); } @@ -3031,7 +3159,7 @@ bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &wei FastllmCudaFloat2HalfKernel <<< (k - 1) / threadPerBlock + 1, threadPerBlock>>>(tempBiasData, cudaBiasData, k); state = cudaFree(tempBiasData); } else { - state = cudaMemset(cudaBiasData, __float2half_rn(0.0), k * sizeof(half)); + state = cudaMemset(cudaBiasData, 0, k * sizeof(half)); } checkCudaErrors("Error: CUDA error when moving bias to device!", state); weight.extraCudaHalfData.push_back((void *) cudaBiasData);