diff --git a/features/feature_case/cublasLt/matmul.cu b/features/feature_case/cublasLt/matmul.cu index 0cf382f7a..f0677c353 100644 --- a/features/feature_case/cublasLt/matmul.cu +++ b/features/feature_case/cublasLt/matmul.cu @@ -722,6 +722,150 @@ bool test6() { return !error; } +void fgemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, + const float *A, const float *B, const float *C, float *D, + float *alpha, float *beta, + int lda, int ldb, int ldc, int ldd, + cublasLtMatrixLayout_t Adesc, + cublasLtMatrixLayout_t Bdesc, + cublasLtMatrixLayout_t Cdesc, + cublasLtMatrixLayout_t Ddesc, + float *amax_d) { + cublasLtMatmulDesc_t matmulDesc = NULL; + cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); + + float *scale_a; + float *scale_b; + float *scale_d; + cudaMallocManaged(&scale_a, sizeof(float)); + cudaMallocManaged(&scale_b, sizeof(float)); + cudaMallocManaged(&scale_d, sizeof(float)); + scale_a[0] = 3; + scale_b[0] = 5; + scale_d[0] = 7; + + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(scale_a)); + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(scale_b)); + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scale_d, sizeof(scale_d)); + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &amax_d, sizeof(amax_d)); + + cublasLtEpilogue_t ep = CUBLASLT_EPILOGUE_RELU; + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &ep, sizeof(ep)); + + cublasLtMatmul(ltHandle, matmulDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, NULL, NULL, 0, 0); + + cudaStreamSynchronize(0); + cublasLtMatmulDescDestroy(matmulDesc); +} + +// clang-format off +// A (4*3) B (3*2) +// 6 10 14 5 4 +// 7 11 15 -3 -2 +// 8 12 16 1 0 +// 9 13 17 p p +// +// alpha * A * B + C = alpha * A*B + C = D +// 2*3*5 6 10 14 5 4 -10000 -5000 30 14 4 -10000 -5000 -9580 -4880 +// 7 11 15 -3 -2 2000 6000 17 6 2000 6000 2510 6180 +// 8 12 16 1 0 3000 7000 20 8 3000 7000 3600 7240 +// 9 13 17 p p 4000 8000 23 10 4000 8000 4690 8300 +// scale_d * D = D +// 7 * -9580 -4880 -67060 -34160 +// 2510 6180 17570 43260 +// 3600 7240 25200 50680 +// 4690 8300 32830 58100 +// clang-format on + +bool test7() { + cublasLtHandle_t ltHandle; + cublasLtCreate(<Handle); + const constexpr int m = 4; + const constexpr int n = 2; + const constexpr int k = 3; + const constexpr int lda = m; + const constexpr int ldb = m; + const constexpr int ldc = m; + const constexpr int ldd = m; + void *Adev; + void *Bdev; + void *Cdev; + void *Ddev; + cudaMalloc(&Adev, lda * k * sizeof(float)); + cudaMalloc(&Bdev, ldb * n * sizeof(float)); + cudaMalloc(&Cdev, ldc * n * sizeof(float)); + cudaMalloc(&Ddev, ldd * n * sizeof(float)); + + float Ahost[lda * k] = {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}; + float Bhost[ldb * n] = {5, -3, 1, 99, 4, -2, 0, 99}; + float Chost[ldc * n] = {-1000, 2000, 3000, 4000, -5000, 6000, 7000, 8000}; + + cudaMemcpy(Adev, Ahost, lda * k * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(Bdev, Bhost, ldb * n * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(Cdev, Chost, ldc * n * sizeof(float), cudaMemcpyHostToDevice); + + cublasLtMatrixLayout_t Adesc_col_major = NULL, + Bdesc_col_major = NULL, + Cdesc_col_major = NULL, + Ddesc_col_major = NULL; + cublasLtMatrixLayoutCreate(&Adesc_col_major, CUDA_R_32F, m, k, lda); + cublasLtMatrixLayoutCreate(&Bdesc_col_major, CUDA_R_32F, k, n, ldb); + cublasLtMatrixLayoutCreate(&Cdesc_col_major, CUDA_R_32F, m, n, ldc); + cublasLtMatrixLayoutCreate(&Ddesc_col_major, CUDA_R_32F, m, n, ldd); + + float alpha = 2; + float beta = 1; + + // Matmul + + float *amax_d; + cudaMallocManaged(&amax_d, sizeof(float)); + + fgemmlt(ltHandle, m, n, k, (const float *)Adev, (const float *)Bdev, (const float *)Cdev, (float *)Ddev, + &alpha, &beta, lda, ldb, ldc, ldd, Adesc_col_major, Bdesc_col_major, Cdesc_col_major, Ddesc_col_major, amax_d); + cudaStreamSynchronize(0); + + // Check result + float Dhost[ldd * n]; + cudaMemcpy(Dhost, Ddev, ldd * n * sizeof(float), cudaMemcpyDeviceToHost); + + bool error = false; + float D_ref[ldd * n] = {0, 17570, 25200, 32830, 0, 43260, 50680, 58100}; + for (int i = 0; i < ldd * n; i++) { + if (Dhost[i] != D_ref[i]) { + error = true; + break; + } + } + if (*amax_d != 8300) + error = true; + + printf("d:\n"); + for (int i = 0; i < ldd * n; i++) + printf("%f, ", Dhost[i]); + printf("\n"); + printf("amax_d:%f\n", *amax_d); + + if (error) { + printf("error\n"); + } else { + printf("success\n"); + } + + cublasLtDestroy(ltHandle); + cublasLtMatrixLayoutDestroy(Adesc_col_major); + cublasLtMatrixLayoutDestroy(Bdesc_col_major); + cublasLtMatrixLayoutDestroy(Cdesc_col_major); + cublasLtMatrixLayoutDestroy(Ddesc_col_major); + cudaFree(Adev); + cudaFree(Bdev); + cudaFree(Ddev); + cudaFree(amax_d); + + return !error; +} + + // clang-format off // A (4*3) B (2*3) // 6 10 14 5 -3 1 @@ -750,5 +894,6 @@ int main() { pass = test4() && pass; pass = test5() && pass; pass = test6() && pass; + pass = test7() && pass; return pass ? 0 : 1; }