diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index 8ba6e019..6b694230 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -1817,32 +1817,58 @@ __global__ void FastllmHalfMatMulTransBBatchKernel(uint8_t** pointer, float alph int input1Stride = (int)((size_t)pointer[id * 8 + 7]); int tid = threadIdx.x; -/* - const int pera = 8, perb = 8; - __shared__ float sa[pera][128], sb[perb][128], sc[pera][perb]; - for (int sta = 0; sta < n; sta += pera) { - for (int stb = 0; stb < k; stb += perb) { - for (int i = 0; i < pera; i++) { - if (sta + i < n) { - sa[i][tid] = (float)input0[(sta + i) * input0Stride + tid]; - } else { - sa[i][tid] = 0; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + if (m == 128) { + int wid = tid >> 5; + int perN = 8, perK = 128; + + const int BN = 8, BK = 128; + __shared__ float curC[BN][BK]; + half hscale = (half)alpha; + + for (int stN = 0; stN < n; stN += perN) { + int endN = min(n, stN + perN); + for (int stK = 0; stK < k; stK += perK) { + int endK = min(k, stK + perK); + wmma::fragment frag_a[8]; + wmma::fragment frag_b[8]; + wmma::fragment frag_c; + + wmma::fill_fragment(frag_c, 0.0); + __syncthreads(); + + #pragma unroll + for (int j = 0; j < 8; j++) { + wmma::load_matrix_sync(frag_a[j], &input0[(stN) * input0Stride + j * 16], input0Stride); } - } - for (int i = 0; i < perb; i++) { - if (stb + i < k) { - sb[i][tid] = (float)input1[(stb + i) * input1Stride + tid]; - } else { - sb[i][tid] = 0; + __syncthreads(); + + #pragma unroll + for (int j = 0; j < 8; j++) { + wmma::load_matrix_sync(frag_b[j], &input1[(stK + wid * 32) * input1Stride + j * 16], input1Stride); + } + __syncthreads(); + + #pragma unroll + for (int j = 0; j < 8; j++) { + wmma::mma_sync(frag_c, frag_a[j], frag_b[j], frag_c); + } + __syncthreads(); + + wmma::store_matrix_sync(&curC[0][wid * 32], frag_c, BK, wmma::mem_row_major); + __syncthreads(); + + if (stK + tid < endK) { + for (int i = 0; stN + i < endN; i++) { + output[(stN + i) * k + stK + tid] = (half)(curC[i][tid] * alpha); + } } + __syncthreads(); } - __syncthreads(); - - __syncthreads(); } + return; } -*/ - +#endif int pera = 4, perb = 4; half cura[4][4], curb[4][4]; float curc[4][4]; @@ -1999,8 +2025,69 @@ __global__ void FastllmHalfMatMulKernel(uint8_t** pointer, float alpha) { int k = (int)((size_t)pointer[id * 8 + 5]); int input0Stride = (int)((size_t)pointer[id * 8 + 6]); int input1Stride = (int)((size_t)pointer[id * 8 + 7]); - int tid = threadIdx.x; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + if (k == 128) { + int wid = tid >> 5; + int perN = 8, perM = 128; + for (int i = 0; i < n; i++) { + output[i * k + tid] = (half)0; + } + + __shared__ half curA[8][128]; + __shared__ float curC[8][128]; + + for (int stN = 0; stN < n; stN += perN) { + int endN = min(stN + perN, n); + wmma::fragment frag_c; + wmma::fill_fragment(frag_c, 0.0); + + for (int stM = 0; stM < m; stM += perM) { + int endM = min(stM + perM, m); + if (stM + tid < m) { + for (int i = 0; stN + i < endN; i++) { + curA[i][tid] = input0[(stN + i) * input0Stride + stM + tid]; + } + } else { + for (int i = 0; stN + i < endN; i++) { + curA[i][tid] = (half)0.0; + } + } + + wmma::fragment frag_a[8]; + wmma::fragment frag_b[8]; + __syncthreads(); + + #pragma unroll + for (int j = 0; j < 8; j++) { + wmma::load_matrix_sync(frag_a[j], &curA[0][16 * j], 128); + } + __syncthreads(); + + #pragma unroll + for (int j = 0; j < 8; j++) { + wmma::load_matrix_sync(frag_b[j], &input1[(stM + 16 * j) * input1Stride + wid * 32], input1Stride); + } + __syncthreads(); + + #pragma unroll + for (int j = 0; j < 8; j++) { + wmma::mma_sync(frag_c, frag_a[j], frag_b[j], frag_c); + } + __syncthreads(); + } + wmma::store_matrix_sync(&curC[0][wid * 32], frag_c, 128, wmma::mem_row_major); + __syncthreads(); + + for (int i = 0; stN + i < endN; i++) { + output[(stN + i) * k + tid] = (half)((float)output[(stN + i) * k + tid] + (float)curC[i][tid] * alpha); + } + __syncthreads(); + } + return; + } +#endif int pera = 4, perb = 4; float cura[4][4], curb[4][4], curc[4][4]; int cnta = (n - 1) / pera + 1, cntb = (k - 1) / perb + 1; @@ -2057,7 +2144,6 @@ __global__ void FastllmHalfMatMulKernel(uint8_t** pointer, float alpha) { } } /* - int tid = threadIdx.x; for (int i = 0; i < n; i++) { half *curInput0 = input0 + i * input0Stride; for (int j = tid; j < k; j += THREAD_PER_BLOCK) {