Skip to content

Commit

Permalink
修复增加fp16算子后,在Maxwell等老平台的编译问题
Browse files Browse the repository at this point in the history
  • Loading branch information
cgli authored and TylunasLi committed Apr 10, 2024
1 parent f6bfb82 commit 9ecf89b
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 3 deletions.
15 changes: 15 additions & 0 deletions docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ cmake .. -DUSE_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=native
if (PY_API)
```

### identifier "__hdiv" is undefined

**现象:**

> src/devices/cuda/fastllm-cuda.cu(247): error: identifier "hexp" is undefined
> src/devices/cuda/fastllm-cuda.cu(247): error: identifier "__hdiv" is undefined
> ...
**原因:** [计算能力(Compute Capability)](https://developer.nvidia.com/cuda-gpus) <= 5.3 的GPU不支持半精度计算。

**解决办法:** 如需要支持这些GPU,执行cmake时使用编译选项`CUDA_NO_TENSOR_CORE`

```shell
cmake .. -DUSE_CUDA=ON -DCUDA_NO_TENSOR_CORE=ON
```

## Windows

Expand Down
50 changes: 47 additions & 3 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ __global__ void FastllmCudaBiasKernel(half *a, half *bias, int k) {
half *now = a + blockIdx.x * k;
int stride = blockDim.x;
for (int i = threadIdx.x; i < k; i += stride) {
#ifdef CUDA_NO_TENSOR_CORE
now[i] = __float2half(__half2float(now[i]) + __half2float(bias[i]));
#else
now[i] = __hadd(now[i], bias[i]);
#endif
}
}

Expand Down Expand Up @@ -141,8 +145,13 @@ __global__ void FastllmSwigluKernel(half* a, half *b, int len, int spatial, int
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
int id = idx / mid * spatial + idx % mid;
#ifdef CUDA_NO_TENSOR_CORE
float x = __half2float(a[id]), y = __half2float(a[id + mid]);
b[idx] = __float2half((x / (1.0 + expf(-x))) * y);
#else
half x = a[id], y = a[id + mid];
b[idx] = __hmul(__hdiv(x, __hadd(__float2half(1.0), hexp(-x))), y);
#endif
}
}

Expand All @@ -156,7 +165,11 @@ __global__ void FastllmMulKernel(float* a, float *b, float v, int len) {
__global__ void FastllmMulKernel(half* a, half *b, half v, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
#ifdef CUDA_NO_TENSOR_CORE
b[idx] = __float2half(__half2float(a[idx]) * __half2float(v));
#else
b[idx] = __hmul(a[idx], v);
#endif
}
}

Expand All @@ -180,7 +193,11 @@ __global__ void FastllmAddToKernel(float* a, float *b, float alpha, int len) {
__global__ void FastllmAddToKernel(half* a, half *b, half alpha, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
#ifdef CUDA_NO_TENSOR_CORE
a[idx] = __float2half(__half2float(a[idx]) + __half2float(b[idx]) * __half2float(alpha));
#else
a[idx] = __hadd(a[idx], __hmul(b[idx], alpha));
#endif
}
}

Expand Down Expand Up @@ -459,12 +476,38 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch
__syncthreads();

// 4. 求和
half sum = 0;
float sum = 0;
#ifdef CUDA_NO_TENSOR_CORE
for (int i = tid; i < channels; i += THREAD_PER_BLOCK) {
float outF = exp(__half2float(input[i]) - __half2float(maxV));
sum = sum + outF;
output[i] = __float2half(outF);
}
sdata[tid] = __float2half(sum);
__syncthreads();

for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] = __float2half(__half2float(sdata[tid]) + __half2float(sdata[tid + s]));
}
__syncthreads();
}
if (tid == 0) {
if (fabs(__half2float(sdata[0])) < 1e-6) {
sdata[0] = __float2half(0.1);
}
}
__syncthreads();

for (int i = tid; i < channels; i += THREAD_PER_BLOCK) {
output[i] = __float2half(__half2float(output[i]) / __half2float(sdata[0]));
}
#else
for (int i = tid; i < channels; i += THREAD_PER_BLOCK) {
output[i] = hexp(__hsub(input[i], maxV));
sum = __hadd(sum, output[i]);
sum = sum + __half2float(output[i]);
}
sdata[tid] = sum;
sdata[tid] = __float2half(sum);
__syncthreads();

for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
Expand All @@ -483,6 +526,7 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch
for (int i = tid; i < channels; i += THREAD_PER_BLOCK) {
output[i] = __hdiv(output[i], sdata[0]);
}
#endif
}

template <int THREAD_PER_BLOCK>
Expand Down

0 comments on commit 9ecf89b

Please sign in to comment.