Skip to content

Commit

Permalink
Merge pull request ztxz16#444 from TylunasLi/no_tensor_core
Browse files Browse the repository at this point in the history
修复NVIDIA旧架构GPU编译问题,初步优化旧架构GPU性能
  • Loading branch information
ztxz16 authored Apr 16, 2024
2 parents d3dfc0a + 9ecf89b commit f3cfc63
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 9 deletions.
8 changes: 7 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ cmake_minimum_required(VERSION 3.5)

project(fastllm LANGUAGES CXX)

option(USE_CUDA "use cuda" OFF)
option(USE_CUDA "use CUDA" OFF)
option(CUDA_NO_TENSOR_CORE "Optimize for legacy CUDA GPUs which has Tensor Core." OFF)

option(PY_API "python api" OFF)

Expand All @@ -14,6 +15,8 @@ option(USE_IVCOREX "use iluvatar corex gpu" OFF)

message(STATUS "USE_CUDA: ${USE_CUDA}")

message(STATUS "For legacy CUDA GPUs: ${CUDA_NO_TENSOR_CORE}")

message(STATUS "PYTHON_API: ${PY_API}")

message(STATUS "USE_SENTENCEPIECE: ${USE_SENTENCEPIECE}")
Expand Down Expand Up @@ -56,6 +59,9 @@ endif()
if (USE_CUDA)
enable_language(CUDA)
add_compile_definitions(USE_CUDA)
if (CUDA_NO_TENSOR_CORE)
add_compile_definitions(CUDA_NO_TENSOR_CORE)
endif()
include_directories(include/devices/cuda)
#message(${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES})
set(FASTLLM_CUDA_SOURCES src/devices/cuda/cudadevice.cpp src/devices/cuda/cudadevicebatch.cpp src/devices/cuda/fastllm-cuda.cu)
Expand Down
15 changes: 15 additions & 0 deletions docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ cmake .. -DUSE_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=native
if (PY_API)
```

### identifier "__hdiv" is undefined

**现象:**

> src/devices/cuda/fastllm-cuda.cu(247): error: identifier "hexp" is undefined
> src/devices/cuda/fastllm-cuda.cu(247): error: identifier "__hdiv" is undefined
> ...
**原因:** [计算能力(Compute Capability)](https://developer.nvidia.com/cuda-gpus) <= 5.3 的GPU不支持半精度计算。

**解决办法:** 如需要支持这些GPU,执行cmake时使用编译选项`CUDA_NO_TENSOR_CORE`

```shell
cmake .. -DUSE_CUDA=ON -DCUDA_NO_TENSOR_CORE=ON
```

## Windows

Expand Down
144 changes: 136 additions & 8 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ __global__ void FastllmCudaBiasKernel(half *a, half *bias, int k) {
half *now = a + blockIdx.x * k;
int stride = blockDim.x;
for (int i = threadIdx.x; i < k; i += stride) {
#ifdef CUDA_NO_TENSOR_CORE
now[i] = __float2half(__half2float(now[i]) + __half2float(bias[i]));
#else
now[i] = __hadd(now[i], bias[i]);
#endif
}
}

Expand Down Expand Up @@ -141,8 +145,13 @@ __global__ void FastllmSwigluKernel(half* a, half *b, int len, int spatial, int
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
int id = idx / mid * spatial + idx % mid;
#ifdef CUDA_NO_TENSOR_CORE
float x = __half2float(a[id]), y = __half2float(a[id + mid]);
b[idx] = __float2half((x / (1.0 + expf(-x))) * y);
#else
half x = a[id], y = a[id + mid];
b[idx] = __hmul(__hdiv(x, __hadd(__float2half(1.0), hexp(-x))), y);
#endif
}
}

Expand All @@ -156,7 +165,11 @@ __global__ void FastllmMulKernel(float* a, float *b, float v, int len) {
__global__ void FastllmMulKernel(half* a, half *b, half v, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
#ifdef CUDA_NO_TENSOR_CORE
b[idx] = __float2half(__half2float(a[idx]) * __half2float(v));
#else
b[idx] = __hmul(a[idx], v);
#endif
}
}

Expand All @@ -180,7 +193,11 @@ __global__ void FastllmAddToKernel(float* a, float *b, float alpha, int len) {
__global__ void FastllmAddToKernel(half* a, half *b, half alpha, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
#ifdef CUDA_NO_TENSOR_CORE
a[idx] = __float2half(__half2float(a[idx]) + __half2float(b[idx]) * __half2float(alpha));
#else
a[idx] = __hadd(a[idx], __hmul(b[idx], alpha));
#endif
}
}

Expand Down Expand Up @@ -418,6 +435,18 @@ __device__ void FastllmSoftmaxKernelInner1Func(float *input, float *output, int
}
}

__device__ half FastllmHalfMaxFunc(const __half a, const __half b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
return __half2float(a) >= __half2float(b) ? a : b;
#else
#if defined(CUDART_VERSION) && CUDART_VERSION > 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hmax(a, b);
#else
return __hge(a, b) ? a : b;
#endif
#endif
}

template <int THREAD_PER_BLOCK>
__device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int channels) {
__shared__ half sdata[THREAD_PER_BLOCK];
Expand All @@ -427,15 +456,15 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch
unsigned int tid = threadIdx.x;
half maxValue = input[tid];
for (int i = tid; i < channels; i += THREAD_PER_BLOCK) {
maxValue = __hmax(maxValue, input[i]);
maxValue = FastllmHalfMaxFunc(maxValue, input[i]);
}
sdata[tid] = maxValue;
__syncthreads();

// 2. 求max
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] = __hmax(sdata[tid], sdata[tid + s]);
sdata[tid] = FastllmHalfMaxFunc(sdata[tid], sdata[tid + s]);
}
__syncthreads();
}
Expand All @@ -447,12 +476,38 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch
__syncthreads();

// 4. 求和
half sum = 0;
float sum = 0;
#ifdef CUDA_NO_TENSOR_CORE
for (int i = tid; i < channels; i += THREAD_PER_BLOCK) {
float outF = exp(__half2float(input[i]) - __half2float(maxV));
sum = sum + outF;
output[i] = __float2half(outF);
}
sdata[tid] = __float2half(sum);
__syncthreads();

for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] = __float2half(__half2float(sdata[tid]) + __half2float(sdata[tid + s]));
}
__syncthreads();
}
if (tid == 0) {
if (fabs(__half2float(sdata[0])) < 1e-6) {
sdata[0] = __float2half(0.1);
}
}
__syncthreads();

for (int i = tid; i < channels; i += THREAD_PER_BLOCK) {
output[i] = __float2half(__half2float(output[i]) / __half2float(sdata[0]));
}
#else
for (int i = tid; i < channels; i += THREAD_PER_BLOCK) {
output[i] = hexp(__hsub(input[i], maxV));
sum = __hadd(sum, output[i]);
sum = sum + __half2float(output[i]);
}
sdata[tid] = sum;
sdata[tid] = __float2half(sum);
__syncthreads();

for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
Expand All @@ -471,6 +526,7 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch
for (int i = tid; i < channels; i += THREAD_PER_BLOCK) {
output[i] = __hdiv(output[i], sdata[0]);
}
#endif
}

template <int THREAD_PER_BLOCK>
Expand Down Expand Up @@ -1340,12 +1396,19 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh
if (n >= 8) {
auto fastllmCublasHandle = getFastllmCublasHandle();
half *cudaFp16Input, *cudaFp16Output, *cudaFp16Weight;
#ifdef CUDA_NO_TENSOR_CORE
cudaFp16Input = (half *) FastllmCudaMalloc(n * m * sizeof(half));
cudaFp16Weight = (half *) FastllmCudaMalloc(k * m * sizeof(half));
float h_alpha = 1.0, h_beta = 0.0;
cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_32F, ComputeType = CUDA_R_32F;
#else
cudaFp16Input = (half *) FastllmCudaMalloc(n * m * sizeof(half));
cudaFp16Output = (half *) FastllmCudaMalloc(n * k * sizeof(half));
cudaFp16Weight = (half *) FastllmCudaMalloc(k * m * sizeof(half));

__half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0);
cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_16F, ComputeType = CUDA_R_16F;
#endif
cublasStatus_t status;

int len = n * m;
Expand All @@ -1358,6 +1421,16 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh
cudaZeropoints,
cudaFp16Weight, len, m);

#ifdef CUDA_NO_TENSOR_CORE
status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
&h_alpha, cudaFp16Weight, AType,
m, cudaFp16Input, BType,
m, &h_beta,
cudaOutput, CType,
k, ComputeType, static_cast<cublasGemmAlgo_t>(CUBLAS_GEMM_DEFAULT));
#else
status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
Expand All @@ -1366,19 +1439,26 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh
m, &h_beta,
cudaFp16Output, CType,
k, ComputeType, static_cast<cublasGemmAlgo_t>(CUBLAS_GEMM_DEFAULT));
#endif
if (status != CUBLAS_STATUS_SUCCESS) {
printf("Error: cublas error.\n");
throw("cublas error");
exit(0);
}

len = n * k;
#ifdef CUDA_NO_TENSOR_CORE
FastllmCudaBiasKernel <<< n, 256 >>> (cudaOutput, cudaBiasData, k);
FastllmCudaFree(cudaFp16Input);
FastllmCudaFree(cudaFp16Weight);
#else
FastllmCudaHalf2FlotaKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock >>>(cudaFp16Output, cudaOutput, len);
FastllmCudaBiasKernel <<< n, 256 >>> (cudaOutput, cudaBiasData, k);

FastllmCudaFree(cudaFp16Input);
FastllmCudaFree(cudaFp16Output);
FastllmCudaFree(cudaFp16Weight);
#endif
} else {
for (int i = 0; i < n; i++) {
FastllmGemvInt8Kernel2<256, 1> <<< k, 256 >>>(cudaInput + i * m,
Expand Down Expand Up @@ -1582,12 +1662,20 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data
if (n >= 8) {
auto fastllmCublasHandle = getFastllmCublasHandle();
half *cudaFp16Input, *cudaFp16Output, *cudaFp16Weight;
#ifdef CUDA_NO_TENSOR_CORE
cudaFp16Input = (half *) FastllmCudaMalloc(n * m * sizeof(half));
cudaFp16Weight = (half *) FastllmCudaMalloc(k * m * sizeof(half));

float h_alpha = 1.0, h_beta = 0.0;
cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_32F, ComputeType = CUDA_R_32F;
#else
cudaFp16Input = (half *) FastllmCudaMalloc(n * m * sizeof(half));
cudaFp16Output = (half *) FastllmCudaMalloc(n * k * sizeof(half));
cudaFp16Weight = (half *) FastllmCudaMalloc(k * m * sizeof(half));

__half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0);
cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_16F, ComputeType = CUDA_R_16F;
#endif
cublasStatus_t status;

int len = n * m;
Expand All @@ -1601,6 +1689,16 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data
cudaMins,
cudaFp16Weight, len, m);

#ifdef CUDA_NO_TENSOR_CORE
status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
&h_alpha, cudaFp16Weight, AType,
m, cudaFp16Input, BType,
m, &h_beta,
cudaOutput, CType,
k, ComputeType, static_cast<cublasGemmAlgo_t>(CUBLAS_GEMM_DEFAULT));
#else
status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
Expand All @@ -1609,20 +1707,27 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data
m, &h_beta,
cudaFp16Output, CType,
k, ComputeType, static_cast<cublasGemmAlgo_t>(CUBLAS_GEMM_DEFAULT));
#endif
if (status != CUBLAS_STATUS_SUCCESS) {
printf("Error: cublas error.\n");
throw("cublas error");
exit(0);
}

len = n * k;
#ifdef CUDA_NO_TENSOR_CORE
FastllmCudaBiasKernel <<< n, 256 >>>(cudaOutput, cudaBiasData, k);
FastllmCudaFree(cudaFp16Input);
FastllmCudaFree(cudaFp16Weight);
#else
FastllmCudaHalf2FlotaKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock >>>(cudaFp16Output, cudaOutput,
len);
FastllmCudaBiasKernel <<< n, 256 >>>(cudaOutput, cudaBiasData, k);

FastllmCudaFree(cudaFp16Input);
FastllmCudaFree(cudaFp16Output);
FastllmCudaFree(cudaFp16Weight);
#endif
} else {
for (int i = 0; i < n; i++) {
FastllmGemvInt4NoZeroKernel2<256, 1> <<< k, 256 >>>(cudaInput + i * m,
Expand Down Expand Up @@ -1708,21 +1813,38 @@ bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight,
float *cudaOutput = (float*)FastllmCudaPrepareOutput(output);

if (n > 1) {
auto fastllmCublasHandle = getFastllmCublasHandle();
//cudaDeviceSynchronize();
half *cudaFp16Input, *cudaFp16Output;
#ifdef CUDA_NO_TENSOR_CORE
cudaFp16Input = (half *) FastllmCudaMalloc(n * m * sizeof(half));

float h_alpha = 1.0, h_beta = 0.0;
cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_32F, ComputeType = CUDA_R_32F;
#else
cudaFp16Input = (half *) FastllmCudaMalloc(n * m * sizeof(half));
cudaFp16Output = (half *) FastllmCudaMalloc(n * k * sizeof(half));

__half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0);
auto fastllmCublasHandle = getFastllmCublasHandle();
//cudaDeviceSynchronize();
cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_16F, ComputeType = CUDA_R_16F;
#endif
cublasStatus_t status;

int len = n * m;
int threadPerBlock = std::min(256, len);
FastllmCudaFloat2HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaFp16Input,
len);

#ifdef CUDA_NO_TENSOR_CORE
status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
&h_alpha, (half *) weight.cudaData, AType,
m, cudaFp16Input, BType,
m, &h_beta,
cudaOutput, CType,
k, ComputeType, static_cast<cublasGemmAlgo_t>(CUBLAS_GEMM_DEFAULT));
#else
status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
Expand All @@ -1731,20 +1853,26 @@ bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight,
m, &h_beta,
cudaFp16Output, CType,
k, ComputeType, static_cast<cublasGemmAlgo_t>(CUBLAS_GEMM_DEFAULT));
#endif
if (status != CUBLAS_STATUS_SUCCESS) {
printf("Error: cublas error.\n");
throw("cublas error");
exit(0);
}

len = n * k;
#ifdef CUDA_NO_TENSOR_CORE
FastllmCudaBiasKernel <<< n, 256 >>> (cudaOutput, (float*)weight.extraCudaData[0], k);
FastllmCudaFree(cudaFp16Input);
#else
FastllmCudaHalf2FlotaKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock >>>(cudaFp16Output, cudaOutput,
len);
FastllmCudaBiasKernel <<< n, 256 >>> (cudaOutput, (float*)weight.extraCudaData[0], k);
//cudaDeviceSynchronize();

FastllmCudaFree(cudaFp16Input);
FastllmCudaFree(cudaFp16Output);
#endif
} else {
FastllmGemvFp32Fp16Kernel2<256, 1> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
}
Expand Down Expand Up @@ -3031,7 +3159,7 @@ bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &wei
FastllmCudaFloat2HalfKernel <<< (k - 1) / threadPerBlock + 1, threadPerBlock>>>(tempBiasData, cudaBiasData, k);
state = cudaFree(tempBiasData);
} else {
state = cudaMemset(cudaBiasData, __float2half_rn(0.0), k * sizeof(half));
state = cudaMemset(cudaBiasData, 0, k * sizeof(half));
}
checkCudaErrors("Error: CUDA error when moving bias to device!", state);
weight.extraCudaHalfData.push_back((void *) cudaBiasData);
Expand Down

0 comments on commit f3cfc63

Please sign in to comment.