From 9ecf89b4f922303c8d0b51fb58f67afed416f4fe Mon Sep 17 00:00:00 2001 From: cgli Date: Wed, 10 Apr 2024 23:05:38 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=A2=9E=E5=8A=A0fp16?= =?UTF-8?q?=E7=AE=97=E5=AD=90=E5=90=8E=EF=BC=8C=E5=9C=A8Maxwell=E7=AD=89?= =?UTF-8?q?=E8=80=81=E5=B9=B3=E5=8F=B0=E7=9A=84=E7=BC=96=E8=AF=91=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/faq.md | 15 ++++++++++ src/devices/cuda/fastllm-cuda.cu | 50 ++++++++++++++++++++++++++++++-- 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/docs/faq.md b/docs/faq.md index 5ce9a868..17472d5f 100755 --- a/docs/faq.md +++ b/docs/faq.md @@ -43,6 +43,21 @@ cmake .. -DUSE_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=native if (PY_API) ``` +### identifier "__hdiv" is undefined + +**现象:** + +> src/devices/cuda/fastllm-cuda.cu(247): error: identifier "hexp" is undefined +> src/devices/cuda/fastllm-cuda.cu(247): error: identifier "__hdiv" is undefined +> ... + +**原因:** [计算能力(Compute Capability)](https://developer.nvidia.com/cuda-gpus) <= 5.3 的GPU不支持半精度计算。 + +**解决办法:** 如需要支持这些GPU,执行cmake时使用编译选项`CUDA_NO_TENSOR_CORE`: + +```shell +cmake .. -DUSE_CUDA=ON -DCUDA_NO_TENSOR_CORE=ON +``` ## Windows diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index f476b4fa..f89152d7 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -108,7 +108,11 @@ __global__ void FastllmCudaBiasKernel(half *a, half *bias, int k) { half *now = a + blockIdx.x * k; int stride = blockDim.x; for (int i = threadIdx.x; i < k; i += stride) { +#ifdef CUDA_NO_TENSOR_CORE + now[i] = __float2half(__half2float(now[i]) + __half2float(bias[i])); +#else now[i] = __hadd(now[i], bias[i]); +#endif } } @@ -141,8 +145,13 @@ __global__ void FastllmSwigluKernel(half* a, half *b, int len, int spatial, int int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < len) { int id = idx / mid * spatial + idx % mid; +#ifdef CUDA_NO_TENSOR_CORE + float x = __half2float(a[id]), y = __half2float(a[id + mid]); + b[idx] = __float2half((x / (1.0 + expf(-x))) * y); +#else half x = a[id], y = a[id + mid]; b[idx] = __hmul(__hdiv(x, __hadd(__float2half(1.0), hexp(-x))), y); +#endif } } @@ -156,7 +165,11 @@ __global__ void FastllmMulKernel(float* a, float *b, float v, int len) { __global__ void FastllmMulKernel(half* a, half *b, half v, int len) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < len) { +#ifdef CUDA_NO_TENSOR_CORE + b[idx] = __float2half(__half2float(a[idx]) * __half2float(v)); +#else b[idx] = __hmul(a[idx], v); +#endif } } @@ -180,7 +193,11 @@ __global__ void FastllmAddToKernel(float* a, float *b, float alpha, int len) { __global__ void FastllmAddToKernel(half* a, half *b, half alpha, int len) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < len) { +#ifdef CUDA_NO_TENSOR_CORE + a[idx] = __float2half(__half2float(a[idx]) + __half2float(b[idx]) * __half2float(alpha)); +#else a[idx] = __hadd(a[idx], __hmul(b[idx], alpha)); +#endif } } @@ -459,12 +476,38 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch __syncthreads(); // 4. 求和 - half sum = 0; + float sum = 0; +#ifdef CUDA_NO_TENSOR_CORE + for (int i = tid; i < channels; i += THREAD_PER_BLOCK) { + float outF = exp(__half2float(input[i]) - __half2float(maxV)); + sum = sum + outF; + output[i] = __float2half(outF); + } + sdata[tid] = __float2half(sum); + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + sdata[tid] = __float2half(__half2float(sdata[tid]) + __half2float(sdata[tid + s])); + } + __syncthreads(); + } + if (tid == 0) { + if (fabs(__half2float(sdata[0])) < 1e-6) { + sdata[0] = __float2half(0.1); + } + } + __syncthreads(); + + for (int i = tid; i < channels; i += THREAD_PER_BLOCK) { + output[i] = __float2half(__half2float(output[i]) / __half2float(sdata[0])); + } +#else for (int i = tid; i < channels; i += THREAD_PER_BLOCK) { output[i] = hexp(__hsub(input[i], maxV)); - sum = __hadd(sum, output[i]); + sum = sum + __half2float(output[i]); } - sdata[tid] = sum; + sdata[tid] = __float2half(sum); __syncthreads(); for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { @@ -483,6 +526,7 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch for (int i = tid; i < channels; i += THREAD_PER_BLOCK) { output[i] = __hdiv(output[i], sdata[0]); } +#endif } template