From da8cd10dc2e25ff388a26d59da3f4eba8e93ea07 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Thu, 27 Jun 2024 16:42:59 +0800 Subject: [PATCH] [SYCLomatic #1993] Add test for cublasLt API migration (#729) Signed-off-by: Jiang, Zhiwei --- features/config/TEMPLATE_cublasLt.xml | 13 + features/feature_case/cublasLt/matmul.cu | 754 ++++++++++++++++++ features/feature_case/cublasLt/transform.cu | 600 ++++++++++++++ features/features.xml | 2 + features/test_feature.py | 4 +- help_function/help_function.xml | 1 + .../src/blas_gemm_utils_interface.cpp | 152 ++++ help_function/test_help.py | 2 +- 8 files changed, 1525 insertions(+), 3 deletions(-) create mode 100644 features/config/TEMPLATE_cublasLt.xml create mode 100644 features/feature_case/cublasLt/matmul.cu create mode 100644 features/feature_case/cublasLt/transform.cu create mode 100644 help_function/src/blas_gemm_utils_interface.cpp diff --git a/features/config/TEMPLATE_cublasLt.xml b/features/config/TEMPLATE_cublasLt.xml new file mode 100644 index 000000000..cb15eeed1 --- /dev/null +++ b/features/config/TEMPLATE_cublasLt.xml @@ -0,0 +1,13 @@ + + + + test + + + + + + + + + diff --git a/features/feature_case/cublasLt/matmul.cu b/features/feature_case/cublasLt/matmul.cu new file mode 100644 index 000000000..0cf382f7a --- /dev/null +++ b/features/feature_case/cublasLt/matmul.cu @@ -0,0 +1,754 @@ +// ===------------ matmul.cu ----------------------------- *- CUDA -* ----=== // +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // + +#include +#include +#include + +const constexpr int COL_TURING = 0; +const constexpr int COL_AMPERE = 1; + +// The original source of below two functions was under the license below: +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. +// +// Repo: https://github.com/TimDettmers/bitsandbytes.git +inline int checkCublasStatus(cublasStatus_t status) { + if (status != CUBLAS_STATUS_SUCCESS) { + printf("cuBLAS API failed with status %d\n", status); + //throw std::logic_error("cuBLAS API failed"); + return 1; + } + return 0; +} + +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) +{ + int has_error = 0; + cublasLtMatmulDesc_t matmulDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + cublasOperation_t opT = CUBLAS_OP_T; + cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C; + cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4; + + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb)); + + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + if(FORMATB == COL_TURING) + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing))); + else + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere))); + + if(DTYPE_OUT == 32) + { + has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I)); + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + int alpha = 1, beta = 0; + has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0)); + } + else + { + has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F)); + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + if(!SCALE_ROWS) + { + float alpha = 1.0f, beta = 0.0f; + has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); + } + else + { + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); + has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); + } + } + + cudaStreamSynchronize(0); + + if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc)); + if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc)); + if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc)); + if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); + if(has_error == 1) + printf("error detected"); + + return has_error; +} + +void transform(cublasLtHandle_t ltHandle, const void *in, int ld_in, + cublasLtMatrixLayout_t layout_in, void *out, int ld_out, + cublasLtMatrixLayout_t layout_out) { + cublasLtMatrixTransformDesc_t transform_desc = NULL; + cublasLtMatrixTransformDescCreate(&transform_desc, CUDA_R_32F); + float alpha = 1.0f, beta = 0.0f; + cublasLtMatrixTransform(ltHandle, transform_desc, &alpha, in, layout_in, + &beta, NULL, NULL, out, layout_out, 0); + cublasLtMatrixTransformDescDestroy(transform_desc); +} + +// igemmlt +bool test1() { + cublasLtHandle_t ltHandle; + cublasLtCreate(<Handle); + const constexpr int m = 4; + const constexpr int n = 2; + const constexpr int k = 3; + int lda = m; + int ldb = n; + int ldc = m; + void *Adev; + void *Bdev; + void *Cdev; + cudaMalloc(&Adev, m * k * sizeof(int8_t)); + cudaMalloc(&Bdev, n * k * sizeof(int8_t)); + cudaMalloc(&Cdev, m * n * sizeof(int32_t)); + + int8_t Ahost[m * k] = {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}; + int8_t Bhost[n * k] = {5, 4, -3, -2, 1, 0}; + + cudaMemcpy(Adev, Ahost, m * k * sizeof(int8_t), cudaMemcpyHostToDevice); + cudaMemcpy(Bdev, Bhost, n * k * sizeof(int8_t), cudaMemcpyHostToDevice); + + cublasLtMatrixLayout_t Adesc_col_major = NULL, Bdesc_col_major = NULL, + Cdesc_col_major = NULL; + cublasLtMatrixLayoutCreate(&Adesc_col_major, CUDA_R_8I, m, k, lda); + cublasLtMatrixLayoutCreate(&Bdesc_col_major, CUDA_R_8I, n, k, ldb); + cublasLtMatrixLayoutCreate(&Cdesc_col_major, CUDA_R_32I, m, n, ldc); + + // Convert A and B + cublasLtMatrixLayout_t Adesc_col32 = NULL, Bdesc_col4_4r2_8c = NULL, + Cdesc_col32 = NULL; + int8_t *A_col32, *B_col4_4r2_8c; + int32_t *C_col32; + cudaMalloc(&A_col32, m * 32 * sizeof(std::int8_t)); + cudaMalloc(&B_col4_4r2_8c, ((n + 8 - 1) / 8) * 8 * 32 * sizeof(std::int8_t)); + cudaMalloc(&C_col32, m * 32 * sizeof(std::int32_t)); + cublasLtMatrixLayoutCreate(&Adesc_col32, CUDA_R_8I, m, k, m * 32); + cublasLtMatrixLayoutCreate(&Bdesc_col4_4r2_8c, CUDA_R_8I, k, n, + ((n + 8 - 1) / 8) * 8 * 32); + cublasLtMatrixLayoutCreate(&Cdesc_col32, CUDA_R_32I, m, n, m * 32); + cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t col4_4r2_8c = CUBLASLT_ORDER_COL4_4R2_8C; + cublasLtMatrixLayoutSetAttribute(Adesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER, + &col32, sizeof(col32)); + cublasLtMatrixLayoutSetAttribute(Bdesc_col4_4r2_8c, + CUBLASLT_MATRIX_LAYOUT_ORDER, &col4_4r2_8c, + sizeof(col4_4r2_8c)); + cublasLtMatrixLayoutSetAttribute(Cdesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER, + &col32, sizeof(col32)); + + transform(ltHandle, Adev, lda, Adesc_col_major, A_col32, m * 32, Adesc_col32); + transform(ltHandle, Bdev, ldb, Bdesc_col_major, B_col4_4r2_8c, 8 * 32, + Bdesc_col4_4r2_8c); + + // Matmul + igemmlt(ltHandle, m, n, k, A_col32, B_col4_4r2_8c, C_col32, + nullptr, m * 32, ((n + 8 - 1) / 8) * 8 * 32, + m * 32); + + // Convert C + transform(ltHandle, C_col32, m * 32, Cdesc_col32, Cdev, ldc, Cdesc_col_major); + cudaStreamSynchronize(0); + + // Check result + int32_t Chost[m * n]; + cudaMemcpy(Chost, Cdev, m * n * sizeof(int32_t), cudaMemcpyDeviceToHost); + + bool error = false; + int32_t C_ref[m * n] = {14, 17, 20, 23, 4, 6, 8, 10}; + for (int i = 0; i < m * n; i++) { + if (Chost[i] != C_ref[i]) { + error = true; + break; + } + } + printf("c:\n"); + for (int i = 0; i < m * n; i++) + printf("%d, ", Chost[i]); + printf("\n"); + + if (error) { + printf("error\n"); + } else { + printf("success\n"); + } + + cublasLtDestroy(ltHandle); + cublasLtMatrixLayoutDestroy(Adesc_col32); + cublasLtMatrixLayoutDestroy(Bdesc_col4_4r2_8c); + cublasLtMatrixLayoutDestroy(Cdesc_col32); + cublasLtMatrixLayoutDestroy(Adesc_col_major); + cublasLtMatrixLayoutDestroy(Bdesc_col_major); + cublasLtMatrixLayoutDestroy(Cdesc_col_major); + cudaFree(Adev); + cudaFree(Bdev); + cudaFree(Cdev); + + return !error; +} + +// igemmlt +bool test2() { + cublasLtHandle_t ltHandle; + cublasLtCreate(<Handle); + const constexpr int m = 4; + const constexpr int n = 2; + const constexpr int k = 3; + int lda = m; + int ldb = n; + int ldc = m; + void *Adev; + void *Bdev; + void *Cdev; + cudaMalloc(&Adev, m * k * sizeof(int8_t)); + cudaMalloc(&Bdev, n * k * sizeof(int8_t)); + cudaMalloc(&Cdev, m * n * sizeof(int8_t)); + + int8_t Ahost[m * k] = {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}; + int8_t Bhost[n * k] = {5, 4, -3, -2, 1, 0}; + + cudaMemcpy(Adev, Ahost, m * k * sizeof(int8_t), cudaMemcpyHostToDevice); + cudaMemcpy(Bdev, Bhost, n * k * sizeof(int8_t), cudaMemcpyHostToDevice); + + cublasLtMatrixLayout_t Adesc_col_major = NULL, Bdesc_col_major = NULL, + Cdesc_col_major = NULL; + cublasLtMatrixLayoutCreate(&Adesc_col_major, CUDA_R_8I, m, k, lda); + cublasLtMatrixLayoutCreate(&Bdesc_col_major, CUDA_R_8I, n, k, ldb); + cublasLtMatrixLayoutCreate(&Cdesc_col_major, CUDA_R_8I, m, n, ldc); + + // Convert A and B + cublasLtMatrixLayout_t Adesc_col32 = NULL, Bdesc_col4_4r2_8c = NULL, + Cdesc_col32 = NULL; + int8_t *A_col32, *B_col4_4r2_8c; + int8_t *C_col32; + cudaMalloc(&A_col32, m * 32 * sizeof(std::int8_t)); + cudaMalloc(&B_col4_4r2_8c, ((n + 8 - 1) / 8) * 8 * 32 * sizeof(std::int8_t)); + cudaMalloc(&C_col32, m * 32 * sizeof(std::int8_t)); + cublasLtMatrixLayoutCreate(&Adesc_col32, CUDA_R_8I, m, k, m * 32); + cublasLtMatrixLayoutCreate(&Bdesc_col4_4r2_8c, CUDA_R_8I, k, n, + ((n + 8 - 1) / 8) * 8 * 32); + cublasLtMatrixLayoutCreate(&Cdesc_col32, CUDA_R_8I, m, n, m * 32); + cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t col4_4r2_8c = CUBLASLT_ORDER_COL4_4R2_8C; + cublasLtMatrixLayoutSetAttribute(Adesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER, + &col32, sizeof(col32)); + cublasLtMatrixLayoutSetAttribute(Bdesc_col4_4r2_8c, + CUBLASLT_MATRIX_LAYOUT_ORDER, &col4_4r2_8c, + sizeof(col4_4r2_8c)); + cublasLtMatrixLayoutSetAttribute(Cdesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER, + &col32, sizeof(col32)); + + transform(ltHandle, Adev, lda, Adesc_col_major, A_col32, m * 32, Adesc_col32); + transform(ltHandle, Bdev, ldb, Bdesc_col_major, B_col4_4r2_8c, 8 * 32, + Bdesc_col4_4r2_8c); + + // Matmul + igemmlt(ltHandle, m, n, k, A_col32, B_col4_4r2_8c, C_col32, + nullptr, m * 32, ((n + 8 - 1) / 8) * 8 * 32, + m * 32); + + // Convert C + transform(ltHandle, C_col32, m * 32, Cdesc_col32, Cdev, ldc, Cdesc_col_major); + cudaStreamSynchronize(0); + + // Check result + int8_t Chost[m * n]; + cudaMemcpy(Chost, Cdev, m * n * sizeof(int8_t), cudaMemcpyDeviceToHost); + + bool error = false; + int8_t C_ref[m * n] = {14, 17, 20, 23, 4, 6, 8, 10}; + for (int i = 0; i < m * n; i++) { + if (Chost[i] != C_ref[i]) { + error = true; + break; + } + } + printf("c:\n"); + for (int i = 0; i < m * n; i++) + printf("%d, ", Chost[i]); + printf("\n"); + + if (error) { + printf("error\n"); + } else { + printf("success\n"); + } + + cublasLtDestroy(ltHandle); + cublasLtMatrixLayoutDestroy(Adesc_col32); + cublasLtMatrixLayoutDestroy(Bdesc_col4_4r2_8c); + cublasLtMatrixLayoutDestroy(Cdesc_col32); + cublasLtMatrixLayoutDestroy(Adesc_col_major); + cublasLtMatrixLayoutDestroy(Bdesc_col_major); + cublasLtMatrixLayoutDestroy(Cdesc_col_major); + cudaFree(Adev); + cudaFree(Bdev); + cudaFree(Cdev); + + return !error; +} + +// igemmlt +bool test3() { + cublasLtHandle_t ltHandle; + cublasLtCreate(<Handle); + const constexpr int m = 4; + const constexpr int n = 2; + const constexpr int k = 3; + int lda = m; + int ldb = n; + int ldc = m; + void *Adev; + void *Bdev; + void *Cdev; + cudaMalloc(&Adev, m * k * sizeof(int8_t)); + cudaMalloc(&Bdev, n * k * sizeof(int8_t)); + cudaMalloc(&Cdev, m * n * sizeof(int8_t)); + + int8_t Ahost[m * k] = {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}; + int8_t Bhost[n * k] = {5, 4, -3, -2, 1, 0}; + + cudaMemcpy(Adev, Ahost, m * k * sizeof(int8_t), cudaMemcpyHostToDevice); + cudaMemcpy(Bdev, Bhost, n * k * sizeof(int8_t), cudaMemcpyHostToDevice); + + cublasLtMatrixLayout_t Adesc_col_major = NULL, Bdesc_col_major = NULL, + Cdesc_col_major = NULL; + cublasLtMatrixLayoutCreate(&Adesc_col_major, CUDA_R_8I, m, k, lda); + cublasLtMatrixLayoutCreate(&Bdesc_col_major, CUDA_R_8I, n, k, ldb); + cublasLtMatrixLayoutCreate(&Cdesc_col_major, CUDA_R_8I, m, n, ldc); + + // Convert A and B + cublasLtMatrixLayout_t Adesc_col32 = NULL, Bdesc_col4_4r2_8c = NULL, + Cdesc_col32 = NULL; + int8_t *A_col32, *B_col4_4r2_8c; + int8_t *C_col32; + cudaMalloc(&A_col32, m * 32 * sizeof(std::int8_t)); + cudaMalloc(&B_col4_4r2_8c, ((n + 8 - 1) / 8) * 8 * 32 * sizeof(std::int8_t)); + cudaMalloc(&C_col32, m * 32 * sizeof(std::int8_t)); + cublasLtMatrixLayoutCreate(&Adesc_col32, CUDA_R_8I, m, k, m * 32); + cublasLtMatrixLayoutCreate(&Bdesc_col4_4r2_8c, CUDA_R_8I, k, n, + ((n + 8 - 1) / 8) * 8 * 32); + cublasLtMatrixLayoutCreate(&Cdesc_col32, CUDA_R_8I, m, n, m * 32); + cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t col4_4r2_8c = CUBLASLT_ORDER_COL4_4R2_8C; + cublasLtMatrixLayoutSetAttribute(Adesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER, + &col32, sizeof(col32)); + cublasLtMatrixLayoutSetAttribute(Bdesc_col4_4r2_8c, + CUBLASLT_MATRIX_LAYOUT_ORDER, &col4_4r2_8c, + sizeof(col4_4r2_8c)); + cublasLtMatrixLayoutSetAttribute(Cdesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER, + &col32, sizeof(col32)); + + transform(ltHandle, Adev, lda, Adesc_col_major, A_col32, m * 32, Adesc_col32); + transform(ltHandle, Bdev, ldb, Bdesc_col_major, B_col4_4r2_8c, 8 * 32, + Bdesc_col4_4r2_8c); + + float *alpha; + cudaMallocManaged(&alpha, 4 * sizeof(float)); + alpha[0] = 0; + alpha[1] = 1; + alpha[2] = 2; + alpha[3] = 3; + + // Matmul + igemmlt(ltHandle, m, n, k, A_col32, B_col4_4r2_8c, C_col32, + alpha, m * 32, ((n + 8 - 1) / 8) * 8 * 32, m * 32); + + // Convert C + transform(ltHandle, C_col32, m * 32, Cdesc_col32, Cdev, ldc, Cdesc_col_major); + cudaStreamSynchronize(0); + + // Check result + int8_t Chost[m * n]; + cudaMemcpy(Chost, Cdev, m * n * sizeof(int8_t), cudaMemcpyDeviceToHost); + + bool error = false; + int8_t C_ref[m * n] = {0, 17, 40, 69, 0, 6, 16, 30}; + for (int i = 0; i < m * n; i++) { + if (Chost[i] != C_ref[i]) { + error = true; + break; + } + } + printf("c:\n"); + for (int i = 0; i < m * n; i++) + printf("%d, ", Chost[i]); + printf("\n"); + + if (error) { + printf("error\n"); + } else { + printf("success\n"); + } + + cublasLtDestroy(ltHandle); + cublasLtMatrixLayoutDestroy(Adesc_col32); + cublasLtMatrixLayoutDestroy(Bdesc_col4_4r2_8c); + cublasLtMatrixLayoutDestroy(Cdesc_col32); + cublasLtMatrixLayoutDestroy(Adesc_col_major); + cublasLtMatrixLayoutDestroy(Bdesc_col_major); + cublasLtMatrixLayoutDestroy(Cdesc_col_major); + cudaFree(Adev); + cudaFree(Bdev); + cudaFree(Cdev); + cudaFree(alpha); + + return !error; +} + +// igemmlt +bool test4() { + cublasLtHandle_t ltHandle; + cublasLtCreate(<Handle); + const constexpr int m = 4; + const constexpr int n = 2; + const constexpr int k = 3; + int lda = m; + int ldb = n; + int ldc = m; + void *Adev; + void *Bdev; + void *Cdev; + cudaMalloc(&Adev, m * k * sizeof(int8_t)); + cudaMalloc(&Bdev, n * k * sizeof(int8_t)); + cudaMalloc(&Cdev, m * n * sizeof(int32_t)); + + int8_t Ahost[m * k] = {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}; + int8_t Bhost[n * k] = {5, 4, -3, -2, 1, 0}; + + cudaMemcpy(Adev, Ahost, m * k * sizeof(int8_t), cudaMemcpyHostToDevice); + cudaMemcpy(Bdev, Bhost, n * k * sizeof(int8_t), cudaMemcpyHostToDevice); + + cublasLtMatrixLayout_t Adesc_col_major = NULL, Bdesc_col_major = NULL, + Cdesc_col_major = NULL; + cublasLtMatrixLayoutCreate(&Adesc_col_major, CUDA_R_8I, m, k, lda); + cublasLtMatrixLayoutCreate(&Bdesc_col_major, CUDA_R_8I, n, k, ldb); + cublasLtMatrixLayoutCreate(&Cdesc_col_major, CUDA_R_32I, m, n, ldc); + + // Convert A and B + cublasLtMatrixLayout_t Adesc_col32 = NULL, Bdesc_col32_2r_4r4 = NULL, + Cdesc_col32 = NULL; + int8_t *A_col32, *B_col32_2r_4r4; + int32_t *C_col32; + cudaMalloc(&A_col32, m * 32 * sizeof(std::int8_t)); + cudaMalloc(&B_col32_2r_4r4, + ((n + 32 - 1) / 32) * 32 * 32 * sizeof(std::int8_t)); + cudaMalloc(&C_col32, m * 32 * sizeof(std::int32_t)); + cublasLtMatrixLayoutCreate(&Adesc_col32, CUDA_R_8I, m, k, m * 32); + cublasLtMatrixLayoutCreate(&Bdesc_col32_2r_4r4, CUDA_R_8I, k, n, + ((n + 32 - 1) / 32) * 32 * 32); + cublasLtMatrixLayoutCreate(&Cdesc_col32, CUDA_R_32I, m, n, m * 32); + cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t col32_2r_4r4 = CUBLASLT_ORDER_COL32_2R_4R4; + cublasLtMatrixLayoutSetAttribute(Adesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER, + &col32, sizeof(col32)); + cublasLtMatrixLayoutSetAttribute(Bdesc_col32_2r_4r4, + CUBLASLT_MATRIX_LAYOUT_ORDER, &col32_2r_4r4, + sizeof(col32_2r_4r4)); + cublasLtMatrixLayoutSetAttribute(Cdesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER, + &col32, sizeof(col32)); + + transform(ltHandle, Adev, lda, Adesc_col_major, A_col32, m * 32, Adesc_col32); + transform(ltHandle, Bdev, ldb, Bdesc_col_major, B_col32_2r_4r4, 8 * 32, + Bdesc_col32_2r_4r4); + + // Matmul + igemmlt(ltHandle, m, n, k, A_col32, B_col32_2r_4r4, + C_col32, nullptr, m * 32, + ((n + 8 - 1) / 8) * 8 * 32, m * 32); + + // Convert C + transform(ltHandle, C_col32, m * 32, Cdesc_col32, Cdev, ldc, Cdesc_col_major); + cudaStreamSynchronize(0); + + // Check result + int32_t Chost[m * n]; + cudaMemcpy(Chost, Cdev, m * n * sizeof(int32_t), cudaMemcpyDeviceToHost); + + bool error = false; + int32_t C_ref[m * n] = {14, 17, 20, 23, 4, 6, 8, 10}; + for (int i = 0; i < m * n; i++) { + if (Chost[i] != C_ref[i]) { + error = true; + break; + } + } + printf("c:\n"); + for (int i = 0; i < m * n; i++) + printf("%d, ", Chost[i]); + printf("\n"); + + if (error) { + printf("error\n"); + } else { + printf("success\n"); + } + + cublasLtDestroy(ltHandle); + cublasLtMatrixLayoutDestroy(Adesc_col32); + cublasLtMatrixLayoutDestroy(Bdesc_col32_2r_4r4); + cublasLtMatrixLayoutDestroy(Cdesc_col32); + cublasLtMatrixLayoutDestroy(Adesc_col_major); + cublasLtMatrixLayoutDestroy(Bdesc_col_major); + cublasLtMatrixLayoutDestroy(Cdesc_col_major); + cudaFree(Adev); + cudaFree(Bdev); + cudaFree(Cdev); + + return !error; +} + +// igemmlt +bool test5() { + cublasLtHandle_t ltHandle; + cublasLtCreate(<Handle); + const constexpr int m = 4; + const constexpr int n = 2; + const constexpr int k = 3; + int lda = m; + int ldb = n; + int ldc = m; + void *Adev; + void *Bdev; + void *Cdev; + cudaMalloc(&Adev, m * k * sizeof(int8_t)); + cudaMalloc(&Bdev, n * k * sizeof(int8_t)); + cudaMalloc(&Cdev, m * n * sizeof(int8_t)); + + int8_t Ahost[m * k] = {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}; + int8_t Bhost[n * k] = {5, 4, -3, -2, 1, 0}; + + cudaMemcpy(Adev, Ahost, m * k * sizeof(int8_t), cudaMemcpyHostToDevice); + cudaMemcpy(Bdev, Bhost, n * k * sizeof(int8_t), cudaMemcpyHostToDevice); + + cublasLtMatrixLayout_t Adesc_col_major = NULL, Bdesc_col_major = NULL, + Cdesc_col_major = NULL; + cublasLtMatrixLayoutCreate(&Adesc_col_major, CUDA_R_8I, m, k, lda); + cublasLtMatrixLayoutCreate(&Bdesc_col_major, CUDA_R_8I, n, k, ldb); + cublasLtMatrixLayoutCreate(&Cdesc_col_major, CUDA_R_8I, m, n, ldc); + + // Convert A and B + cublasLtMatrixLayout_t Adesc_col32 = NULL, Bdesc_col32_2r_4r4 = NULL, + Cdesc_col32 = NULL; + int8_t *A_col32, *B_col32_2r_4r4; + int8_t *C_col32; + cudaMalloc(&A_col32, m * 32 * sizeof(std::int8_t)); + cudaMalloc(&B_col32_2r_4r4, + ((n + 32 - 1) / 32) * 32 * 32 * sizeof(std::int8_t)); + cudaMalloc(&C_col32, m * 32 * sizeof(std::int8_t)); + cublasLtMatrixLayoutCreate(&Adesc_col32, CUDA_R_8I, m, k, m * 32); + cublasLtMatrixLayoutCreate(&Bdesc_col32_2r_4r4, CUDA_R_8I, k, n, + ((n + 32 - 1) / 32) * 32 * 32); + cublasLtMatrixLayoutCreate(&Cdesc_col32, CUDA_R_8I, m, n, m * 32); + cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t col32_2r_4r4 = CUBLASLT_ORDER_COL32_2R_4R4; + cublasLtMatrixLayoutSetAttribute(Adesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER, + &col32, sizeof(col32)); + cublasLtMatrixLayoutSetAttribute(Bdesc_col32_2r_4r4, + CUBLASLT_MATRIX_LAYOUT_ORDER, &col32_2r_4r4, + sizeof(col32_2r_4r4)); + cublasLtMatrixLayoutSetAttribute(Cdesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER, + &col32, sizeof(col32)); + + transform(ltHandle, Adev, lda, Adesc_col_major, A_col32, m * 32, Adesc_col32); + transform(ltHandle, Bdev, ldb, Bdesc_col_major, B_col32_2r_4r4, 8 * 32, + Bdesc_col32_2r_4r4); + + // Matmul + igemmlt(ltHandle, m, n, k, A_col32, B_col32_2r_4r4, C_col32, + nullptr, m * 32, ((n + 8 - 1) / 8) * 8 * 32, + m * 32); + + // Convert C + transform(ltHandle, C_col32, m * 32, Cdesc_col32, Cdev, ldc, Cdesc_col_major); + cudaStreamSynchronize(0); + + // Check result + int8_t Chost[m * n]; + cudaMemcpy(Chost, Cdev, m * n * sizeof(int8_t), cudaMemcpyDeviceToHost); + + bool error = false; + int8_t C_ref[m * n] = {14, 17, 20, 23, 4, 6, 8, 10}; + for (int i = 0; i < m * n; i++) { + if (Chost[i] != C_ref[i]) { + error = true; + break; + } + } + printf("c:\n"); + for (int i = 0; i < m * n; i++) + printf("%d, ", Chost[i]); + printf("\n"); + + if (error) { + printf("error\n"); + } else { + printf("success\n"); + } + + cublasLtDestroy(ltHandle); + cublasLtMatrixLayoutDestroy(Adesc_col32); + cublasLtMatrixLayoutDestroy(Bdesc_col32_2r_4r4); + cublasLtMatrixLayoutDestroy(Cdesc_col32); + cublasLtMatrixLayoutDestroy(Adesc_col_major); + cublasLtMatrixLayoutDestroy(Bdesc_col_major); + cublasLtMatrixLayoutDestroy(Cdesc_col_major); + cudaFree(Adev); + cudaFree(Bdev); + cudaFree(Cdev); + + return !error; +} + +// igemmlt +bool test6() { + cublasLtHandle_t ltHandle; + cublasLtCreate(<Handle); + const constexpr int m = 4; + const constexpr int n = 2; + const constexpr int k = 3; + int lda = m; + int ldb = n; + int ldc = m; + void *Adev; + void *Bdev; + void *Cdev; + cudaMalloc(&Adev, m * k * sizeof(int8_t)); + cudaMalloc(&Bdev, n * k * sizeof(int8_t)); + cudaMalloc(&Cdev, m * n * sizeof(int8_t)); + + int8_t Ahost[m * k] = {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}; + int8_t Bhost[n * k] = {5, 4, -3, -2, 1, 0}; + + cudaMemcpy(Adev, Ahost, m * k * sizeof(int8_t), cudaMemcpyHostToDevice); + cudaMemcpy(Bdev, Bhost, n * k * sizeof(int8_t), cudaMemcpyHostToDevice); + + cublasLtMatrixLayout_t Adesc_col_major = NULL, Bdesc_col_major = NULL, + Cdesc_col_major = NULL; + cublasLtMatrixLayoutCreate(&Adesc_col_major, CUDA_R_8I, m, k, lda); + cublasLtMatrixLayoutCreate(&Bdesc_col_major, CUDA_R_8I, n, k, ldb); + cublasLtMatrixLayoutCreate(&Cdesc_col_major, CUDA_R_8I, m, n, ldc); + + // Convert A and B + cublasLtMatrixLayout_t Adesc_col32 = NULL, Bdesc_col32_2r_4r4 = NULL, + Cdesc_col32 = NULL; + int8_t *A_col32, *B_col32_2r_4r4; + int8_t *C_col32; + cudaMalloc(&A_col32, m * 32 * sizeof(std::int8_t)); + cudaMalloc(&B_col32_2r_4r4, + ((n + 32 - 1) / 32) * 32 * 32 * sizeof(std::int8_t)); + cudaMalloc(&C_col32, m * 32 * sizeof(std::int8_t)); + cublasLtMatrixLayoutCreate(&Adesc_col32, CUDA_R_8I, m, k, m * 32); + cublasLtMatrixLayoutCreate(&Bdesc_col32_2r_4r4, CUDA_R_8I, k, n, + ((n + 32 - 1) / 32) * 32 * 32); + cublasLtMatrixLayoutCreate(&Cdesc_col32, CUDA_R_8I, m, n, m * 32); + cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t col32_2r_4r4 = CUBLASLT_ORDER_COL32_2R_4R4; + cublasLtMatrixLayoutSetAttribute(Adesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER, + &col32, sizeof(col32)); + cublasLtMatrixLayoutSetAttribute(Bdesc_col32_2r_4r4, + CUBLASLT_MATRIX_LAYOUT_ORDER, &col32_2r_4r4, + sizeof(col32_2r_4r4)); + cublasLtMatrixLayoutSetAttribute(Cdesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER, + &col32, sizeof(col32)); + + transform(ltHandle, Adev, lda, Adesc_col_major, A_col32, m * 32, Adesc_col32); + transform(ltHandle, Bdev, ldb, Bdesc_col_major, B_col32_2r_4r4, 8 * 32, + Bdesc_col32_2r_4r4); + + float *alpha; + cudaMallocManaged(&alpha, 4 * sizeof(float)); + alpha[0] = 0; + alpha[1] = 1; + alpha[2] = 2; + alpha[3] = 3; + + // Matmul + igemmlt(ltHandle, m, n, k, A_col32, B_col32_2r_4r4, C_col32, + alpha, m * 32, ((n + 8 - 1) / 8) * 8 * 32, m * 32); + + // Convert C + transform(ltHandle, C_col32, m * 32, Cdesc_col32, Cdev, ldc, Cdesc_col_major); + cudaStreamSynchronize(0); + + // Check result + int8_t Chost[m * n]; + cudaMemcpy(Chost, Cdev, m * n * sizeof(int8_t), cudaMemcpyDeviceToHost); + + bool error = false; + int8_t C_ref[m * n] = {0, 17, 40, 69, 0, 6, 16, 30}; + for (int i = 0; i < m * n; i++) { + if (Chost[i] != C_ref[i]) { + error = true; + break; + } + } + printf("c:\n"); + for (int i = 0; i < m * n; i++) + printf("%d, ", Chost[i]); + printf("\n"); + + if (error) { + printf("error\n"); + } else { + printf("success\n"); + } + + cublasLtDestroy(ltHandle); + cublasLtMatrixLayoutDestroy(Adesc_col32); + cublasLtMatrixLayoutDestroy(Bdesc_col32_2r_4r4); + cublasLtMatrixLayoutDestroy(Cdesc_col32); + cublasLtMatrixLayoutDestroy(Adesc_col_major); + cublasLtMatrixLayoutDestroy(Bdesc_col_major); + cublasLtMatrixLayoutDestroy(Cdesc_col_major); + cudaFree(Adev); + cudaFree(Bdev); + cudaFree(Cdev); + cudaFree(alpha); + + return !error; +} + +// clang-format off +// A (4*3) B (2*3) +// 6 10 14 5 -3 1 +// 7 11 15 4 -2 0 +// 8 12 16 +// 9 13 17 +// +// alpha * A * op(B) = alpha * C = C +// 0 6 10 14 5 4 0 14 4 0 0 +// 1 7 11 15 -3 -2 1 17 6 17 6 +// 2 8 12 16 1 0 2 20 8 40 16 +// 3 9 13 17 3 23 10 69 30 +// +// alpha * A * op(B) = alpha * C = C +// 1 6 10 14 5 4 1 14 4 14 4 +// 7 11 15 -3 -2 17 6 17 6 +// 8 12 16 1 0 20 8 20 8 +// 9 13 17 23 10 23 10 +// clang-format on + +int main() { + bool pass = true; + pass = test1() && pass; + pass = test2() && pass; + pass = test3() && pass; + pass = test4() && pass; + pass = test5() && pass; + pass = test6() && pass; + return pass ? 0 : 1; +} diff --git a/features/feature_case/cublasLt/transform.cu b/features/feature_case/cublasLt/transform.cu new file mode 100644 index 000000000..3a1205add --- /dev/null +++ b/features/feature_case/cublasLt/transform.cu @@ -0,0 +1,600 @@ +// ===------------ transform.cu -------------------------- *- CUDA -* ----=== // +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // + +#include "cublasLt.h" +#include + +void transform(cublasLtHandle_t ltHandle, void *in, int ld_in, + cublasLtOrder_t order_in, void *out, int ld_out, + cublasLtOrder_t order_out, int dim1, int dim2) { + cublasLtMatrixLayout_t in_desc = NULL, out_desc = NULL; + cublasLtMatrixTransformDesc_t transform_desc = NULL; + + cublasLtMatrixLayoutCreate(&in_desc, CUDA_R_8I, dim1, dim2, ld_in); + cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ld_out); + + cublasLtMatrixLayoutSetAttribute(in_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_in, sizeof(order_in)); + cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_out, sizeof(order_out)); + + cublasLtMatrixTransformDescCreate(&transform_desc, CUDA_R_32F); + + float alpha = 1.0f, beta = 0.0f; + cublasLtMatrixTransform(ltHandle, transform_desc, &alpha, in, in_desc, &beta, + NULL, NULL, out, out_desc, 0); + + cublasLtMatrixLayoutDestroy(in_desc); + cublasLtMatrixLayoutDestroy(out_desc); + cublasLtMatrixTransformDescDestroy(transform_desc); +} + +bool test_ROW() { + const constexpr int m = 2; + const constexpr int n = 33; + const constexpr int in_ld = 4; + void *in_dev; + cudaMalloc(&in_dev, n * in_ld * sizeof(int8_t)); + + int8_t in_host[n * in_ld]; + int8_t value = 0; + for (int i = 0; i < n * in_ld; i++) { + if (i % 4 < 2) { + in_host[i] = value; + value++; + } else + in_host[i] = 99; + } + int8_t ref_2nd[n * in_ld]; + std::memcpy(ref_2nd, in_host, n * in_ld * sizeof(int8_t)); + + cudaMemcpy(in_dev, in_host, n * in_ld * sizeof(int8_t), + cudaMemcpyHostToDevice); + + cublasLtHandle_t ltHandle; + cublasLtCreate(<Handle); + + void *out_dev; + const constexpr int out_ld = 36; + cudaMalloc(&out_dev, out_ld * m * sizeof(int8_t)); + cudaMemset(out_dev, 0, out_ld * m * sizeof(int8_t)); + transform(ltHandle, in_dev, in_ld, CUBLASLT_ORDER_COL, out_dev, out_ld, + CUBLASLT_ORDER_ROW, m, n); + + int8_t out_host[out_ld * m]; + cudaMemcpy(out_host, out_dev, out_ld * m * sizeof(int8_t), + cudaMemcpyDeviceToHost); + + bool pass_1st = true; + int8_t ref_1st[out_ld * m] = + {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 64, 0, 0, 0, + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 0, 0, 0}; + for (int i = 0; i < out_ld * m; i++) { + if (i % out_ld < n) { + if (out_host[i] != ref_1st[i]) { + pass_1st = false; + break; + } + } + } + + for (int i = 0; i < out_ld * m; i++) { + printf("%d, ", out_host[i]); + } + printf("\n"); + if (pass_1st) { + printf("ROW 1st pass\n"); + } else { + printf("ROW 1st fail\n"); + } + + cudaMemset(in_dev, 0, n * in_ld * sizeof(int8_t)); + std::memset(in_host, 0, n * in_ld * sizeof(int8_t)); + transform(ltHandle, out_dev, out_ld, CUBLASLT_ORDER_ROW, in_dev, in_ld, + CUBLASLT_ORDER_COL, m, n); + cudaMemcpy(in_host, in_dev, n * in_ld * sizeof(int8_t), + cudaMemcpyDeviceToHost); + + bool pass_2nd = true; + for (int i = 0; i < n * in_ld; i++) { + if (i % in_ld < m) { + if (in_host[i] != ref_2nd[i]) { + pass_2nd = false; + break; + } + } + } + + for (int i = 0; i < n * in_ld; i++) { + printf("%d, ", in_host[i]); + } + printf("\n"); + if (pass_2nd) { + printf("ROW 2nd pass\n"); + } else { + printf("ROW 2nd fail\n"); + } + + cublasLtDestroy(ltHandle); + + return pass_1st && pass_2nd; +} + +bool test_COL32() { + const constexpr int m = 2; + const constexpr int n = 33; + const constexpr int in_ld = 4; + void *in_dev; + cudaMalloc(&in_dev, n * in_ld * sizeof(int8_t)); + + int8_t in_host[n * in_ld]; + int8_t value = 0; + for (int i = 0; i < n * in_ld; i++) { + if (i % 4 < 2) { + in_host[i] = value; + value++; + } else + in_host[i] = 99; + } + int8_t ref_2nd[n * in_ld]; + std::memcpy(ref_2nd, in_host, n * in_ld * sizeof(int8_t)); + + cudaMemcpy(in_dev, in_host, n * in_ld * sizeof(int8_t), + cudaMemcpyHostToDevice); + + cublasLtHandle_t ltHandle; + cublasLtCreate(<Handle); + + void *out_dev; + const constexpr int out_ld = 64; + cudaMalloc(&out_dev, out_ld * m * sizeof(int8_t)); + cudaMemset(out_dev, 0, out_ld * m * sizeof(int8_t)); + transform(ltHandle, in_dev, in_ld, CUBLASLT_ORDER_COL, out_dev, out_ld, + CUBLASLT_ORDER_COL32, m, n); + + int8_t out_host[out_ld * m]; + cudaMemcpy(out_host, out_dev, out_ld * m * sizeof(int8_t), + cudaMemcpyDeviceToHost); + + bool pass_1st = true; + int8_t ref_1st[out_ld * m] = + {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63, + 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 65, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + for (int i = 0; i < out_ld * m; i++) { + if (i % out_ld < n) { + if (out_host[i] != ref_1st[i]) { + pass_1st = false; + break; + } + } + } + + for (int i = 0; i < out_ld * m; i++) { + printf("%d, ", out_host[i]); + } + printf("\n"); + if (pass_1st) { + printf("COL32 1st pass\n"); + } else { + printf("COL32 1st fail\n"); + } + + cudaMemset(in_dev, 0, n * in_ld * sizeof(int8_t)); + std::memset(in_host, 0, n * in_ld * sizeof(int8_t)); + transform(ltHandle, out_dev, out_ld, CUBLASLT_ORDER_COL32, in_dev, in_ld, + CUBLASLT_ORDER_COL, m, n); + cudaMemcpy(in_host, in_dev, n * in_ld * sizeof(int8_t), + cudaMemcpyDeviceToHost); + + bool pass_2nd = true; + for (int i = 0; i < n * in_ld; i++) { + if (i % in_ld < m) { + if (in_host[i] != ref_2nd[i]) { + pass_2nd = false; + break; + } + } + } + + for (int i = 0; i < n * in_ld; i++) { + printf("%d, ", in_host[i]); + } + printf("\n"); + if (pass_2nd) { + printf("COL32 2nd pass\n"); + } else { + printf("COL32 2nd fail\n"); + } + + cublasLtDestroy(ltHandle); + + return pass_1st && pass_2nd; +} + +bool test_COL4_4R2_8C() { + const constexpr int m = 2; + const constexpr int n = 33; + const constexpr int in_ld = 4; + void *in_dev; + cudaMalloc(&in_dev, n * in_ld * sizeof(int8_t)); + + int8_t in_host[n * in_ld]; + int8_t value = 0; + for (int i = 0; i < n * in_ld; i++) { + if (i % 4 < 2) { + in_host[i] = value; + value++; + } else + in_host[i] = 99; + } + int8_t ref_2nd[n * in_ld]; + std::memcpy(ref_2nd, in_host, n * in_ld * sizeof(int8_t)); + + cudaMemcpy(in_dev, in_host, n * in_ld * sizeof(int8_t), + cudaMemcpyHostToDevice); + + cublasLtHandle_t ltHandle; + cublasLtCreate(<Handle); + + void *out_dev; + const constexpr int out_ld = (32 * 8) * 2; + cudaMalloc(&out_dev, out_ld * m * sizeof(int8_t)); + cudaMemset(out_dev, 0, out_ld * m * sizeof(int8_t)); + transform(ltHandle, in_dev, in_ld, CUBLASLT_ORDER_COL, out_dev, out_ld, + CUBLASLT_ORDER_COL4_4R2_8C, m, n); + + int8_t out_host[out_ld * m]; + cudaMemcpy(out_host, out_dev, out_ld * m * sizeof(int8_t), + cudaMemcpyDeviceToHost); + + bool pass_1st = true; + int8_t ref_1st[out_ld * m] = + {0, 2, 4, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 8, 10, 12, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 16, 18, 20, 22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 24, 26, 28, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 32, 34, 36, 38, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 40, 42, 44, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 48, 50, 52, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 56, 58, 60, 62, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 3, 5, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 9, 11, 13, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 17, 19, 21, 23, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 25, 27, 29, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 33, 35, 37, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 41, 43, 45, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 49, 51, 53, 55, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 57, 59, 61, 63, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 65, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + for (int i = 0; i < out_ld * m; i++) { + if (i % out_ld < n) { + if (out_host[i] != ref_1st[i]) { + pass_1st = false; + break; + } + } + } + + for (int i = 0; i < out_ld * m; i++) { + printf("%d, ", out_host[i]); + } + printf("\n"); + if (pass_1st) { + printf("COL4_4R2_8C 1st pass\n"); + } else { + printf("COL4_4R2_8C 1st fail\n"); + } + + cudaMemset(in_dev, 0, n * in_ld * sizeof(int8_t)); + std::memset(in_host, 0, n * in_ld * sizeof(int8_t)); + transform(ltHandle, out_dev, out_ld, CUBLASLT_ORDER_COL4_4R2_8C, in_dev, + in_ld, CUBLASLT_ORDER_COL, m, n); + cudaMemcpy(in_host, in_dev, n * in_ld * sizeof(int8_t), + cudaMemcpyDeviceToHost); + + bool pass_2nd = true; + for (int i = 0; i < n * in_ld; i++) { + if (i % in_ld < m) { + if (in_host[i] != ref_2nd[i]) { + pass_2nd = false; + break; + } + } + } + + for (int i = 0; i < n * in_ld; i++) { + printf("%d, ", in_host[i]); + } + printf("\n"); + if (pass_2nd) { + printf("COL4_4R2_8C 2nd pass\n"); + } else { + printf("COL4_4R2_8C 2nd fail\n"); + } + + cublasLtDestroy(ltHandle); + + return pass_1st && pass_2nd; +} + +bool test_COL32_2R_4R4() { + const constexpr int m = 2; + const constexpr int n = 33; + const constexpr int in_ld = 4; + void *in_dev; + cudaMalloc(&in_dev, n * in_ld * sizeof(int8_t)); + + int8_t in_host[n * in_ld]; + int8_t value = 0; + for (int i = 0; i < n * in_ld; i++) { + if (i % 4 < 2) { + in_host[i] = value; + value++; + } else + in_host[i] = 99; + } + int8_t ref_2nd[n * in_ld]; + std::memcpy(ref_2nd, in_host, n * in_ld * sizeof(int8_t)); + + cudaMemcpy(in_dev, in_host, n * in_ld * sizeof(int8_t), + cudaMemcpyHostToDevice); + + cublasLtHandle_t ltHandle; + cublasLtCreate(<Handle); + + void *out_dev; + const constexpr int out_ld = (32 * 32) * 2; + cudaMalloc(&out_dev, out_ld * m * sizeof(int8_t)); + cudaMemset(out_dev, 0, out_ld * m * sizeof(int8_t)); + transform(ltHandle, in_dev, in_ld, CUBLASLT_ORDER_COL, out_dev, out_ld, + CUBLASLT_ORDER_COL32_2R_4R4, m, n); + + int8_t out_host[out_ld * m]; + cudaMemcpy(out_host, out_dev, out_ld * m * sizeof(int8_t), + cudaMemcpyDeviceToHost); + + bool pass_1st = true; + int8_t ref_1st[out_ld * m] = + {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 65, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + for (int i = 0; i < out_ld * m; i++) { + if (i % out_ld < n) { + if (out_host[i] != ref_1st[i]) { + pass_1st = false; + break; + } + } + } + + for (int i = 0; i < out_ld * m; i++) { + printf("%d, ", out_host[i]); + } + printf("\n"); + if (pass_1st) { + printf("COL32_2R_4R4 1st pass\n"); + } else { + printf("COL32_2R_4R4 1st fail\n"); + } + + cudaMemset(in_dev, 0, n * in_ld * sizeof(int8_t)); + std::memset(in_host, 0, n * in_ld * sizeof(int8_t)); + transform(ltHandle, out_dev, out_ld, CUBLASLT_ORDER_COL32_2R_4R4, in_dev, + in_ld, CUBLASLT_ORDER_COL, m, n); + cudaMemcpy(in_host, in_dev, n * in_ld * sizeof(int8_t), + cudaMemcpyDeviceToHost); + + bool pass_2nd = true; + for (int i = 0; i < n * in_ld; i++) { + if (i % in_ld < m) { + if (in_host[i] != ref_2nd[i]) { + pass_2nd = false; + break; + } + } + } + + for (int i = 0; i < n * in_ld; i++) { + printf("%d, ", in_host[i]); + } + printf("\n"); + if (pass_2nd) { + printf("COL32_2R_4R4 2nd pass\n"); + } else { + printf("COL32_2R_4R4 2nd fail\n"); + } + + cublasLtDestroy(ltHandle); + + return pass_1st && pass_2nd; +} + +// Input col_major matrix: +// 2 rows * 33 columns, ld is 4 +int main() { + bool pass = true; + pass = test_ROW() && pass; + pass = test_COL32() && pass; + pass = test_COL4_4R2_8C() && pass; + pass = test_COL32_2R_4R4() && pass; + return pass ? 0 : 1; +} diff --git a/features/features.xml b/features/features.xml index 8ad6e3964..ff7838f2e 100644 --- a/features/features.xml +++ b/features/features.xml @@ -342,5 +342,7 @@ + + diff --git a/features/test_feature.py b/features/test_feature.py index 2f6bab329..bbd6e08c9 100644 --- a/features/test_feature.py +++ b/features/test_feature.py @@ -60,7 +60,7 @@ 'thrust_swap_ranges', 'thrust_uninitialized_fill_n', 'thrust_equal', 'system_atomic', 'thrust_detail_types', 'operator_eq', 'operator_neq', 'operator_lege', 'thrust_system', 'thrust_reverse_copy', 'thrust_device_new_delete', 'thrust_temporary_buffer', 'thrust_malloc_free', 'codepin', 'thrust_unique_count', - 'thrust_advance_trans_op_itr', 'cuda_stream_query'] + 'thrust_advance_trans_op_itr', 'cuda_stream_query', "matmul", "transform"] occupancy_calculation_exper = ['occupancy_calculation'] @@ -166,7 +166,7 @@ def build_test(): 'cudnn-binary', 'cudnn-bnp1', 'cudnn-bnp2', 'cudnn-bnp3', 'cudnn-normp1', 'cudnn-normp2', 'cudnn-normp3', 'cudnn-convp1', 'cudnn-convp2', 'cudnn-convp3', 'cudnn-convp4', 'cudnn-convp5', 'cudnn-convp6', 'cudnn-rnn', 'cudnn-GetErrorString', 'cudnn-convp7', - 'cudnn-types', 'cudnn-version', 'cudnn-dropout' + 'cudnn-types', 'cudnn-version', 'cudnn-dropout', 'matmul' ] no_fast_math_tests = ['math-emu-half-after11', 'math-emu-half2-after11', 'math-ext-half-after11', 'math-ext-half2-after11', diff --git a/help_function/help_function.xml b/help_function/help_function.xml index 3bdbd26ed..fc047e53e 100644 --- a/help_function/help_function.xml +++ b/help_function/help_function.xml @@ -215,5 +215,6 @@ + diff --git a/help_function/src/blas_gemm_utils_interface.cpp b/help_function/src/blas_gemm_utils_interface.cpp new file mode 100644 index 000000000..551021b31 --- /dev/null +++ b/help_function/src/blas_gemm_utils_interface.cpp @@ -0,0 +1,152 @@ +// ===------ blas_gemm_utils_interface.cpp ----------------- *- C++ -* ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // + +#include +#include +#include +#include + +void foo1 () { + dpct::blas_gemm::experimental::descriptor_ptr ltHandle; + ltHandle = new dpct::blas_gemm::experimental::descriptor(); + delete (ltHandle); + + dpct::blas_gemm::experimental::matrix_layout_ptr matLayout; + dpct::library_data_t type; + uint64_t rows; + uint64_t cols; + int64_t ld; + matLayout = + new dpct::blas_gemm::experimental::matrix_layout_t(type, rows, cols, ld); + + dpct::blas_gemm::experimental::matrix_layout_t::attribute attr1; + void *buf1; + size_t sizeInBytes1; + size_t *sizeWritten1; + matLayout->get_attribute(attr1, buf1); + matLayout->set_attribute(attr1, buf1); + delete (matLayout); + + dpct::blas_gemm::experimental::matmul_desc_ptr matmulDesc; + dpct::compute_type computeType; + dpct::library_data_t scaleType; + matmulDesc = + new dpct::blas_gemm::experimental::matmul_desc_t(computeType, scaleType); + + dpct::blas_gemm::experimental::matmul_desc_t::attribute attr2; + void *buf2; + size_t sizeInBytes2; + size_t *sizeWritten2; + matmulDesc->get_attribute(attr2, buf2); + matmulDesc->set_attribute(attr2, buf2); + delete (matmulDesc); + + int matmulPreference; + void *buf3; + size_t sizeInBytes3; + size_t *sizeWritten3; + + dpct::blas_gemm::experimental::matrix_layout_ptr Adesc; + dpct::blas_gemm::experimental::matrix_layout_ptr Bdesc; + dpct::blas_gemm::experimental::matrix_layout_ptr Cdesc; + dpct::blas_gemm::experimental::matrix_layout_ptr Ddesc; + + int requestedAlgoCount = 1; + int heuristicResultsArray; + int returnAlgoCount; + returnAlgoCount = 1; +} + +void foo2() { + dpct::blas_gemm::experimental::descriptor_ptr lightHandle; + dpct::blas_gemm::experimental::matmul_desc_ptr computeDesc; + const void *alpha; + const void *A; + dpct::blas_gemm::experimental::matrix_layout_ptr Adesc; + const void *B; + dpct::blas_gemm::experimental::matrix_layout_ptr Bdesc; + const void *beta; + const void *C; + dpct::blas_gemm::experimental::matrix_layout_ptr Cdesc; + void *D; + dpct::blas_gemm::experimental::matrix_layout_ptr Ddesc; + const int *algo; + void *workspace; + size_t workspaceSizeInBytes; + dpct::queue_ptr stream; + dpct::blas_gemm::experimental::matmul(lightHandle, computeDesc, alpha, A, + Adesc, B, Bdesc, beta, C, Cdesc, D, + Ddesc, stream); +} + +void foo3() { + dpct::blas_gemm::experimental::order_t a; + a = dpct::blas_gemm::experimental::order_t::col; + a = dpct::blas_gemm::experimental::order_t::row; + a = dpct::blas_gemm::experimental::order_t::col32; + a = dpct::blas_gemm::experimental::order_t::col4_4r2_8c; + a = dpct::blas_gemm::experimental::order_t::col32_2r_4r4; + + dpct::blas_gemm::experimental::pointer_mode_t b; + b = dpct::blas_gemm::experimental::pointer_mode_t::host; + b = dpct::blas_gemm::experimental::pointer_mode_t::device; + b = dpct::blas_gemm::experimental::pointer_mode_t::device_vector; + b = dpct::blas_gemm::experimental::pointer_mode_t:: + alpha_device_vector_beta_zero; + b = dpct::blas_gemm::experimental::pointer_mode_t:: + alpha_device_vector_beta_host; + + dpct::blas_gemm::experimental::matrix_layout_t::attribute c; + c = dpct::blas_gemm::experimental::matrix_layout_t::attribute::type; + c = dpct::blas_gemm::experimental::matrix_layout_t::attribute::order; + c = dpct::blas_gemm::experimental::matrix_layout_t::attribute::rows; + c = dpct::blas_gemm::experimental::matrix_layout_t::attribute::cols; + c = dpct::blas_gemm::experimental::matrix_layout_t::attribute::ld; + + dpct::blas_gemm::experimental::matmul_desc_t::attribute d; + d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::compute_type; + d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::scale_type; + d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::pointer_mode; + d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::trans_a; + d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::trans_b; + d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::trans_c; + d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue; +} + +void foo4() { + dpct::blas_gemm::experimental::transform_desc_ptr transformDesc; + dpct::library_data_t scaleType; + transformDesc = + new dpct::blas_gemm::experimental::transform_desc_t(scaleType); + oneapi::mkl::transpose opT = oneapi::mkl::transpose::trans; + size_t sizeWritten; + transformDesc->set_attribute( + dpct::blas_gemm::experimental::transform_desc_t::attribute::trans_a, + &opT); + transformDesc->get_attribute( + dpct::blas_gemm::experimental::transform_desc_t::attribute::trans_a, + &opT); + delete (transformDesc); + + dpct::blas_gemm::experimental::descriptor_ptr lightHandle; + const void *alpha; + const void *A; + dpct::blas_gemm::experimental::matrix_layout_ptr Adesc; + const void *beta; + const void *B; + dpct::blas_gemm::experimental::matrix_layout_ptr Bdesc; + void *C; + dpct::blas_gemm::experimental::matrix_layout_ptr Cdesc; + dpct::queue_ptr stream; + dpct::blas_gemm::experimental::matrix_transform( + transformDesc, alpha, A, Adesc, beta, B, Bdesc, C, Cdesc, stream); +} + +int main() { + return 0; +} diff --git a/help_function/test_help.py b/help_function/test_help.py index b6c585270..52588464a 100644 --- a/help_function/test_help.py +++ b/help_function/test_help.py @@ -45,7 +45,7 @@ def build_test(): "dnnl_utils_batch_normalization_2", "dnnl_utils_batch_normalization_3", "dnnl_utils_convolution_1", "dnnl_utils_convolution_2", "dnnl_utils_convolution_3", "dnnl_utils_convolution_4", "dnnl_utils_convolution_5", "dnnl_utils_normalization_1", "dnnl_utils_normalization_2", "dnnl_utils_normalization_3", "dnnl_utils_rnn", - "dnnl_utils_version", "dnnl_utils_dropout"] + "dnnl_utils_version", "dnnl_utils_dropout", "blas_gemm_utils_interface"] fft_cases = ["fft_utils_engine_buffer", "fft_utils_engine_usm", "fft_workspace_interface", "fft_set_workspace"] lapack_cases = ["lapack_utils_buffer", "lapack_utils_usm"] rng_cases = ["rng_generator", "rng_generator_vec_size_1", "rng_host"]