From 0752a372e36fe4f16ffd904aabb1f8a612e800ab Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 15 Feb 2024 19:36:06 +0900 Subject: [PATCH 1/3] support dynamic cuda wrapper --- CMakeLists.txt | 18 ++- bitsandbytes/cextension.py | 5 + bitsandbytes/cuda_setup/main.py | 11 +- csrc/ops.cu | 4 +- csrc/ops.cuh | 253 ++++++++++++++++++++++++++++++++ csrc/pythonInterface.cpp | 216 +++++++++++++++++++++++++++ 6 files changed, 500 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1b9f1854b..754d2fc4b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,7 @@ list(APPEND SRC_FILES ${CPP_FILES}) set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)") set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) +option(USE_CUDA_WRAPPER "Dynamic CUDA Linking" OFF) if(APPLE) set(CMAKE_OSX_DEPLOYMENT_TARGET 13.1) @@ -43,6 +44,7 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") set(BUILD_CUDA ON) set(BUILD_MPS OFF) message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}") + message(STATUS "USE_CUDA_WRAPPER := ${USE_CUDA_WRAPPER}") elseif(${COMPUTE_BACKEND} STREQUAL "mps") if(NOT APPLE) message(FATAL_ERROR "MPS is only supported on macOS" ) @@ -111,7 +113,13 @@ if(BUILD_CUDA) list(APPEND SRC_FILES ${CUDA_FILES}) - string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") + if(USE_CUDA_WRAPPER) + string(APPEND BNB_OUTPUT_NAME "_cuda") + else() + string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") + endif() + add_compile_definitions(USE_CUDA_WRAPPER) + if(NO_CUBLASLT) string(APPEND BNB_OUTPUT_NAME "_nocublaslt") endif() @@ -160,11 +168,15 @@ target_include_directories(bitsandbytes PUBLIC csrc include) if(BUILD_CUDA) target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse) + if(NOT USE_CUDA_WRAPPER) + target_link_libraries(bitsandbytes PUBLIC CUDA::cublas CUDA::cusparse) + endif() if(NO_CUBLASLT) target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT) else() - target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt) + if(NOT USE_CUDA_WRAPPER) + target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt) + endif() endif() set_target_properties(bitsandbytes diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 858365f02..835a35d9c 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -26,6 +26,11 @@ lib.get_context.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p lib.cget_managed_ptr.restype = ct.c_void_p + try: + lib.initCudaLibs() + except Exception: + # ignore + pass COMPILED_WITH_CUDA = True except AttributeError as ex: warn("The installed version of bitsandbytes was compiled without GPU support. " diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index 14c7abbd8..74ddf874d 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -30,7 +30,7 @@ DYNAMIC_LIBRARY_SUFFIX = { "Darwin": ".dylib", "Windows": ".dll", "Linux": ".so"}.get(platform.system(), ".so") if platform.system() == "Windows": # Windows - CUDA_RUNTIME_LIBS = ["nvcuda.dll"] + CUDA_RUNTIME_LIBS = ["cudart64_110.dll", "cudart64_12.dll"] else: # Linux or other # these are the most common libs names # libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead @@ -383,7 +383,14 @@ def evaluate_cuda_setup(): # we use ls -l instead of nvcc to determine the cuda version # since most installations will have the libcudart.so installed, but not the compiler - binary_name = f"libbitsandbytes_cuda{cuda_version_string}" + binary = f"libbitsandbytes_cuda{DYNAMIC_LIBRARY_SUFFIX}" + package_dir = Path(__file__).parent.parent + binary_path = package_dir / binary + # check binary_path without cuda_version_string + if binary_path.exists(): + binary_name = "libbitsandbytes_cuda" + else: + binary_name = f"libbitsandbytes_cuda{cuda_version_string}" if not has_cublaslt: # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt binary_name += "_nocublaslt" diff --git a/csrc/ops.cu b/csrc/ops.cu index 796211fed..6eb524d74 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -255,7 +255,7 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in m, n, k, alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta, C, CUDA_R_32I, ldc, - CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); + CUBLAS_COMPUTE_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); if (status != CUBLAS_STATUS_SUCCESS) { @@ -285,7 +285,7 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i m, n, k, alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta, C, CUDA_R_32I, ldc, (long long int)strideC, batchCount, - CUDA_R_32I, CUBLAS_GEMM_DEFAULT); + CUBLAS_COMPUTE_32I, CUBLAS_GEMM_DEFAULT); if (status != CUBLAS_STATUS_SUCCESS) { diff --git a/csrc/ops.cuh b/csrc/ops.cuh index da9df6af0..aaa88cfaf 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -22,7 +22,260 @@ #include #include +#if USE_CUDA_WRAPPER +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#include +#else +#include +#endif + +typedef const char* (*cudaGetErrorString_t)(cudaError_t err); +typedef const char* (*cusparseGetErrorString_t)(cusparseStatus_t status); +typedef cusparseStatus_t (*cusparseCreate_t)(cusparseHandle_t* handle); +typedef cublasStatus_t (*cublasCreate_v2_t)(cublasHandle_t* handle); +typedef cublasStatus_t (*cublasLtCreate_t)(cublasLtHandle_t* lightHandle); +//typedef cudaError_t (*cudaMallocManaged_t)(void **devPtr, size_t size, unsigned int flags); +//typedef cudaError_t (*cudaMemPrefetchAsync_t)(const void *devPtr, size_t count, int dstDevice, cudaStream_t stream); +//typedef cudaError_t (*cudaDeviceGetAttribute_t)(int *value, enum cudaDeviceAttr attr, int device); + +typedef cusparseStatus_t (*cusparseCreateCoo_t)(cusparseSpMatDescr_t* spMatDescr, + int64_t rows, + int64_t cols, + int64_t nnz, + void* cooRowInd, + void* cooColInd, + void* cooValues, + cusparseIndexType_t cooIdxType, + cusparseIndexBase_t idxBase, + cudaDataType valueType); + +typedef cusparseStatus_t (*cusparseCreateDnMat_t)(cusparseDnMatDescr_t* dnMatDescr, + int64_t rows, + int64_t cols, + int64_t ld, + void* values, + cudaDataType valueType, + cusparseOrder_t order); + +typedef cusparseStatus_t (*cusparseSpMM_bufferSize_t)(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const void* alpha, + cusparseSpMatDescr_t matA, + cusparseDnMatDescr_t matB, + const void* beta, + cusparseDnMatDescr_t matC, + cudaDataType computeType, + cusparseSpMMAlg_t alg, + size_t* bufferSize); + +typedef cusparseStatus_t (*cusparseSpMM_t)(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const void* alpha, + cusparseSpMatDescr_t matA, + cusparseDnMatDescr_t matB, + const void* beta, + cusparseDnMatDescr_t matC, + cudaDataType computeType, + cusparseSpMMAlg_t alg, + void* externalBuffer); + +typedef cusparseStatus_t (*cusparseDestroySpMat_t)(cusparseSpMatDescr_t spMatDescr); +typedef cusparseStatus_t (*cusparseDestroyDnMat_t)(cusparseDnMatDescr_t dnMatDescr); + +typedef cudaError_t (*cudaMemset_t)(void *devPtr, int value, size_t count); +typedef cudaError_t (*cudaMalloc_t)(void **devPtr, size_t size); +typedef cudaError_t (*cudaFree_t)(void *devPtr); +typedef cudaError_t (*cudaPeekAtLastError_t)(void); + +typedef cublasStatus_t (*cublasGemmEx_t)(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void* alpha, /* host or device pointer */ + const void* A, + cudaDataType Atype, + int lda, + const void* B, + cudaDataType Btype, + int ldb, + const void* beta, /* host or device pointer */ + void* C, + cudaDataType Ctype, + int ldc, + cublasComputeType_t computeType, + cublasGemmAlgo_t algo); + +typedef cublasStatus_t (*cublasGemmStridedBatchedEx_t)(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void* alpha, /* host or device pointer */ + const void* A, + cudaDataType Atype, + int lda, + long long int strideA, /* purposely signed */ + const void* B, + cudaDataType Btype, + int ldb, + long long int strideB, + const void* beta, /* host or device pointer */ + void* C, + cudaDataType Ctype, + int ldc, + long long int strideC, + int batchCount, + cublasComputeType_t computeType, + cublasGemmAlgo_t algo); + +typedef cublasStatus_t (*cublasLtMatrixLayoutCreate_t)( // + cublasLtMatrixLayout_t* matLayout, + cudaDataType type, + uint64_t rows, + uint64_t cols, + int64_t ld); + +typedef cublasStatus_t (*cublasLtMatrixLayoutSetAttribute_t)( // + cublasLtMatrixLayout_t matLayout, + cublasLtMatrixLayoutAttribute_t attr, + const void* buf, + size_t sizeInBytes); + +typedef cublasStatus_t (*cublasLtMatrixTransform_t)(cublasLtHandle_t lightHandle, + cublasLtMatrixTransformDesc_t transformDesc, + const void* alpha, /* host or device pointer */ + const void* A, + cublasLtMatrixLayout_t Adesc, + const void* beta, /* host or device pointer */ + const void* B, + cublasLtMatrixLayout_t Bdesc, + void* C, + cublasLtMatrixLayout_t Cdesc, + cudaStream_t stream); + +typedef cublasStatus_t (*cublasLtMatrixTransformDescCreate_t)(cublasLtMatrixTransformDesc_t* transformDesc, + cudaDataType scaleType); + +typedef cublasStatus_t (*cublasLtMatrixTransformDescSetAttribute_t)( // + cublasLtMatrixTransformDesc_t transformDesc, + cublasLtMatrixTransformDescAttributes_t attr, + const void* buf, + size_t sizeInBytes); + +typedef cublasStatus_t (*cublasLtMatrixLayoutDestroy_t)(cublasLtMatrixLayout_t matLayout); + +typedef cublasStatus_t (*cublasLtMatrixTransformDescDestroy_t)(cublasLtMatrixTransformDesc_t transformDesc); + +typedef cublasStatus_t (*cublasLtMatmul_t)(cublasLtHandle_t lightHandle, + cublasLtMatmulDesc_t computeDesc, + const void* alpha, /* host or device pointer */ + const void* A, + cublasLtMatrixLayout_t Adesc, + const void* B, + cublasLtMatrixLayout_t Bdesc, + const void* beta, /* host or device pointer */ + const void* C, + cublasLtMatrixLayout_t Cdesc, + void* D, + cublasLtMatrixLayout_t Ddesc, + const cublasLtMatmulAlgo_t* algo, + void* workspace, + size_t workspaceSizeInBytes, + cudaStream_t stream); + +typedef cublasStatus_t (*cublasLtMatmulDescCreate_t)(cublasLtMatmulDesc_t* matmulDesc, + cublasComputeType_t computeType, + cudaDataType_t scaleType); + +typedef cublasStatus_t (*cublasLtMatmulDescDestroy_t)(cublasLtMatmulDesc_t matmulDesc); + +typedef cublasStatus_t (*cublasLtMatmulDescSetAttribute_t)( // + cublasLtMatmulDesc_t matmulDesc, + cublasLtMatmulDescAttributes_t attr, + const void* buf, + size_t sizeInBytes); + + +/* externs */ +extern cudaGetErrorString_t _cudaGetErrorString; +extern cusparseGetErrorString_t _cusparseGetErrorString; +//extern cudaMallocManaged_t _cudaMallocManaged; +//extern cudaMemPrefetchAsync_t _cudaMemPrefetchAsync; +//extern cudaDeviceGetAttribute_t _cudaDeviceGetAttribute; + +extern cusparseCreate_t _cusparseCreate; +extern cublasCreate_v2_t _cublasCreate_v2; +extern cublasLtCreate_t _cublasLtCreate; + +extern cusparseDestroySpMat_t _cusparseDestroySpMat; +extern cusparseDestroyDnMat_t _cusparseDestroyDnMat; +extern cusparseCreateCoo_t _cusparseCreateCoo; +extern cusparseSpMM_t _cusparseSpMM; +extern cusparseSpMM_bufferSize_t _cusparseSpMM_bufferSize; +extern cusparseCreateDnMat_t _cusparseCreateDnMat; + +extern cudaMemset_t _cudaMemset; +extern cudaMalloc_t _cudaMalloc; +extern cudaFree_t _cudaFree; +extern cudaPeekAtLastError_t _cudaPeekAtLastError; + +extern cublasGemmEx_t _cublasGemmEx; +extern cublasGemmStridedBatchedEx_t _cublasGemmStridedBatchedEx; + +extern cublasLtMatrixLayoutCreate_t _cublasLtMatrixLayoutCreate; +extern cublasLtMatrixLayoutSetAttribute_t _cublasLtMatrixLayoutSetAttribute; +extern cublasLtMatrixTransform_t _cublasLtMatrixTransform; +extern cublasLtMatrixTransformDescCreate_t _cublasLtMatrixTransformDescCreate; +extern cublasLtMatrixTransformDescSetAttribute_t _cublasLtMatrixTransformDescSetAttribute; +extern cublasLtMatrixLayoutDestroy_t _cublasLtMatrixLayoutDestroy; +extern cublasLtMatrixTransformDescDestroy_t _cublasLtMatrixTransformDescDestroy; +extern cublasLtMatmul_t _cublasLtMatmul; +extern cublasLtMatmulDescCreate_t _cublasLtMatmulDescCreate; +extern cublasLtMatmulDescDestroy_t _cublasLtMatmulDescDestroy; +extern cublasLtMatmulDescSetAttribute_t _cublasLtMatmulDescSetAttribute; + + +#define cudaGetErrorString _cudaGetErrorString +#define cusparseGetErrorString _cusparseGetErrorString +#define cusparseCreate _cusparseCreate +#define cublasCreate_v2 _cublasCreate_v2 +#define cublasLtCreate _cublasLtCreate + +#define cudaMemset _cudaMemset +#define cudaMalloc _cudaMalloc +#define cudaFree _cudaFree +#define cudaPeekAtLastError _cudaPeekAtLastError + +#define cusparseCreateCoo _cusparseCreateCoo +#define cusparseDestroySpMat _cusparseDestroySpMat +#define cusparseDestroyDnMat _cusparseDestroyDnMat +#define cusparseSpMM _cusparseSpMM +#define cusparseSpMM_bufferSize _cusparseSpMM_bufferSize +#define cusparseCreateDnMat _cusparseCreateDnMat + +#define cublasGemmEx _cublasGemmEx +#define cublasGemmStridedBatchedEx _cublasGemmStridedBatchedEx +#define cublasLtMatrixLayoutCreate _cublasLtMatrixLayoutCreate +#define cublasLtMatrixLayoutSetAttribute _cublasLtMatrixLayoutSetAttribute +#define cublasLtMatrixTransform _cublasLtMatrixTransform +#define cublasLtMatrixTransformDescCreate _cublasLtMatrixTransformDescCreate + +#define cublasLtMatrixTransformDescSetAttribute _cublasLtMatrixTransformDescSetAttribute +#define cublasLtMatrixLayoutDestroy _cublasLtMatrixLayoutDestroy +#define cublasLtMatrixTransformDescDestroy _cublasLtMatrixTransformDescDestroy +#define cublasLtMatmul _cublasLtMatmul +#define cublasLtMatmulDescCreate _cublasLtMatmulDescCreate +#define cublasLtMatmulDescDestroy _cublasLtMatmulDescDestroy +#define cublasLtMatmulDescSetAttribute _cublasLtMatmulDescSetAttribute + +#endif /* USE_CUDA_WRAPPER */ #define CUDA_CHECK_RETURN(value) { \ cudaError_t _m_cudaStat = value; \ diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index ea2283504..2ae419ab4 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -19,6 +19,222 @@ //=================================================================================== #if BUILD_CUDA +#if USE_CUDA_WRAPPER + +#ifndef _WIN32 +#define LDSYMCUDA(func) do { \ + _##func = (func##_t) dlsym(libCudaHandle, #func); \ + /* Check for errors */ \ + const char *err = dlerror(); \ + if (err) { \ + fprintf(stderr, "Failed to retrieve " #func " from libcuda.so"); \ + exit(1); \ + } \ +} while(0) + +#define LDSYMCUSPARSE(func) do { \ + _##func = (func##_t) dlsym(libCusparseHandle, #func); \ + /* Check for errors */ \ + const char *err = dlerror(); \ + if (err) { \ + fprintf(stderr, "Failed to retrieve " #func " from libcudaparse.so"); \ + exit(1); \ + } \ +} while(0) + +#define LDSYMCUBLAS(func) do { \ + _##func = (func##_t) dlsym(libCublasHandle, #func); \ + /* Check for errors */ \ + const char *err = dlerror(); \ + if (err) { \ + fprintf(stderr, "Failed to retrieve " #func " from libcublas.so"); \ + exit(1); \ + } \ +} while(0) + +#define LDSYMCUBLASLT(func) do { \ + _##func = (func##_t) dlsym(libCublasLtHandle, #func); \ + /* Check for errors */ \ + const char *err = dlerror(); \ + if (err) { \ + fprintf(stderr, "Failed to retrieve " #func " from libcublaslt.so"); \ + exit(1); \ + } \ +} while(0) + +#else +#define LDSYMCUDA(func) do { \ + _##func = (func##_t) GetProcAddress((HMODULE)libCudaHandle, #func); \ + /* Check for errors */ \ + long err = GetLastError(); \ + if (err) { \ + fprintf(stderr, "Failed to retrieve " #func " from nvcuda.dll: %d\n", err); \ + exit(1); \ + } \ +} while(0) + +#define LDSYMCUSPARSE(func) do { \ + _##func = (func##_t) GetProcAddress((HMODULE)libCusparseHandle, #func); \ + /* Check for errors */ \ + long err = GetLastError(); \ + if (err) { \ + fprintf(stderr, "Failed to retrieve " #func " from cusparse64_*.dll: %d\n", err); \ + exit(1); \ + } \ +} while(0) + +#define LDSYMCUBLAS(func) do { \ + _##func = (func##_t) GetProcAddress((HMODULE)libCublasHandle, #func); \ + /* Check for errors */ \ + long err = GetLastError(); \ + if (err) { \ + fprintf(stderr, "Failed to retrieve " #func " from cublas64_*.dll: %d\n", err); \ + exit(1); \ + } \ +} while(0) + +#define LDSYMCUBLASLT(func) do { \ + _##func = (func##_t) GetProcAddress((HMODULE)libCublasLtHandle, #func); \ + /* Check for errors */ \ + long err = GetLastError(); \ + if (err) { \ + fprintf(stderr, "Failed to retrieve " #func " from cublasLt64_*.dll: %d\n", err);\ + exit(1); \ + } \ +} while(0) + +#endif /* _WIN32 */ + +cudaGetErrorString_t cudaGetErrorString; +cusparseGetErrorString_t cusparseGetErrorString; +//static cudaMallocManaged_t cudaMallocManaged; +//static cudaMemPrefetchAsync_t cudaMemPrefetchAsync; +//static cudaDeviceGetAttribute_t cudaDeviceGetAttribute; +cudaMemset_t cudaMemset; +cudaMalloc_t cudaMalloc; +cudaFree_t cudaFree; +cudaPeekAtLastError_t cudaPeekAtLastError; + +cusparseCreate_t cusparseCreate; +cublasCreate_v2_t cublasCreate_v2; +cublasLtCreate_t cublasLtCreate; + +cusparseDestroySpMat_t cusparseDestroySpMat; +cusparseDestroyDnMat_t cusparseDestroyDnMat; +cusparseCreateCoo_t cusparseCreateCoo; +cusparseSpMM_t cusparseSpMM; +cusparseSpMM_bufferSize_t cusparseSpMM_bufferSize; +cusparseCreateDnMat_t cusparseCreateDnMat; + +cublasGemmEx_t cublasGemmEx; +cublasGemmStridedBatchedEx_t cublasGemmStridedBatchedEx; + +cublasLtMatrixLayoutCreate_t cublasLtMatrixLayoutCreate; +cublasLtMatrixLayoutSetAttribute_t cublasLtMatrixLayoutSetAttribute; +cublasLtMatrixTransform_t cublasLtMatrixTransform; +cublasLtMatrixTransformDescCreate_t cublasLtMatrixTransformDescCreate; +cublasLtMatrixTransformDescSetAttribute_t cublasLtMatrixTransformDescSetAttribute; +cublasLtMatrixLayoutDestroy_t cublasLtMatrixLayoutDestroy; +cublasLtMatrixTransformDescDestroy_t cublasLtMatrixTransformDescDestroy; +cublasLtMatmul_t cublasLtMatmul; +cublasLtMatmulDescCreate_t cublasLtMatmulDescCreate; +cublasLtMatmulDescDestroy_t cublasLtMatmulDescDestroy; +cublasLtMatmulDescSetAttribute_t cublasLtMatmulDescSetAttribute; + + +extern "C" int initCudaLibs() { + // cuda libs handles + void *libCudaHandle = NULL; + void *libCusparseHandle = NULL; + void *libCublasHandle = NULL; + void *libCublasLtHandle = NULL; + + void *libs[4] = { NULL, NULL, NULL, NULL }; + +#ifndef _WIN32 + const char *libname[] = { "libcudart.so", "libcusparse.so", "libcublas.so", "libcublasLt.so" }; + const char *libname11[] = { "libcudart.so.11", "libcusparse.so.11", "libcublas.so.11", "libcublasLt.so.11" }; + const char *libname12[] = { "libcudart.so.12", "libcusparse.so.12", "libcublas.so.12", "libcublasLt.so.12" }; +#else + const char *libname[] = { "cudart64_110.dll", "cusparse64_11.dll", "cublas64_11.dll", "cublasLt64_11.dll" }; + const char *libname12[] = { "cudart64_12.dll, cusparse64_12.dll", "cublas64_12.dll", "cublasLt64_12.dll" }; +#endif + + if (cudaGetErrorString != NULL) { + return 0; + } + for (int i = 0; i < 4; i++) { +#ifndef _WIN32 + void *handle = dlopen(libname[i], RTLD_LAZY); + if (!handle) { + handle = dlopen(libname11[i], RTLD_LAZY); + if (!handle) { + handle = dlopen(libname12[i], RTLD_LAZY); + } + } +#else + HANDLE handle = (HANDLE)LoadLibraryA(libname[i]); + if (!handle) { + handle = (HANDLE)LoadLibraryA(libname12[i]); + } +#endif + if (!handle) { + fprintf(stderr, "Fail to load cuda library '%s'\n", libname[i]); + exit(1); + } +#ifndef _WIN32 + dlerror(); +#endif + libs[i] = handle; + } + + libCudaHandle = libs[0]; + libCusparseHandle = libs[1]; + libCublasHandle = libs[2]; + libCublasLtHandle = libs[3]; + + LDSYMCUDA(cudaGetErrorString); + LDSYMCUSPARSE(cusparseGetErrorString); + LDSYMCUSPARSE(cusparseCreate); + LDSYMCUBLAS(cublasCreate_v2); + LDSYMCUBLASLT(cublasLtCreate); + //LDSYMCUDA(cudaMallocManaged); + //LDSYMCUDA(cudaMemPrefetchAsync); + //LDSYMCUDA(cudaDeviceGetAttribute); + LDSYMCUDA(cudaMemset); + LDSYMCUDA(cudaMalloc); + LDSYMCUDA(cudaFree); + LDSYMCUDA(cudaPeekAtLastError); + + LDSYMCUSPARSE(cusparseCreateCoo); + LDSYMCUSPARSE(cusparseDestroySpMat); + LDSYMCUSPARSE(cusparseDestroyDnMat); + LDSYMCUSPARSE(cusparseSpMM); + LDSYMCUSPARSE(cusparseSpMM_bufferSize); + LDSYMCUSPARSE(cusparseCreateDnMat); + + LDSYMCUBLAS(cublasGemmEx); + LDSYMCUBLAS(cublasGemmStridedBatchedEx); + +#ifndef NO_CUBLASLT + LDSYMCUBLASLT(cublasLtMatrixLayoutCreate); + LDSYMCUBLASLT(cublasLtMatrixLayoutSetAttribute); + LDSYMCUBLASLT(cublasLtMatrixTransform); + LDSYMCUBLASLT(cublasLtMatrixTransformDescCreate); + LDSYMCUBLASLT(cublasLtMatrixTransformDescSetAttribute); + LDSYMCUBLASLT(cublasLtMatrixLayoutDestroy); + LDSYMCUBLASLT(cublasLtMatrixTransformDescDestroy); + LDSYMCUBLASLT(cublasLtMatmul); + LDSYMCUBLASLT(cublasLtMatmulDescCreate); + LDSYMCUBLASLT(cublasLtMatmulDescDestroy); + LDSYMCUBLASLT(cublasLtMatmulDescSetAttribute); +#endif + + return 0; +} +#endif /* USE_CUDA_WRAPPER */ + + void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } From 162c99808080d146ea148be07825a4b91e1b37ba Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 15 Feb 2024 19:49:16 +0900 Subject: [PATCH 2/3] add -DUSE_CUDA_WRAPPER --- .github/workflows/python-package.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 07c3b5217..97c1dccda 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -125,10 +125,10 @@ jobs: docker run --platform linux/$build_arch -i -w /src -v $PWD:/src $image sh -c \ "apt-get update \ && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ - && cmake -DCOMPUTE_BACKEND=cuda -DNO_CUBLASLT=${NO_CUBLASLT} . \ + && cmake -DCOMPUTE_BACKEND=cuda -DUSE_CUDA_WRAPPER=ON -DNO_CUBLASLT=${NO_CUBLASLT} . \ && cmake --build ." else - cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S . + cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DUSE_CUDA_WRAPPER=ON -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S . cmake --build . --config Release fi done From aaa2dcc3b456666ef4facc6970160abbd1224652 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 15 Feb 2024 21:18:04 +0900 Subject: [PATCH 3/3] use ubuntu20.04 images * fix cmake version: default cmake version of ubuntu20.04 is too low --- .github/workflows/python-package.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 97c1dccda..ebe0c93d6 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -120,11 +120,12 @@ jobs: [[ "${{ matrix.os }}" = windows-* ]] && python3 -m pip install ninja for NO_CUBLASLT in ON OFF; do if [ ${build_os:0:6} == ubuntu ]; then - image=nvidia/cuda:${{ matrix.cuda_version }}-devel-ubuntu22.04 + image=nvidia/cuda:${{ matrix.cuda_version }}-devel-ubuntu20.04 echo "Using image $image" docker run --platform linux/$build_arch -i -w /src -v $PWD:/src $image sh -c \ "apt-get update \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends python3 python3-pip \ + && pip install cmake==3.27.9 \ && cmake -DCOMPUTE_BACKEND=cuda -DUSE_CUDA_WRAPPER=ON -DNO_CUBLASLT=${NO_CUBLASLT} . \ && cmake --build ." else